| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Resume FLUX.2-klein-4B LoRA training from step 500 checkpoint. |
| Runs on Hugging Face Jobs infrastructure. |
| """ |
|
|
| import os |
| import sys |
| import torch |
| from pathlib import Path |
| from huggingface_hub import hf_hub_download, snapshot_download, create_repo, upload_folder |
|
|
| |
| CHECKPOINT_REPO = "Limbicnation/sprite-lora-checkpoint-step500" |
| DATASET_REPO = "Limbicnation/sprite-lora-training-data" |
| OUTPUT_REPO = "Limbicnation/sprite-lora-final" |
|
|
| def main(): |
| print("="*70) |
| print("π FLUX.2-klein-4B LoRA Training (Resuming from Step 500)") |
| print("="*70) |
| |
| |
| print("\nπ₯ Downloading checkpoint from Hugging Face Hub...") |
| checkpoint_path = hf_hub_download( |
| repo_id=CHECKPOINT_REPO, |
| filename="pytorch_lora_weights.safetensors", |
| repo_type="model", |
| local_dir="./checkpoint_step500" |
| ) |
| print(f" β
Checkpoint downloaded: {checkpoint_path}") |
| |
| |
| print("\nπ₯ Downloading training dataset...") |
| dataset_path = snapshot_download( |
| repo_id=DATASET_REPO, |
| repo_type="dataset", |
| local_dir="./training_data" |
| ) |
| print(f" β
Dataset downloaded to: {dataset_path}") |
| |
| |
| image_files = list(Path(dataset_path).rglob("*.png")) |
| print(f" Found {len(image_files)} training images") |
| |
| |
| print("\nποΈ Setting up trainer...") |
| |
| |
| os.system("git clone https://github.com/Limbicnation/klein-lora-trainer.git 2>/dev/null || true") |
| |
| sys.path.insert(0, "./klein-lora-trainer") |
| |
| from flux2_klein_trainer.config import TrainingConfig, ModelConfig, LoRAConfig, DatasetConfig |
| from flux2_klein_trainer.trainer import KleinLoRATrainer |
| |
| |
| config = TrainingConfig( |
| model=ModelConfig( |
| pretrained_model_name="black-forest-labs/FLUX.2-klein-4B", |
| dtype="bfloat16", |
| enable_cpu_offload=True, |
| ), |
| lora=LoRAConfig( |
| rank=64, |
| alpha=128, |
| ), |
| dataset=DatasetConfig( |
| data_dir="./training_data/images", |
| caption_ext="txt", |
| resolution=512, |
| ), |
| output_dir="./output/sprite_lora_final", |
| resume_from_checkpoint="./checkpoint_step500", |
| num_train_steps=1000, |
| batch_size=1, |
| gradient_accumulation_steps=4, |
| learning_rate=1e-4, |
| optimizer="adamw_8bit", |
| save_every=500, |
| sample_every=500, |
| trigger_word="pixel art sprite", |
| push_to_hub=True, |
| hub_model_id=OUTPUT_REPO, |
| ) |
| |
| print("\nπ Training Configuration:") |
| print(f" Resume from: Step 500") |
| print(f" Target steps: 1000") |
| print(f" Batch size: 1") |
| print(f" LoRA rank: 64") |
| print(f" Learning rate: 1e-4") |
| print(f" Dataset: {len(image_files)} images") |
| |
| |
| print(f"\nπ€ Output will be pushed to: {OUTPUT_REPO}") |
| create_repo(OUTPUT_REPO, exist_ok=True, repo_type="model") |
| |
| |
| print("\n" + "="*70) |
| print("ποΈ Starting Training") |
| print("="*70) |
| |
| trainer = KleinLoRATrainer(config) |
| trainer.train() |
| |
| print("\n" + "="*70) |
| print("β
Training Complete!") |
| print("="*70) |
| print(f"\nπ€ Final model saved to: {OUTPUT_REPO}") |
| print(f" https://huggingface.co/{OUTPUT_REPO}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|