Limbicnation commited on
Commit
8c747e6
·
verified ·
1 Parent(s): d015479

Upload sprite_lora_resume_v7.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. sprite_lora_resume_v7.py +330 -0
sprite_lora_resume_v7.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "torch>=2.0.0",
5
+ # "torchvision>=0.15.0",
6
+ # "diffusers>=0.25.0",
7
+ # "transformers>=4.35.0",
8
+ # "accelerate>=0.24.0",
9
+ # "peft>=0.7.0",
10
+ # "bitsandbytes>=0.41.0",
11
+ # "huggingface-hub>=0.20.0",
12
+ # "safetensors>=0.4.0",
13
+ # "omegaconf>=2.3.0",
14
+ # "Pillow>=10.0.0",
15
+ # "numpy>=1.24.0",
16
+ # "tqdm>=4.66.0",
17
+ # ]
18
+ # ///
19
+
20
+ """
21
+ Resume FLUX.2-klein-4B LoRA training from step 500 checkpoint.
22
+ Uses standard FluxPipeline from diffusers.
23
+ Output: Limbicnation/pixel-art-lora
24
+ """
25
+
26
+ import os
27
+ import sys
28
+ import torch
29
+ import torch.nn.functional as F
30
+ from pathlib import Path
31
+ from tqdm import tqdm
32
+ from PIL import Image
33
+ from huggingface_hub import (
34
+ hf_hub_download,
35
+ snapshot_download,
36
+ create_repo,
37
+ upload_folder,
38
+ login,
39
+ HfApi
40
+ )
41
+ from diffusers import FluxPipeline
42
+ from transformers import CLIPTokenizer, T5EncoderModel
43
+ from peft import LoraConfig, get_peft_model, set_peft_model_state_dict, get_peft_model_state_dict
44
+ from safetensors.torch import load_file, save_file
45
+ from accelerate import Accelerator
46
+
47
+ # Configuration
48
+ CHECKPOINT_REPO = "Limbicnation/sprite-lora-checkpoint-step500"
49
+ DATASET_REPO = "Limbicnation/sprite-lora-training-data"
50
+ OUTPUT_REPO = "Limbicnation/pixel-art-lora"
51
+ BASE_MODEL = "black-forest-labs/FLUX.2-klein-4B"
52
+
53
+ def train(token):
54
+ print("="*70)
55
+ print("🚀 FLUX.2-klein-4B LoRA Training Script v7")
56
+ print("="*70)
57
+ print(f"\n Base model: {BASE_MODEL}")
58
+ print(f" Dataset: {DATASET_REPO}")
59
+ print(f" Output: {OUTPUT_REPO}")
60
+ print(f" Steps: 1000")
61
+ print(f" LoRA: rank=64, alpha=128")
62
+ print("="*70)
63
+
64
+ # Authenticate
65
+ print("\n🔑 Authenticating...")
66
+ login(token=token, add_to_git_credential=False)
67
+ print("✅ Authenticated")
68
+
69
+ # Download checkpoint
70
+ print("\n📥 Downloading checkpoint...")
71
+ os.makedirs("./checkpoint_step500", exist_ok=True)
72
+ checkpoint_path = hf_hub_download(
73
+ repo_id=CHECKPOINT_REPO,
74
+ filename="pytorch_lora_weights.safetensors",
75
+ repo_type="model",
76
+ local_dir="./checkpoint_step500",
77
+ token=token
78
+ )
79
+ print(f"✅ Checkpoint: {checkpoint_path}")
80
+
81
+ # Download dataset
82
+ print("\n📥 Downloading dataset...")
83
+ dataset_path = snapshot_download(
84
+ repo_id=DATASET_REPO,
85
+ repo_type="dataset",
86
+ local_dir="./training_data",
87
+ token=token
88
+ )
89
+ image_files = list(Path(dataset_path).rglob("*.png"))
90
+ print(f"✅ Dataset: {len(image_files)} images")
91
+
92
+ # Setup accelerator
93
+ print("\n⚙️ Setting up accelerator...")
94
+ accelerator = Accelerator(
95
+ gradient_accumulation_steps=4,
96
+ mixed_precision="bf16"
97
+ )
98
+ device = accelerator.device
99
+
100
+ print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
101
+
102
+ # Load pipeline
103
+ print(f"\n📥 Loading {BASE_MODEL}...")
104
+ pipe = FluxPipeline.from_pretrained(
105
+ BASE_MODEL,
106
+ torch_dtype=torch.bfloat16,
107
+ token=token
108
+ )
109
+
110
+ # Enable CPU offloading for memory efficiency
111
+ print("💾 Enabling CPU offloading...")
112
+ pipe.enable_model_cpu_offload()
113
+
114
+ # Apply LoRA
115
+ print("\n🔧 Applying LoRA (rank=64, alpha=128)...")
116
+
117
+ # Get target modules for FLUX transformer
118
+ target_modules = []
119
+ for i in range(19): # FLUX.2-klein has 19 transformer blocks
120
+ target_modules.extend([
121
+ f"transformer_blocks.{i}.attn.to_q",
122
+ f"transformer_blocks.{i}.attn.to_k",
123
+ f"transformer_blocks.{i}.attn.to_v",
124
+ f"transformer_blocks.{i}.attn.to_out.0",
125
+ ])
126
+
127
+ lora_config = LoraConfig(
128
+ r=64,
129
+ lora_alpha=128,
130
+ lora_dropout=0.0,
131
+ target_modules=target_modules,
132
+ use_rslora=True
133
+ )
134
+
135
+ pipe.transformer = get_peft_model(pipe.transformer, lora_config)
136
+ pipe.transformer.print_trainable_parameters()
137
+
138
+ # Load checkpoint
139
+ print("\n🔄 Loading checkpoint from step 500...")
140
+ state_dict = load_file(checkpoint_path)
141
+ set_peft_model_state_dict(pipe.transformer, state_dict)
142
+ print("✅ Checkpoint loaded")
143
+
144
+ global_step = 500
145
+
146
+ # Create output dir
147
+ output_dir = Path("./output")
148
+ output_dir.mkdir(exist_ok=True)
149
+
150
+ # Create output repo
151
+ print(f"\n📤 Creating output repo: {OUTPUT_REPO}")
152
+ create_repo(OUTPUT_REPO, exist_ok=True, repo_type="model", token=token)
153
+
154
+ # Setup optimizer
155
+ print("\n⚙️ Setting up optimizer (AdamW 8-bit)...")
156
+ trainable_params = [p for p in pipe.transformer.parameters() if p.requires_grad]
157
+
158
+ import bitsandbytes as bnb
159
+ optimizer = bnb.optim.AdamW8bit(
160
+ trainable_params,
161
+ lr=1e-4,
162
+ betas=(0.9, 0.999),
163
+ weight_decay=0.01
164
+ )
165
+
166
+ # Simple dataset
167
+ print("\n📂 Loading dataset...")
168
+
169
+ class SimpleDataset(torch.utils.data.Dataset):
170
+ def __init__(self, data_dir, resolution=512):
171
+ self.data_dir = Path(data_dir)
172
+ self.resolution = resolution
173
+ self.image_files = sorted(list(self.data_dir.rglob("*.png")))
174
+
175
+ def __len__(self):
176
+ return len(self.image_files)
177
+
178
+ def __getitem__(self, idx):
179
+ img_path = self.image_files[idx]
180
+ caption_path = img_path.with_suffix(".txt")
181
+
182
+ image = Image.open(img_path).convert("RGB")
183
+ image = image.resize((self.resolution, self.resolution))
184
+ image = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
185
+ image = image * 2.0 - 1.0
186
+
187
+ caption = ""
188
+ if caption_path.exists():
189
+ caption = caption_path.read_text().strip()
190
+
191
+ return {"images": image, "captions": caption}
192
+
193
+ import numpy as np
194
+ dataset = SimpleDataset("./training_data/images")
195
+ dataloader = torch.utils.data.DataLoader(
196
+ dataset,
197
+ batch_size=1,
198
+ shuffle=True,
199
+ num_workers=0
200
+ )
201
+
202
+ print(f"✅ Dataset loaded: {len(dataset)} images")
203
+
204
+ # Prepare with accelerator
205
+ pipe.transformer, optimizer, dataloader = accelerator.prepare(
206
+ pipe.transformer, optimizer, dataloader
207
+ )
208
+
209
+ # Training loop
210
+ print("\n" + "="*70)
211
+ print("🏋️ Starting Training")
212
+ print("="*70)
213
+ print(f"Resuming from step {global_step} to step 1000")
214
+ print(f"Steps remaining: {1000 - global_step}")
215
+
216
+ pipe.transformer.train()
217
+
218
+ progress_bar = tqdm(total=1000, initial=global_step, desc="Training")
219
+
220
+ while global_step < 1000:
221
+ for batch in dataloader:
222
+ with accelerator.accumulate(pipe.transformer):
223
+ # Get batch data
224
+ images = batch["images"].to(device)
225
+ captions = batch["captions"]
226
+
227
+ # Add trigger word
228
+ captions = [f"pixel art sprite, {c}" for c in captions]
229
+
230
+ # Training step (simplified flow matching)
231
+ with torch.no_grad():
232
+ latents = pipe.vae.encode(images).latent_dist.sample()
233
+ noise = torch.randn_like(latents)
234
+ timesteps = torch.rand(latents.shape[0], device=device) * 1000
235
+
236
+ # Flow matching
237
+ sigmas = timesteps.view(-1, 1, 1, 1) / 1000
238
+ noisy_latents = (1 - sigmas) * latents + sigmas * noise
239
+ target = noise - latents
240
+
241
+ # Get text embeddings
242
+ with torch.no_grad():
243
+ prompt_embeds = pipe.encode_prompt(captions)[0]
244
+
245
+ # Predict
246
+ model_output = pipe.transformer(
247
+ hidden_states=noisy_latents,
248
+ timestep=timesteps,
249
+ encoder_hidden_states=prompt_embeds,
250
+ return_dict=False
251
+ )[0]
252
+
253
+ # Loss
254
+ loss = F.mse_loss(model_output.float(), target.float())
255
+
256
+ accelerator.backward(loss)
257
+
258
+ if accelerator.sync_gradients:
259
+ accelerator.clip_grad_norm_(pipe.transformer.parameters(), 1.0)
260
+
261
+ optimizer.step()
262
+ optimizer.zero_grad()
263
+
264
+ if accelerator.sync_gradients:
265
+ global_step += 1
266
+ progress_bar.update(1)
267
+ progress_bar.set_postfix({"loss": loss.item()})
268
+
269
+ # Save checkpoint every 500 steps
270
+ if global_step % 500 == 0:
271
+ print(f"\n💾 Saving checkpoint at step {global_step}...")
272
+ save_dir = output_dir / f"step_{global_step}"
273
+ save_dir.mkdir(exist_ok=True)
274
+
275
+ unwrapped = accelerator.unwrap_model(pipe.transformer)
276
+ save_file(
277
+ get_peft_model_state_dict(unwrapped),
278
+ save_dir / "pytorch_lora_weights.safetensors"
279
+ )
280
+
281
+ # Push to hub
282
+ upload_folder(
283
+ folder_path=save_dir,
284
+ repo_id=OUTPUT_REPO,
285
+ repo_type="model",
286
+ token=token
287
+ )
288
+ print(f"✅ Checkpoint pushed to {OUTPUT_REPO}")
289
+
290
+ if global_step >= 1000:
291
+ break
292
+
293
+ progress_bar.close()
294
+
295
+ # Final save
296
+ print("\n💾 Saving final checkpoint...")
297
+ save_dir = output_dir / "final"
298
+ save_dir.mkdir(exist_ok=True)
299
+
300
+ unwrapped = accelerator.unwrap_model(pipe.transformer)
301
+ save_file(
302
+ get_peft_model_state_dict(unwrapped),
303
+ save_dir / "pytorch_lora_weights.safetensors"
304
+ )
305
+
306
+ upload_folder(
307
+ folder_path=save_dir,
308
+ repo_id=OUTPUT_REPO,
309
+ repo_type="model",
310
+ token=token
311
+ )
312
+
313
+ print("\n" + "="*70)
314
+ print("✅ Training Complete!")
315
+ print("="*70)
316
+ print(f"\n📤 Model saved to: {OUTPUT_REPO}")
317
+ print(f" https://huggingface.co/{OUTPUT_REPO}")
318
+
319
+ def main():
320
+ # Get token from environment
321
+ token = os.environ.get("HF_TOKEN")
322
+ if not token:
323
+ print("❌ HF_TOKEN not found in environment!")
324
+ sys.exit(1)
325
+
326
+ print(f"Using token: {token[:7]}...")
327
+ train(token)
328
+
329
+ if __name__ == "__main__":
330
+ main()