Limbicnation commited on
Commit
ea5603a
·
verified ·
1 Parent(s): 010318d

Upload sprite_lora_resume.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. sprite_lora_resume.py +131 -0
sprite_lora_resume.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "torch>=2.0.0",
5
+ # "diffusers>=0.25.0",
6
+ # "transformers>=4.35.0",
7
+ # "accelerate>=0.24.0",
8
+ # "peft>=0.7.0",
9
+ # "bitsandbytes>=0.41.0",
10
+ # "huggingface-hub>=0.20.0",
11
+ # "safetensors>=0.4.0",
12
+ # "omegaconf>=2.3.0",
13
+ # "Pillow>=10.0.0",
14
+ # "numpy>=1.24.0",
15
+ # "tqdm>=4.66.0",
16
+ # ]
17
+ # ///
18
+
19
+ """
20
+ Resume FLUX.2-klein-4B LoRA training from step 500 checkpoint.
21
+ Runs on Hugging Face Jobs infrastructure.
22
+ """
23
+
24
+ import os
25
+ import sys
26
+ import torch
27
+ from pathlib import Path
28
+ from huggingface_hub import hf_hub_download, snapshot_download, create_repo, upload_folder
29
+
30
+ # Configuration
31
+ CHECKPOINT_REPO = "Limbicnation/sprite-lora-checkpoint-step500"
32
+ DATASET_REPO = "Limbicnation/sprite-lora-training-data"
33
+ OUTPUT_REPO = "Limbicnation/sprite-lora-final"
34
+
35
+ def main():
36
+ print("="*70)
37
+ print("🚀 FLUX.2-klein-4B LoRA Training (Resuming from Step 500)")
38
+ print("="*70)
39
+
40
+ # Step 1: Download checkpoint
41
+ print("\n📥 Downloading checkpoint from Hugging Face Hub...")
42
+ checkpoint_path = hf_hub_download(
43
+ repo_id=CHECKPOINT_REPO,
44
+ filename="pytorch_lora_weights.safetensors",
45
+ repo_type="model",
46
+ local_dir="./checkpoint_step500"
47
+ )
48
+ print(f" ✅ Checkpoint downloaded: {checkpoint_path}")
49
+
50
+ # Step 2: Download dataset
51
+ print("\n📥 Downloading training dataset...")
52
+ dataset_path = snapshot_download(
53
+ repo_id=DATASET_REPO,
54
+ repo_type="dataset",
55
+ local_dir="./training_data"
56
+ )
57
+ print(f" ✅ Dataset downloaded to: {dataset_path}")
58
+
59
+ # Count images
60
+ image_files = list(Path(dataset_path).rglob("*.png"))
61
+ print(f" Found {len(image_files)} training images")
62
+
63
+ # Step 3: Setup and run training
64
+ print("\n🏋️ Setting up trainer...")
65
+
66
+ # Clone the trainer repo
67
+ os.system("git clone https://github.com/Limbicnation/klein-lora-trainer.git 2>/dev/null || true")
68
+
69
+ sys.path.insert(0, "./klein-lora-trainer")
70
+
71
+ from flux2_klein_trainer.config import TrainingConfig, ModelConfig, LoRAConfig, DatasetConfig
72
+ from flux2_klein_trainer.trainer import KleinLoRATrainer
73
+
74
+ # Build config
75
+ config = TrainingConfig(
76
+ model=ModelConfig(
77
+ pretrained_model_name="black-forest-labs/FLUX.2-klein-4B",
78
+ dtype="bfloat16",
79
+ enable_cpu_offload=True, # Low VRAM mode
80
+ ),
81
+ lora=LoRAConfig(
82
+ rank=64,
83
+ alpha=128,
84
+ ),
85
+ dataset=DatasetConfig(
86
+ data_dir="./training_data/images",
87
+ caption_ext="txt",
88
+ resolution=512,
89
+ ),
90
+ output_dir="./output/sprite_lora_final",
91
+ resume_from_checkpoint="./checkpoint_step500",
92
+ num_train_steps=1000, # Train 500 more steps (500 -> 1000)
93
+ batch_size=1,
94
+ gradient_accumulation_steps=4,
95
+ learning_rate=1e-4,
96
+ optimizer="adamw_8bit",
97
+ save_every=500,
98
+ sample_every=500,
99
+ trigger_word="pixel art sprite",
100
+ push_to_hub=True,
101
+ hub_model_id=OUTPUT_REPO,
102
+ )
103
+
104
+ print("\n📋 Training Configuration:")
105
+ print(f" Resume from: Step 500")
106
+ print(f" Target steps: 1000")
107
+ print(f" Batch size: 1")
108
+ print(f" LoRA rank: 64")
109
+ print(f" Learning rate: 1e-4")
110
+ print(f" Dataset: {len(image_files)} images")
111
+
112
+ # Create output repo
113
+ print(f"\n📤 Output will be pushed to: {OUTPUT_REPO}")
114
+ create_repo(OUTPUT_REPO, exist_ok=True, repo_type="model")
115
+
116
+ # Start training
117
+ print("\n" + "="*70)
118
+ print("🏋️ Starting Training")
119
+ print("="*70)
120
+
121
+ trainer = KleinLoRATrainer(config)
122
+ trainer.train()
123
+
124
+ print("\n" + "="*70)
125
+ print("✅ Training Complete!")
126
+ print("="*70)
127
+ print(f"\n📤 Final model saved to: {OUTPUT_REPO}")
128
+ print(f" https://huggingface.co/{OUTPUT_REPO}")
129
+
130
+ if __name__ == "__main__":
131
+ main()