Limbicnation commited on
Commit
024b957
·
verified ·
1 Parent(s): ea5603a

Upload sprite_lora_resume_v2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. sprite_lora_resume_v2.py +106 -0
sprite_lora_resume_v2.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "torch>=2.0.0",
5
+ # "diffusers>=0.25.0",
6
+ # "transformers>=4.35.0",
7
+ # "accelerate>=0.24.0",
8
+ # "peft>=0.7.0",
9
+ # "bitsandbytes>=0.41.0",
10
+ # "huggingface-hub>=0.20.0",
11
+ # "safetensors>=0.4.0",
12
+ # "omegaconf>=2.3.0",
13
+ # "Pillow>=10.0.0",
14
+ # "numpy>=1.24.0",
15
+ # "tqdm>=4.66.0",
16
+ # ]
17
+ # ///
18
+
19
+ """
20
+ Resume FLUX.2-klein-4B LoRA training from step 500 checkpoint.
21
+ Output: Limbicnation/pixel-art-lora
22
+ """
23
+
24
+ import os
25
+ import sys
26
+ import torch
27
+ from pathlib import Path
28
+ from huggingface_hub import hf_hub_download, snapshot_download, create_repo, upload_folder
29
+
30
+ CHECKPOINT_REPO = "Limbicnation/sprite-lora-checkpoint-step500"
31
+ DATASET_REPO = "Limbicnation/sprite-lora-training-data"
32
+ OUTPUT_REPO = "Limbicnation/pixel-art-lora"
33
+
34
+ def main():
35
+ print("="*70)
36
+ print("🚀 FLUX.2-klein-4B LoRA Training (Resuming from Step 500)")
37
+ print("="*70)
38
+
39
+ # Download checkpoint
40
+ print("\n📥 Downloading checkpoint...")
41
+ checkpoint_path = hf_hub_download(
42
+ repo_id=CHECKPOINT_REPO,
43
+ filename="pytorch_lora_weights.safetensors",
44
+ repo_type="model",
45
+ local_dir="./checkpoint_step500"
46
+ )
47
+ print(f" ✅ Checkpoint: {checkpoint_path}")
48
+
49
+ # Download dataset
50
+ print("\n📥 Downloading dataset...")
51
+ dataset_path = snapshot_download(
52
+ repo_id=DATASET_REPO,
53
+ repo_type="dataset",
54
+ local_dir="./training_data"
55
+ )
56
+ image_files = list(Path(dataset_path).rglob("*.png"))
57
+ print(f" ✅ Dataset: {len(image_files)} images")
58
+
59
+ # Clone trainer
60
+ print("\n📥 Setting up trainer...")
61
+ os.system("git clone https://github.com/Limbicnation/klein-lora-trainer.git 2>/dev/null || true")
62
+ sys.path.insert(0, "./klein-lora-trainer")
63
+
64
+ from flux2_klein_trainer.config import TrainingConfig, ModelConfig, LoRAConfig, DatasetConfig
65
+ from flux2_klein_trainer.trainer import KleinLoRATrainer
66
+
67
+ # Config
68
+ config = TrainingConfig(
69
+ model=ModelConfig(
70
+ pretrained_model_name="black-forest-labs/FLUX.2-klein-4B",
71
+ dtype="bfloat16",
72
+ enable_cpu_offload=True,
73
+ ),
74
+ lora=LoRAConfig(rank=64, alpha=128),
75
+ dataset=DatasetConfig(
76
+ data_dir="./training_data/images",
77
+ caption_ext="txt",
78
+ resolution=512,
79
+ ),
80
+ output_dir="./output",
81
+ resume_from_checkpoint="./checkpoint_step500",
82
+ num_train_steps=1000,
83
+ batch_size=1,
84
+ gradient_accumulation_steps=4,
85
+ learning_rate=1e-4,
86
+ optimizer="adamw_8bit",
87
+ save_every=500,
88
+ sample_every=500,
89
+ trigger_word="pixel art sprite",
90
+ push_to_hub=True,
91
+ hub_model_id=OUTPUT_REPO,
92
+ )
93
+
94
+ print(f"\n📤 Output: {OUTPUT_REPO}")
95
+ create_repo(OUTPUT_REPO, exist_ok=True, repo_type="model")
96
+
97
+ # Train
98
+ print("\n🏋️ Starting Training...")
99
+ trainer = KleinLoRATrainer(config)
100
+ trainer.train()
101
+
102
+ print("\n✅ Complete!")
103
+ print(f"📤 Model saved to: {OUTPUT_REPO}")
104
+
105
+ if __name__ == "__main__":
106
+ main()