| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Resume FLUX LoRA training from step 500 checkpoint. |
| Uses standard FluxPipeline from diffusers. |
| Output: Limbicnation/pixel-art-lora |
| """ |
|
|
| import os |
| import sys |
| import torch |
| from pathlib import Path |
| from huggingface_hub import hf_hub_download, snapshot_download, create_repo, upload_folder, HfApi |
|
|
| CHECKPOINT_REPO = "Limbicnation/sprite-lora-checkpoint-step500" |
| DATASET_REPO = "Limbicnation/sprite-lora-training-data" |
| OUTPUT_REPO = "Limbicnation/pixel-art-lora" |
|
|
| def main(): |
| print("="*70) |
| print("π FLUX LoRA Training (Resuming from Step 500)") |
| print("="*70) |
| |
| |
| print("\nπ₯ Downloading checkpoint...") |
| os.makedirs("./checkpoint_step500", exist_ok=True) |
| checkpoint_path = hf_hub_download( |
| repo_id=CHECKPOINT_REPO, |
| filename="pytorch_lora_weights.safetensors", |
| repo_type="model", |
| local_dir="./checkpoint_step500" |
| ) |
| print(f" β
Checkpoint: {checkpoint_path}") |
| |
| |
| print("\nπ₯ Downloading dataset...") |
| dataset_path = snapshot_download( |
| repo_id=DATASET_REPO, |
| repo_type="dataset", |
| local_dir="./training_data" |
| ) |
| image_files = list(Path(dataset_path).rglob("*.png")) |
| print(f" β
Dataset: {len(image_files)} images") |
| |
| |
| print("\nπ₯ Setting up trainer...") |
| os.system("git clone https://github.com/Limbicnation/klein-lora-trainer.git 2>/dev/null || true") |
| |
| |
| trainer_file = Path("./klein-lora-trainer/flux2_klein_trainer/trainer.py") |
| if trainer_file.exists(): |
| content = trainer_file.read_text() |
| |
| content = content.replace("from diffusers import Flux2KleinPipeline", "from diffusers import FluxPipeline") |
| content = content.replace("Flux2KleinPipeline", "FluxPipeline") |
| trainer_file.write_text(content) |
| print(" β
Fixed imports in trainer.py") |
| |
| sys.path.insert(0, "./klein-lora-trainer") |
| |
| |
| from flux2_klein_trainer.config import TrainingConfig, ModelConfig, LoRAConfig, DatasetConfig |
| from flux2_klein_trainer.trainer import KleinLoRATrainer |
| |
| |
| config = TrainingConfig( |
| model=ModelConfig( |
| pretrained_model_name="black-forest-labs/FLUX.1-dev", |
| dtype="bfloat16", |
| enable_cpu_offload=True, |
| ), |
| lora=LoRAConfig(rank=64, alpha=128), |
| dataset=DatasetConfig( |
| data_dir="./training_data/images", |
| caption_ext="txt", |
| resolution=512, |
| ), |
| output_dir="./output", |
| resume_from_checkpoint="./checkpoint_step500", |
| num_train_steps=1000, |
| batch_size=1, |
| gradient_accumulation_steps=4, |
| learning_rate=1e-4, |
| optimizer="adamw_8bit", |
| save_every=500, |
| sample_every=500, |
| trigger_word="pixel art sprite", |
| push_to_hub=True, |
| hub_model_id=OUTPUT_REPO, |
| ) |
| |
| print(f"\nπ€ Output: {OUTPUT_REPO}") |
| create_repo(OUTPUT_REPO, exist_ok=True, repo_type="model") |
| |
| |
| print("\nποΈ Starting Training...") |
| trainer = KleinLoRATrainer(config) |
| trainer.train() |
| |
| print("\n" + "="*70) |
| print("β
Training Complete!") |
| print("="*70) |
| print(f"\nπ€ Model saved to: {OUTPUT_REPO}") |
| print(f" https://huggingface.co/{OUTPUT_REPO}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|