AbstractPhil commited on
Commit
df0879c
·
verified ·
1 Parent(s): fdba648

Create trainer.py

Browse files
Files changed (1) hide show
  1. trainer.py +458 -0
trainer.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SD15 Flow-Matching trainer
3
+ Author: AbstractPhil
4
+
5
+ Loads the current format pt and ensures through multiple validations that the process is correct for training.
6
+
7
+ Trains flow matching for sd15.
8
+
9
+ License: MIT
10
+ If you use my work, a cite wouldnt hurt.
11
+
12
+ """
13
+
14
+ import os
15
+ import json
16
+ import datetime
17
+ from dataclasses import dataclass, asdict
18
+ from tqdm.auto import tqdm
19
+ import matplotlib.pyplot as plt
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from torch.utils.tensorboard import SummaryWriter
24
+ from torch.utils.data import DataLoader
25
+
26
+ import datasets
27
+ from diffusers import UNet2DConditionModel
28
+ from huggingface_hub import HfApi, create_repo, hf_hub_download
29
+
30
+
31
+ @dataclass
32
+ class TrainConfig:
33
+ output_dir: str = "./outputs"
34
+ model_repo: str = "AbstractPhil/sd15-flow-matching-try2"
35
+ checkpoint_filename: str = "sd15_flowmatch_david_weighted_2_e34.pt"
36
+ dataset_name: str = "AbstractPhil/sd15-latent-distillation-500k"
37
+
38
+ # HuggingFace upload settings
39
+ hf_repo_id: str = "AbstractPhil/sd15-flow-lune"
40
+ upload_to_hub: bool = True
41
+
42
+ seed: int = 42
43
+ batch_size: int = 16
44
+ base_lr: float = 2e-6
45
+ shift: float = 2.0
46
+ dropout: float = 0.1
47
+
48
+ max_train_steps: int = 50_000
49
+ checkpointing_steps: int = 1000
50
+ num_workers: int = 0
51
+
52
+ # VAE scaling factor - multiply raw latents
53
+ vae_scale: float = 0.18215
54
+
55
+
56
+ def load_student_unet(repo_id: str, filename: str, device="cuda") -> UNet2DConditionModel:
57
+ """Load UNet from .pt checkpoint containing student state_dict"""
58
+ # Download checkpoint from HuggingFace
59
+ print(f"Downloading checkpoint from {repo_id}/{filename}...")
60
+ checkpoint_path = hf_hub_download(
61
+ repo_id=repo_id,
62
+ filename=filename,
63
+ repo_type="model"
64
+ )
65
+ print(f"✓ Downloaded to: {checkpoint_path}")
66
+
67
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
68
+
69
+ # Initialize UNet with SD1.5 config in fp32
70
+ print("Loading SD1.5 UNet architecture...")
71
+ unet = UNet2DConditionModel.from_pretrained(
72
+ "runwayml/stable-diffusion-v1-5",
73
+ subfolder="unet",
74
+ torch_dtype=torch.float32
75
+ )
76
+
77
+ # Get original state for comparison
78
+ original_state_dict = {k: v.clone() for k, v in unet.state_dict().items()}
79
+
80
+ # Load student weights and strip "unet." prefix
81
+ student_state_dict = checkpoint["student"]
82
+
83
+ # Strip prefix if present
84
+ cleaned_student_dict = {}
85
+ for key, value in student_state_dict.items():
86
+ if key.startswith("unet."):
87
+ cleaned_key = key[5:] # Remove "unet." prefix
88
+ cleaned_student_dict[cleaned_key] = value
89
+ else:
90
+ cleaned_student_dict[key] = value
91
+
92
+ print(f"\n{'='*70}")
93
+ print("WEIGHT VERIFICATION")
94
+ print(f"{'='*70}")
95
+
96
+ # 1. Compare keys
97
+ original_keys = set(original_state_dict.keys())
98
+ student_keys = set(cleaned_student_dict.keys())
99
+
100
+ matching_keys = original_keys & student_keys
101
+
102
+ print(f"Original UNet keys: {len(original_keys)}")
103
+ print(f"Student checkpoint keys: {len(student_keys)}")
104
+ print(f"Matching keys: {len(matching_keys)}")
105
+
106
+ # 2. Compare student weights vs original BEFORE loading
107
+ total_params = 0
108
+ different_params = 0
109
+ mean_diff_sum = 0.0
110
+ max_diff = 0.0
111
+
112
+ for key in matching_keys:
113
+ if key not in original_state_dict or key not in cleaned_student_dict:
114
+ continue
115
+
116
+ orig = original_state_dict[key]
117
+ student = cleaned_student_dict[key].float() # Convert to fp32 for comparison
118
+
119
+ if orig.shape != student.shape:
120
+ print(f"⚠ Shape mismatch for {key}: {orig.shape} vs {student.shape}")
121
+ continue
122
+
123
+ total_params += orig.numel()
124
+
125
+ # Check if weights are different
126
+ diff = (orig - student).abs()
127
+ if diff.max() > 1e-6:
128
+ different_params += orig.numel()
129
+ mean_diff_sum += diff.sum().item()
130
+ max_diff = max(max_diff, diff.max().item())
131
+
132
+ pct_different = (different_params / total_params * 100) if total_params > 0 else 0
133
+ avg_diff = mean_diff_sum / different_params if different_params > 0 else 0
134
+
135
+ print(f"\nStudent vs Original (BEFORE loading):")
136
+ print(f" Total parameters: {total_params:,}")
137
+ print(f" Parameters different: {different_params:,} ({pct_different:.1f}%)")
138
+ print(f" Average difference: {avg_diff:.6f}")
139
+ print(f" Max difference: {max_diff:.6f}")
140
+
141
+ # 3. Load weights
142
+ load_result = unet.load_state_dict(cleaned_student_dict, strict=False)
143
+
144
+ if load_result.missing_keys:
145
+ print(f"\n⚠ Missing keys during load: {len(load_result.missing_keys)}")
146
+ for key in load_result.missing_keys[:3]:
147
+ print(f" - {key}")
148
+
149
+ if load_result.unexpected_keys:
150
+ print(f"⚠ Unexpected keys during load: {len(load_result.unexpected_keys)}")
151
+ for key in load_result.unexpected_keys[:3]:
152
+ print(f" - {key}")
153
+
154
+ # 4. Verify weights actually changed after loading
155
+ loaded_state_dict = unet.state_dict()
156
+
157
+ total_params_after = 0
158
+ changed_params = 0
159
+ mean_diff_after = 0.0
160
+ max_diff_after = 0.0
161
+
162
+ for key in matching_keys:
163
+ if key not in original_state_dict or key not in loaded_state_dict:
164
+ continue
165
+
166
+ orig = original_state_dict[key]
167
+ loaded = loaded_state_dict[key]
168
+
169
+ total_params_after += orig.numel()
170
+
171
+ diff = (orig - loaded).abs()
172
+ if diff.max() > 1e-6:
173
+ changed_params += orig.numel()
174
+ mean_diff_after += diff.sum().item()
175
+ max_diff_after = max(max_diff_after, diff.max().item())
176
+
177
+ pct_changed = (changed_params / total_params_after * 100) if total_params_after > 0 else 0
178
+ avg_diff_after = mean_diff_after / changed_params if changed_params > 0 else 0
179
+
180
+ print(f"\nOriginal vs Loaded (AFTER loading):")
181
+ print(f" Parameters changed: {changed_params:,} ({pct_changed:.1f}%)")
182
+ print(f" Average difference: {avg_diff_after:.6f}")
183
+ print(f" Max difference: {max_diff_after:.6f}")
184
+
185
+ print(f"\n{'='*70}")
186
+ # Verification checks
187
+ if pct_different < 50:
188
+ print(f"⚠️ WARNING: Student weights only {pct_different:.1f}% different from base!")
189
+ print(" This checkpoint may not be trained.")
190
+ elif pct_changed < 90:
191
+ print(f"⚠️ WARNING: Only {pct_changed:.1f}% of weights changed after loading!")
192
+ print(" The load may have failed.")
193
+ else:
194
+ print(f"✅ Weights loaded successfully!")
195
+ print(f" Checkpoint step: {checkpoint.get('gstep', 'unknown')}")
196
+ print(f" {pct_different:.1f}% of weights differ from base SD1.5")
197
+
198
+ print(f"{'='*70}\n")
199
+
200
+ return unet.to(device)
201
+
202
+
203
+ def train(config: TrainConfig):
204
+ device = "cuda"
205
+ torch.backends.cuda.matmul.allow_tf32 = True
206
+
207
+ torch.manual_seed(config.seed)
208
+ torch.cuda.manual_seed(config.seed)
209
+
210
+ # Setup output directory
211
+ date_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
212
+ real_output_dir = os.path.join(config.output_dir, date_time)
213
+ os.makedirs(real_output_dir, exist_ok=True)
214
+ t_writer = SummaryWriter(log_dir=real_output_dir, flush_secs=60)
215
+
216
+ # Initialize HuggingFace API
217
+ hf_api = None
218
+ if config.upload_to_hub:
219
+ try:
220
+ hf_api = HfApi()
221
+ create_repo(
222
+ repo_id=config.hf_repo_id,
223
+ repo_type="model",
224
+ exist_ok=True,
225
+ private=False
226
+ )
227
+ print(f"✓ HuggingFace repo ready: {config.hf_repo_id}")
228
+ except Exception as e:
229
+ print(f"⚠ Hub upload disabled: {e}")
230
+ config.upload_to_hub = False
231
+
232
+ # Save config locally and to hub
233
+ config_path = os.path.join(real_output_dir, "config.json")
234
+ with open(config_path, "w") as f:
235
+ json.dump(asdict(config), f, indent=2)
236
+
237
+ if config.upload_to_hub:
238
+ hf_api.upload_file(
239
+ path_or_fileobj=config_path,
240
+ path_in_repo="config.json",
241
+ repo_id=config.hf_repo_id,
242
+ repo_type="model"
243
+ )
244
+
245
+ # Load dataset in streaming mode
246
+ print(f"\nLoading dataset (streaming): {config.dataset_name}")
247
+ train_dataset = datasets.load_dataset(
248
+ config.dataset_name,
249
+ split="train",
250
+ streaming=True,
251
+ trust_remote_code=True
252
+ )
253
+ train_dataset = train_dataset.shuffle(seed=config.seed, buffer_size=1000)
254
+ print(f"✓ Dataset loaded in streaming mode")
255
+
256
+ def collate_fn(examples):
257
+ # Latents are RAW from VAE - need to scale them
258
+ latents = torch.stack([torch.tensor(ex["latent"]) for ex in examples])
259
+ latents = latents * config.vae_scale # Scale: ~[-6, 6] -> ~[-1, 1]
260
+
261
+ clip_embeddings = torch.stack([torch.tensor(ex["clip_embedding"]) for ex in examples])
262
+ ids = [ex["id"] for ex in examples]
263
+ prompts = [ex["prompt"] for ex in examples]
264
+
265
+ return latents, clip_embeddings, ids, prompts
266
+
267
+ train_dataloader = DataLoader(
268
+ dataset=train_dataset,
269
+ batch_size=config.batch_size,
270
+ collate_fn=collate_fn,
271
+ num_workers=config.num_workers,
272
+ )
273
+
274
+ # Verify first batch latent range (on GPU for speed)
275
+ print("\nVerifying latent scaling on first batch...")
276
+ first_batch = next(iter(train_dataloader))
277
+ latents_check, _, _, _ = first_batch
278
+ print(f"Raw latent range: [{latents_check.min():.3f}, {latents_check.max():.3f}]")
279
+ latents_check = latents_check.to(device)
280
+ print(f"After GPU transfer: [{latents_check.min():.3f}, {latents_check.max():.3f}]")
281
+ print(f"Expected: ~[-1, 1] for properly scaled latents")
282
+ del latents_check
283
+
284
+ # Load pretrained student UNet
285
+ print(f"\nLoading model from HuggingFace...")
286
+ unet = load_student_unet(config.model_repo, config.checkpoint_filename, device=device)
287
+ unet.requires_grad_(True)
288
+ unet.enable_gradient_checkpointing()
289
+ unet.train()
290
+
291
+ optimizer = torch.optim.Adam(
292
+ unet.parameters(),
293
+ lr=config.base_lr * (config.batch_size ** 0.5),
294
+ )
295
+
296
+ global_step = 0
297
+ train_logs = {
298
+ "train_step": [],
299
+ "train_loss": [],
300
+ "train_timestep": [],
301
+ "trained_images": []
302
+ }
303
+
304
+ def get_prediction(batch, log_to=None):
305
+ latents, encoder_hidden_states, ids, prompts = batch
306
+
307
+ # Everything in fp32
308
+ latents = latents.to(dtype=torch.float32, device=device)
309
+ encoder_hidden_states = encoder_hidden_states.to(dtype=torch.float32, device=device)
310
+
311
+ batch_size = latents.shape[0]
312
+
313
+ # Apply dropout to conditioning for CFG support
314
+ dropout_mask = torch.rand(batch_size, device=device) < config.dropout
315
+ encoder_hidden_states = encoder_hidden_states.clone()
316
+ encoder_hidden_states[dropout_mask] = 0
317
+
318
+ # Sample timesteps with shift
319
+ sigmas = torch.rand(batch_size, device=device)
320
+ sigmas = (config.shift * sigmas) / (1 + (config.shift - 1) * sigmas)
321
+ timesteps = sigmas * 1000
322
+ sigmas = sigmas[:, None, None, None]
323
+
324
+ # Flow matching forward process
325
+ noise = torch.randn_like(latents)
326
+ noisy_latents = noise * sigmas + latents * (1 - sigmas)
327
+ target = noise - latents
328
+
329
+ # Predict velocity
330
+ pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
331
+
332
+ loss = F.mse_loss(pred, target, reduction="none")
333
+ loss = loss.mean(dim=list(range(1, len(loss.shape))))
334
+
335
+ if log_to is not None:
336
+ for i in range(batch_size):
337
+ log_to["train_step"].append(global_step)
338
+ log_to["train_loss"].append(loss[i].item())
339
+ log_to["train_timestep"].append(timesteps[i].item())
340
+ log_to["trained_images"].append({
341
+ "step": global_step,
342
+ "id": ids[i],
343
+ "prompt": prompts[i]
344
+ })
345
+
346
+ return loss.mean()
347
+
348
+ def plot_logs(log_dict):
349
+ plt.figure(figsize=(10, 6))
350
+ plt.scatter(
351
+ log_dict["train_timestep"],
352
+ log_dict["train_loss"],
353
+ s=3,
354
+ c=log_dict["train_step"],
355
+ marker=".",
356
+ cmap='cool'
357
+ )
358
+ plt.xlabel("timestep")
359
+ plt.ylabel("loss")
360
+ plt.yscale("log")
361
+ plt.colorbar(label="step")
362
+
363
+ def save_checkpoint(step):
364
+ checkpoint_path = os.path.join(real_output_dir, f"checkpoint-{step:08}")
365
+ os.makedirs(checkpoint_path, exist_ok=True)
366
+
367
+ # Save UNet weights as diffusers format
368
+ unet.save_pretrained(
369
+ os.path.join(checkpoint_path, "unet"),
370
+ safe_serialization=True
371
+ )
372
+
373
+ # Save complete checkpoint in .pt format
374
+ pt_filename = f"sd15_flow_lune_e{step//1000}_s{step}.pt"
375
+ pt_path = os.path.join(checkpoint_path, pt_filename)
376
+
377
+ torch.save({
378
+ "cfg": asdict(config),
379
+ "student": unet.state_dict(),
380
+ "opt": optimizer.state_dict(),
381
+ "gstep": step
382
+ }, pt_path)
383
+
384
+ # Save training metadata
385
+ metadata = {
386
+ "step": step,
387
+ "trained_images": train_logs["trained_images"]
388
+ }
389
+ metadata_path = os.path.join(checkpoint_path, "trained_images.json")
390
+ with open(metadata_path, "w") as f:
391
+ json.dump(metadata, f, indent=2)
392
+
393
+ print(f"✓ Checkpoint saved at step {step}")
394
+
395
+ # Upload to HuggingFace Hub
396
+ if config.upload_to_hub and hf_api is not None:
397
+ try:
398
+ hf_api.upload_file(
399
+ path_or_fileobj=pt_path,
400
+ path_in_repo=pt_filename,
401
+ repo_id=config.hf_repo_id,
402
+ repo_type="model"
403
+ )
404
+
405
+ hf_api.upload_folder(
406
+ folder_path=os.path.join(checkpoint_path, "unet"),
407
+ path_in_repo=f"checkpoint-{step:08}/unet",
408
+ repo_id=config.hf_repo_id,
409
+ repo_type="model"
410
+ )
411
+
412
+ hf_api.upload_file(
413
+ path_or_fileobj=metadata_path,
414
+ path_in_repo=f"checkpoint-{step:08}/trained_images.json",
415
+ repo_id=config.hf_repo_id,
416
+ repo_type="model"
417
+ )
418
+
419
+ print(f"✓ Uploaded to hub: {config.hf_repo_id}")
420
+ except Exception as e:
421
+ print(f"⚠ Upload failed: {e}")
422
+
423
+ print("\nStarting training...")
424
+ progress_bar = tqdm(range(0, config.max_train_steps))
425
+
426
+ for batch in train_dataloader:
427
+ loss = get_prediction(batch, log_to=train_logs)
428
+ t_writer.add_scalar("train/loss", loss.detach().item(), global_step)
429
+
430
+ loss.backward()
431
+
432
+ grad_norm = torch.nn.utils.clip_grad_norm_(unet.parameters(), 2.0)
433
+ t_writer.add_scalar("train/grad_norm", grad_norm.detach().item(), global_step)
434
+
435
+ optimizer.step()
436
+ optimizer.zero_grad()
437
+
438
+ progress_bar.update(1)
439
+ progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
440
+ global_step += 1
441
+
442
+ if global_step % 100 == 0:
443
+ plot_logs(train_logs)
444
+ t_writer.add_figure("train_loss", plt.gcf(), global_step)
445
+ plt.close()
446
+
447
+ if global_step % config.checkpointing_steps == 0:
448
+ save_checkpoint(global_step)
449
+
450
+ if global_step >= config.max_train_steps:
451
+ save_checkpoint(global_step)
452
+ print("\n✅ Training complete!")
453
+ return
454
+
455
+
456
+ if __name__ == "__main__":
457
+ config = TrainConfig()
458
+ train(config)