Limbicnation commited on
Commit
875def4
ยท
verified ยท
1 Parent(s): dd8e2a1

Upload sprite_lora_final.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. sprite_lora_final.py +242 -0
sprite_lora_final.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "torch>=2.0.0",
5
+ # "diffusers>=0.25.0",
6
+ # "transformers>=4.35.0",
7
+ # "accelerate>=0.24.0",
8
+ # "peft>=0.7.0",
9
+ # "huggingface-hub>=0.20.0",
10
+ # "safetensors>=0.4.0",
11
+ # "Pillow>=10.0.0",
12
+ # "numpy>=1.24.0",
13
+ # "tqdm>=4.66.0",
14
+ # ]
15
+ # ///
16
+
17
+ """
18
+ Resume FLUX.2-klein-4B LoRA training from step 500 checkpoint.
19
+ Output: Limbicnation/pixel-art-lora
20
+ """
21
+
22
+ import os
23
+ import sys
24
+ import torch
25
+ import torch.nn.functional as F
26
+ from pathlib import Path
27
+ from tqdm import tqdm
28
+ from PIL import Image
29
+ import numpy as np
30
+
31
+ # Get token
32
+ token = os.environ.get("HF_TOKEN")
33
+ if not token or token == "$HF_TOKEN":
34
+ print("ERROR: HF_TOKEN not set")
35
+ sys.exit(1)
36
+
37
+ os.environ["HF_TOKEN"] = token
38
+
39
+ # Import after setting token
40
+ from huggingface_hub import login, hf_hub_download, snapshot_download, create_repo, upload_file
41
+ from diffusers import FluxPipeline
42
+ from peft import LoraConfig, get_peft_model, set_peft_model_state_dict
43
+ from safetensors.torch import load_file, save_file
44
+ from accelerate import Accelerator
45
+
46
+ CHECKPOINT_REPO = "Limbicnation/sprite-lora-checkpoint-step500"
47
+ DATASET_REPO = "Limbicnation/sprite-lora-training-data"
48
+ OUTPUT_REPO = "Limbicnation/pixel-art-lora"
49
+ BASE_MODEL = "black-forest-labs/FLUX.2-klein-4B"
50
+
51
+ def main():
52
+ print("="*70)
53
+ print("๐Ÿš€ FLUX.2-klein-4B LoRA Training - Final")
54
+ print("="*70)
55
+ print(f"Base model: {BASE_MODEL}")
56
+ print(f"Output: {OUTPUT_REPO}")
57
+ print(f"Resume: Step 500 -> 1000")
58
+
59
+ # Login
60
+ print("\n๐Ÿ”‘ Authenticating...")
61
+ login(token=token, add_to_git_credential=False)
62
+ print("โœ… Authenticated")
63
+
64
+ # Download checkpoint
65
+ print("\n๐Ÿ“ฅ Downloading checkpoint...")
66
+ os.makedirs("checkpoint", exist_ok=True)
67
+ hf_hub_download(
68
+ repo_id=CHECKPOINT_REPO,
69
+ filename="pytorch_lora_weights.safetensors",
70
+ repo_type="model",
71
+ local_dir="checkpoint",
72
+ token=token
73
+ )
74
+ print("โœ… Checkpoint downloaded")
75
+
76
+ # Download dataset
77
+ print("\n๐Ÿ“ฅ Downloading dataset...")
78
+ snapshot_download(
79
+ repo_id=DATASET_REPO,
80
+ repo_type="dataset",
81
+ local_dir="data",
82
+ token=token
83
+ )
84
+ image_files = list(Path("data").rglob("*.png"))
85
+ print(f"โœ… Dataset: {len(image_files)} images")
86
+
87
+ # Setup accelerator
88
+ accelerator = Accelerator(gradient_accumulation_steps=4, mixed_precision="bf16")
89
+ device = accelerator.device
90
+ print(f"\nโš™๏ธ Device: {device}")
91
+
92
+ # Load model
93
+ print(f"\n๐Ÿ“ฅ Loading {BASE_MODEL}...")
94
+ pipe = FluxPipeline.from_pretrained(
95
+ BASE_MODEL,
96
+ torch_dtype=torch.bfloat16,
97
+ token=token
98
+ )
99
+ pipe.enable_model_cpu_offload()
100
+ print("โœ… Model loaded")
101
+
102
+ # Apply LoRA
103
+ print("\n๐Ÿ”ง Applying LoRA (rank=64, alpha=128)...")
104
+ target_modules = []
105
+ for i in range(19):
106
+ target_modules.extend([
107
+ f"transformer_blocks.{i}.attn.to_q",
108
+ f"transformer_blocks.{i}.attn.to_k",
109
+ f"transformer_blocks.{i}.attn.to_v",
110
+ ])
111
+
112
+ lora_config = LoraConfig(r=64, lora_alpha=128, target_modules=target_modules, use_rslora=True)
113
+ pipe.transformer = get_peft_model(pipe.transformer, lora_config)
114
+
115
+ # Load checkpoint
116
+ print("\n๐Ÿ”„ Loading checkpoint...")
117
+ state_dict = load_file("checkpoint/pytorch_lora_weights.safetensors")
118
+ set_peft_model_state_dict(pipe.transformer, state_dict)
119
+ print("โœ… Checkpoint loaded, resuming from step 500")
120
+
121
+ global_step = 500
122
+
123
+ # Create output repo
124
+ print(f"\n๐Ÿ“ค Creating output repo...")
125
+ create_repo(OUTPUT_REPO, exist_ok=True, repo_type="model", token=token)
126
+
127
+ # Setup optimizer
128
+ trainable = [p for p in pipe.transformer.parameters() if p.requires_grad]
129
+ import bitsandbytes as bnb
130
+ optimizer = bnb.optim.AdamW8bit(trainable, lr=1e-4)
131
+
132
+ # Dataset
133
+ class Dataset(torch.utils.data.Dataset):
134
+ def __init__(self, root, res=512):
135
+ self.imgs = sorted(list(Path(root).rglob("*.png")))
136
+ self.res = res
137
+ def __len__(self): return len(self.imgs)
138
+ def __getitem__(self, idx):
139
+ img = Image.open(self.imgs[idx]).convert("RGB").resize((self.res, self.res))
140
+ img = torch.from_numpy(np.array(img)).permute(2,0,1).float()/255.0 * 2 - 1
141
+ txt = self.imgs[idx].with_suffix(".txt")
142
+ cap = txt.read_text().strip() if txt.exists() else ""
143
+ return {"images": img, "captions": cap}
144
+
145
+ dataset = Dataset("data/images")
146
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)
147
+ print(f"โœ… Dataset ready: {len(dataset)} images")
148
+
149
+ # Prepare
150
+ pipe.transformer, optimizer, dataloader = accelerator.prepare(
151
+ pipe.transformer, optimizer, dataloader
152
+ )
153
+
154
+ # Training
155
+ print("\n" + "="*70)
156
+ print("๐Ÿ‹๏ธ Training: Step 500 -> 1000")
157
+ print("="*70)
158
+
159
+ pipe.transformer.train()
160
+ pbar = tqdm(total=1000, initial=global_step, desc="Training")
161
+
162
+ while global_step < 1000:
163
+ for batch in dataloader:
164
+ with accelerator.accumulate(pipe.transformer):
165
+ imgs = batch["images"].to(device)
166
+ caps = [f"pixel art sprite, {c}" for c in batch["captions"]]
167
+
168
+ with torch.no_grad():
169
+ latents = pipe.vae.encode(imgs).latent_dist.sample()
170
+ noise = torch.randn_like(latents)
171
+ t = torch.rand(latents.shape[0], device=device) * 1000
172
+ sigmas = t.view(-1,1,1,1) / 1000
173
+ noisy = (1-sigmas)*latents + sigmas*noise
174
+ target = noise - latents
175
+
176
+ with torch.no_grad():
177
+ prompt_embeds = pipe.encode_prompt(caps)[0]
178
+
179
+ output = pipe.transformer(
180
+ hidden_states=noisy,
181
+ timestep=t,
182
+ encoder_hidden_states=prompt_embeds,
183
+ return_dict=False
184
+ )[0]
185
+
186
+ loss = F.mse_loss(output.float(), target.float())
187
+ accelerator.backward(loss)
188
+
189
+ if accelerator.sync_gradients:
190
+ accelerator.clip_grad_norm_(pipe.transformer.parameters(), 1.0)
191
+
192
+ optimizer.step()
193
+ optimizer.zero_grad()
194
+
195
+ if accelerator.sync_gradients:
196
+ global_step += 1
197
+ pbar.update(1)
198
+ pbar.set_postfix({"loss": f"{loss.item():.4f}"})
199
+
200
+ if global_step % 500 == 0:
201
+ print(f"\n๐Ÿ’พ Saving checkpoint at step {global_step}...")
202
+ os.makedirs(f"output/step_{global_step}", exist_ok=True)
203
+ save_file(
204
+ get_peft_model_state_dict(accelerator.unwrap_model(pipe.transformer)),
205
+ f"output/step_{global_step}/pytorch_lora_weights.safetensors"
206
+ )
207
+ upload_file(
208
+ path_or_fileobj=f"output/step_{global_step}/pytorch_lora_weights.safetensors",
209
+ path_in_repo=f"step_{global_step}/pytorch_lora_weights.safetensors",
210
+ repo_id=OUTPUT_REPO,
211
+ repo_type="model",
212
+ token=token
213
+ )
214
+ print("โœ… Checkpoint saved")
215
+
216
+ if global_step >= 1000:
217
+ break
218
+
219
+ pbar.close()
220
+
221
+ # Final save
222
+ print("\n๐Ÿ’พ Saving final model...")
223
+ os.makedirs("output/final", exist_ok=True)
224
+ save_file(
225
+ get_peft_model_state_dict(accelerator.unwrap_model(pipe.transformer)),
226
+ "output/final/pytorch_lora_weights.safetensors"
227
+ )
228
+ upload_file(
229
+ path_or_fileobj="output/final/pytorch_lora_weights.safetensors",
230
+ path_in_repo="pytorch_lora_weights.safetensors",
231
+ repo_id=OUTPUT_REPO,
232
+ repo_type="model",
233
+ token=token
234
+ )
235
+
236
+ print("\n" + "="*70)
237
+ print("โœ… Training Complete!")
238
+ print("="*70)
239
+ print(f"\n๐Ÿ“ค Model: https://huggingface.co/{OUTPUT_REPO}")
240
+
241
+ if __name__ == "__main__":
242
+ main()