Limbicnation commited on
Commit
3c0c0a3
·
verified ·
1 Parent(s): 5038a7c

Upload sprite_lora_resume_v4.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. sprite_lora_resume_v4.py +124 -0
sprite_lora_resume_v4.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "torch>=2.0.0",
5
+ # "torchvision>=0.15.0",
6
+ # "diffusers>=0.25.0",
7
+ # "transformers>=4.35.0",
8
+ # "accelerate>=0.24.0",
9
+ # "peft>=0.7.0",
10
+ # "bitsandbytes>=0.41.0",
11
+ # "huggingface-hub>=0.20.0",
12
+ # "safetensors>=0.4.0",
13
+ # "omegaconf>=2.3.0",
14
+ # "Pillow>=10.0.0",
15
+ # "numpy>=1.24.0",
16
+ # "tqdm>=4.66.0",
17
+ # ]
18
+ # ///
19
+
20
+ """
21
+ Resume FLUX LoRA training from step 500 checkpoint.
22
+ Uses standard FluxPipeline from diffusers.
23
+ Output: Limbicnation/pixel-art-lora
24
+ """
25
+
26
+ import os
27
+ import sys
28
+ import torch
29
+ from pathlib import Path
30
+ from huggingface_hub import hf_hub_download, snapshot_download, create_repo, upload_folder, HfApi
31
+
32
+ CHECKPOINT_REPO = "Limbicnation/sprite-lora-checkpoint-step500"
33
+ DATASET_REPO = "Limbicnation/sprite-lora-training-data"
34
+ OUTPUT_REPO = "Limbicnation/pixel-art-lora"
35
+
36
+ def main():
37
+ print("="*70)
38
+ print("🚀 FLUX LoRA Training (Resuming from Step 500)")
39
+ print("="*70)
40
+
41
+ # Download checkpoint
42
+ print("\n📥 Downloading checkpoint...")
43
+ os.makedirs("./checkpoint_step500", exist_ok=True)
44
+ checkpoint_path = hf_hub_download(
45
+ repo_id=CHECKPOINT_REPO,
46
+ filename="pytorch_lora_weights.safetensors",
47
+ repo_type="model",
48
+ local_dir="./checkpoint_step500"
49
+ )
50
+ print(f" ✅ Checkpoint: {checkpoint_path}")
51
+
52
+ # Download dataset
53
+ print("\n📥 Downloading dataset...")
54
+ dataset_path = snapshot_download(
55
+ repo_id=DATASET_REPO,
56
+ repo_type="dataset",
57
+ local_dir="./training_data"
58
+ )
59
+ image_files = list(Path(dataset_path).rglob("*.png"))
60
+ print(f" ✅ Dataset: {len(image_files)} images")
61
+
62
+ # Clone trainer repo with fixes
63
+ print("\n📥 Setting up trainer...")
64
+ os.system("git clone https://github.com/Limbicnation/klein-lora-trainer.git 2>/dev/null || true")
65
+
66
+ # Fix the import in trainer.py
67
+ trainer_file = Path("./klein-lora-trainer/flux2_klein_trainer/trainer.py")
68
+ if trainer_file.exists():
69
+ content = trainer_file.read_text()
70
+ # Replace Flux2KleinPipeline with FluxPipeline
71
+ content = content.replace("from diffusers import Flux2KleinPipeline", "from diffusers import FluxPipeline")
72
+ content = content.replace("Flux2KleinPipeline", "FluxPipeline")
73
+ trainer_file.write_text(content)
74
+ print(" ✅ Fixed imports in trainer.py")
75
+
76
+ sys.path.insert(0, "./klein-lora-trainer")
77
+
78
+ # Import after fixing
79
+ from flux2_klein_trainer.config import TrainingConfig, ModelConfig, LoRAConfig, DatasetConfig
80
+ from flux2_klein_trainer.trainer import KleinLoRATrainer
81
+
82
+ # Build config
83
+ config = TrainingConfig(
84
+ model=ModelConfig(
85
+ pretrained_model_name="black-forest-labs/FLUX.1-dev", # Use standard FLUX
86
+ dtype="bfloat16",
87
+ enable_cpu_offload=True,
88
+ ),
89
+ lora=LoRAConfig(rank=64, alpha=128),
90
+ dataset=DatasetConfig(
91
+ data_dir="./training_data/images",
92
+ caption_ext="txt",
93
+ resolution=512,
94
+ ),
95
+ output_dir="./output",
96
+ resume_from_checkpoint="./checkpoint_step500",
97
+ num_train_steps=1000,
98
+ batch_size=1,
99
+ gradient_accumulation_steps=4,
100
+ learning_rate=1e-4,
101
+ optimizer="adamw_8bit",
102
+ save_every=500,
103
+ sample_every=500,
104
+ trigger_word="pixel art sprite",
105
+ push_to_hub=True,
106
+ hub_model_id=OUTPUT_REPO,
107
+ )
108
+
109
+ print(f"\n📤 Output: {OUTPUT_REPO}")
110
+ create_repo(OUTPUT_REPO, exist_ok=True, repo_type="model")
111
+
112
+ # Train
113
+ print("\n🏋️ Starting Training...")
114
+ trainer = KleinLoRATrainer(config)
115
+ trainer.train()
116
+
117
+ print("\n" + "="*70)
118
+ print("✅ Training Complete!")
119
+ print("="*70)
120
+ print(f"\n📤 Model saved to: {OUTPUT_REPO}")
121
+ print(f" https://huggingface.co/{OUTPUT_REPO}")
122
+
123
+ if __name__ == "__main__":
124
+ main()