Limbicnation commited on
Commit
5038a7c
·
verified ·
1 Parent(s): 024b957

Upload sprite_lora_resume_v3.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. sprite_lora_resume_v3.py +123 -0
sprite_lora_resume_v3.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "torch>=2.0.0",
5
+ # "diffusers>=0.25.0",
6
+ # "transformers>=4.35.0",
7
+ # "accelerate>=0.24.0",
8
+ # "peft>=0.7.0",
9
+ # "bitsandbytes>=0.41.0",
10
+ # "huggingface-hub>=0.20.0",
11
+ # "safetensors>=0.4.0",
12
+ # "omegaconf>=2.3.0",
13
+ # "Pillow>=10.0.0",
14
+ # "numpy>=1.24.0",
15
+ # "tqdm>=4.66.0",
16
+ # ]
17
+ # ///
18
+
19
+ """
20
+ Resume FLUX LoRA training from step 500 checkpoint.
21
+ Uses standard FluxPipeline from diffusers.
22
+ Output: Limbicnation/pixel-art-lora
23
+ """
24
+
25
+ import os
26
+ import sys
27
+ import torch
28
+ from pathlib import Path
29
+ from huggingface_hub import hf_hub_download, snapshot_download, create_repo, upload_folder, HfApi
30
+
31
+ CHECKPOINT_REPO = "Limbicnation/sprite-lora-checkpoint-step500"
32
+ DATASET_REPO = "Limbicnation/sprite-lora-training-data"
33
+ OUTPUT_REPO = "Limbicnation/pixel-art-lora"
34
+
35
+ def main():
36
+ print("="*70)
37
+ print("🚀 FLUX LoRA Training (Resuming from Step 500)")
38
+ print("="*70)
39
+
40
+ # Download checkpoint
41
+ print("\n📥 Downloading checkpoint...")
42
+ os.makedirs("./checkpoint_step500", exist_ok=True)
43
+ checkpoint_path = hf_hub_download(
44
+ repo_id=CHECKPOINT_REPO,
45
+ filename="pytorch_lora_weights.safetensors",
46
+ repo_type="model",
47
+ local_dir="./checkpoint_step500"
48
+ )
49
+ print(f" ✅ Checkpoint: {checkpoint_path}")
50
+
51
+ # Download dataset
52
+ print("\n📥 Downloading dataset...")
53
+ dataset_path = snapshot_download(
54
+ repo_id=DATASET_REPO,
55
+ repo_type="dataset",
56
+ local_dir="./training_data"
57
+ )
58
+ image_files = list(Path(dataset_path).rglob("*.png"))
59
+ print(f" ✅ Dataset: {len(image_files)} images")
60
+
61
+ # Clone trainer repo with fixes
62
+ print("\n📥 Setting up trainer...")
63
+ os.system("git clone https://github.com/Limbicnation/klein-lora-trainer.git 2>/dev/null || true")
64
+
65
+ # Fix the import in trainer.py
66
+ trainer_file = Path("./klein-lora-trainer/flux2_klein_trainer/trainer.py")
67
+ if trainer_file.exists():
68
+ content = trainer_file.read_text()
69
+ # Replace Flux2KleinPipeline with FluxPipeline
70
+ content = content.replace("from diffusers import Flux2KleinPipeline", "from diffusers import FluxPipeline")
71
+ content = content.replace("Flux2KleinPipeline", "FluxPipeline")
72
+ trainer_file.write_text(content)
73
+ print(" ✅ Fixed imports in trainer.py")
74
+
75
+ sys.path.insert(0, "./klein-lora-trainer")
76
+
77
+ # Import after fixing
78
+ from flux2_klein_trainer.config import TrainingConfig, ModelConfig, LoRAConfig, DatasetConfig
79
+ from flux2_klein_trainer.trainer import KleinLoRATrainer
80
+
81
+ # Build config
82
+ config = TrainingConfig(
83
+ model=ModelConfig(
84
+ pretrained_model_name="black-forest-labs/FLUX.1-dev", # Use standard FLUX
85
+ dtype="bfloat16",
86
+ enable_cpu_offload=True,
87
+ ),
88
+ lora=LoRAConfig(rank=64, alpha=128),
89
+ dataset=DatasetConfig(
90
+ data_dir="./training_data/images",
91
+ caption_ext="txt",
92
+ resolution=512,
93
+ ),
94
+ output_dir="./output",
95
+ resume_from_checkpoint="./checkpoint_step500",
96
+ num_train_steps=1000,
97
+ batch_size=1,
98
+ gradient_accumulation_steps=4,
99
+ learning_rate=1e-4,
100
+ optimizer="adamw_8bit",
101
+ save_every=500,
102
+ sample_every=500,
103
+ trigger_word="pixel art sprite",
104
+ push_to_hub=True,
105
+ hub_model_id=OUTPUT_REPO,
106
+ )
107
+
108
+ print(f"\n📤 Output: {OUTPUT_REPO}")
109
+ create_repo(OUTPUT_REPO, exist_ok=True, repo_type="model")
110
+
111
+ # Train
112
+ print("\n🏋️ Starting Training...")
113
+ trainer = KleinLoRATrainer(config)
114
+ trainer.train()
115
+
116
+ print("\n" + "="*70)
117
+ print("✅ Training Complete!")
118
+ print("="*70)
119
+ print(f"\n📤 Model saved to: {OUTPUT_REPO}")
120
+ print(f" https://huggingface.co/{OUTPUT_REPO}")
121
+
122
+ if __name__ == "__main__":
123
+ main()