AbstractPhil commited on
Commit
b08d242
·
verified ·
1 Parent(s): 8ed8311

Create lune_mask_trainer.py

Browse files
Files changed (1) hide show
  1. lune_mask_trainer.py +627 -0
lune_mask_trainer.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SD15 Flow-Matching Trainer - ControlNet Pose Edition
3
+ Author: AbstractPhil
4
+
5
+ Trains Lune on controlnet pose dataset with transparent backgrounds.
6
+
7
+ License: MIT
8
+ """
9
+
10
+ import os
11
+ import json
12
+ import datetime
13
+ import random
14
+ from dataclasses import dataclass, asdict, field
15
+ from tqdm.auto import tqdm
16
+ import matplotlib.pyplot as plt
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch.utils.tensorboard import SummaryWriter
21
+ from torch.utils.data import DataLoader
22
+
23
+ import datasets
24
+ from diffusers import UNet2DConditionModel, AutoencoderKL
25
+ from transformers import CLIPTextModel, CLIPTokenizer
26
+ from huggingface_hub import HfApi, create_repo, hf_hub_download
27
+
28
+
29
+ @dataclass
30
+ class TrainConfig:
31
+ output_dir: str = "./outputs"
32
+ model_repo: str = "AbstractPhil/sd15-flow-lune"
33
+ checkpoint_filename: str = "sd15_flow_pretrain_pose_controlnet_t500_700_s8312.pt"
34
+ dataset_name: str = "AbstractPhil/CN_pose3D_V7_512"
35
+ use_masks: bool = True
36
+ mask_column: str = "mask"
37
+
38
+ # HuggingFace upload settings
39
+ hf_repo_id: str = "AbstractPhil/sd15-flow-lune"
40
+ upload_to_hub: bool = True
41
+
42
+ # Run identification
43
+ run_name: str = "pretrain_pose_controlnet_v7_v10_t400_600"
44
+
45
+ # Checkpoint continuation
46
+ continue_from_checkpoint: bool = False
47
+
48
+ seed: int = 42
49
+ batch_size: int = 64
50
+
51
+ # learning params
52
+ base_lr: float = 2e-6
53
+ shift: float = 2.5
54
+ dropout: float = 0.1
55
+ min_snr_gamma: float = 5.0
56
+
57
+ # Timestep range - training on mid-to-late denoising (400-600)
58
+ # This targets the structural refinement phase
59
+ min_timestep: float = 400.0
60
+ max_timestep: float = 600.0
61
+
62
+ # Training schedule
63
+ num_train_epochs: int = 1
64
+ warmup_epochs: int = 1
65
+ checkpointing_steps: int = 2500
66
+ num_workers: int = 0
67
+
68
+ # VAE scaling factor
69
+ vae_scale: float = 0.18215
70
+
71
+ # Prompt preprocessing
72
+ delimiter: str = ","
73
+ preserved_count: int = 2 # preserve first N tokens before shuffle prepented after shuffle
74
+ remove_these: list = field(default_factory=lambda: [
75
+ "simple background",
76
+ "white background"])
77
+ prepend_prompt: str = "doll" # prepended after shuffle
78
+ append_prompt: str = "transparent background" # final appended suffix
79
+ shuffle_prompt: bool = True
80
+
81
+
82
+ def preprocess_caption(text: str, config: TrainConfig) -> str:
83
+ """
84
+ Preprocess controlnet pose captions with config-based shuffling:
85
+ - Lowercase and clean punctuation
86
+ - Remove unwanted tokens from config.remove_these
87
+ - Prepend config.prepend_prompt
88
+ - Shuffle tokens (preserving first config.preserved_count)
89
+ - Append config.append_prompt
90
+ """
91
+ # Handle None or empty text
92
+ if text is None or text == "":
93
+ if config.append_prompt:
94
+ return config.append_prompt
95
+ return ""
96
+
97
+ # Basic cleaning
98
+ text = text.lower()
99
+ text = text.replace(".", config.delimiter)
100
+ text = text.strip()
101
+
102
+ # Clean up multiple delimiters and spaces
103
+ while f"{config.delimiter}{config.delimiter}" in text:
104
+ text = text.replace(f"{config.delimiter}{config.delimiter}", config.delimiter)
105
+ while " " in text:
106
+ text = text.replace(" ", " ")
107
+
108
+ text = text.strip()
109
+
110
+ # Remove leading/trailing delimiters
111
+ if text.startswith(config.delimiter):
112
+ text = text[1:].strip()
113
+ if text.endswith(config.delimiter):
114
+ text = text[:-1].strip()
115
+
116
+ # Prepend prompt (before shuffling)
117
+ if config.prepend_prompt:
118
+ text = f"{config.prepend_prompt}{config.delimiter} {text}" if text else config.prepend_prompt
119
+
120
+ # Apply prompt shuffling
121
+ if config.shuffle_prompt and text:
122
+ # Split on delimiter
123
+ tokens = [t.strip() for t in text.split(config.delimiter) if t.strip()]
124
+
125
+ # Remove unwanted tokens
126
+ if config.remove_these:
127
+ tokens = [t for t in tokens if t not in config.remove_these]
128
+
129
+ # Separate preserved vs shuffleable
130
+ preserved = tokens[:config.preserved_count]
131
+ shuffleable = tokens[config.preserved_count:]
132
+
133
+ # Shuffle the rest
134
+ random.shuffle(shuffleable)
135
+
136
+ # Reconstruct
137
+ tokens = preserved + shuffleable
138
+ text = f"{config.delimiter} ".join(tokens)
139
+ else:
140
+ # Even without shuffling, remove unwanted tokens
141
+ if config.remove_these and text:
142
+ tokens = [t.strip() for t in text.split(config.delimiter) if t.strip()]
143
+ tokens = [t for t in tokens if t not in config.remove_these]
144
+ text = f"{config.delimiter} ".join(tokens)
145
+
146
+ # Append prompt (after shuffling)
147
+ if config.append_prompt:
148
+ text = f"{text}{config.delimiter} {config.append_prompt}" if text else config.append_prompt
149
+
150
+ return text
151
+
152
+
153
+ def load_student_unet(repo_id: str, filename: str, device="cuda"):
154
+ """Load UNet from checkpoint, return checkpoint dict for optional optimizer/scheduler restoration"""
155
+ print(f"Downloading checkpoint from {repo_id}/{filename}...")
156
+ checkpoint_path = hf_hub_download(
157
+ repo_id=repo_id,
158
+ filename=filename,
159
+ repo_type="model"
160
+ )
161
+
162
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
163
+
164
+ print("Loading SD1.5 UNet architecture...")
165
+ unet = UNet2DConditionModel.from_pretrained(
166
+ "runwayml/stable-diffusion-v1-5",
167
+ subfolder="unet",
168
+ torch_dtype=torch.float32
169
+ )
170
+
171
+ # Load student weights
172
+ student_state_dict = checkpoint["student"]
173
+
174
+ # Strip "unet." prefix if present
175
+ cleaned_dict = {}
176
+ for key, value in student_state_dict.items():
177
+ cleaned_key = key[5:] if key.startswith("unet.") else key
178
+ cleaned_dict[cleaned_key] = value
179
+
180
+ unet.load_state_dict(cleaned_dict, strict=False)
181
+
182
+ print(f"✓ Loaded UNet from step {checkpoint.get('gstep', 'unknown')}")
183
+
184
+ return unet.to(device), checkpoint
185
+
186
+
187
+ def train(config: TrainConfig):
188
+ device = "cuda"
189
+ torch.backends.cuda.matmul.allow_tf32 = True
190
+
191
+ torch.manual_seed(config.seed)
192
+ torch.cuda.manual_seed(config.seed)
193
+
194
+ # Setup output directory
195
+ date_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
196
+ real_output_dir = os.path.join(config.output_dir, date_time)
197
+ os.makedirs(real_output_dir, exist_ok=True)
198
+ t_writer = SummaryWriter(log_dir=real_output_dir, flush_secs=60)
199
+
200
+ # Initialize HuggingFace API
201
+ hf_api = None
202
+ if config.upload_to_hub:
203
+ try:
204
+ hf_api = HfApi()
205
+ create_repo(
206
+ repo_id=config.hf_repo_id,
207
+ repo_type="model",
208
+ exist_ok=True,
209
+ private=False
210
+ )
211
+ print(f"✓ HuggingFace repo ready: {config.hf_repo_id}")
212
+ except Exception as e:
213
+ print(f"⚠ Hub upload disabled: {e}")
214
+ config.upload_to_hub = False
215
+
216
+ # Save config
217
+ config_path = os.path.join(real_output_dir, "config.json")
218
+ with open(config_path, "w") as f:
219
+ json.dump(asdict(config), f, indent=2)
220
+
221
+ if config.upload_to_hub:
222
+ hf_api.upload_file(
223
+ path_or_fileobj=config_path,
224
+ path_in_repo="config.json",
225
+ repo_id=config.hf_repo_id,
226
+ repo_type="model"
227
+ )
228
+
229
+ # Load SD1.5 VAE and CLIP
230
+ print("\nLoading SD1.5 VAE and CLIP...")
231
+ vae = AutoencoderKL.from_pretrained(
232
+ "runwayml/stable-diffusion-v1-5",
233
+ subfolder="vae",
234
+ torch_dtype=torch.float32
235
+ ).to(device)
236
+ vae.requires_grad_(False)
237
+ vae.eval()
238
+
239
+ tokenizer = CLIPTokenizer.from_pretrained(
240
+ "runwayml/stable-diffusion-v1-5",
241
+ subfolder="tokenizer"
242
+ )
243
+ text_encoder = CLIPTextModel.from_pretrained(
244
+ "runwayml/stable-diffusion-v1-5",
245
+ subfolder="text_encoder",
246
+ torch_dtype=torch.float32
247
+ ).to(device)
248
+ text_encoder.requires_grad_(False)
249
+ text_encoder.eval()
250
+
251
+ print("✓ VAE and CLIP loaded")
252
+
253
+ # Load dataset - columns: image, conditioning_image, mask, text
254
+ print(f"\nLoading dataset: {config.dataset_name}")
255
+ train_dataset = datasets.load_dataset(
256
+ config.dataset_name,
257
+ split="train"
258
+ )
259
+
260
+ print(f"✓ Loaded {len(train_dataset):,} images")
261
+ print(f" Columns: {train_dataset.column_names}")
262
+
263
+ # Calculate steps
264
+ steps_per_epoch = len(train_dataset) // config.batch_size
265
+ total_steps = steps_per_epoch * config.num_train_epochs
266
+ warmup_steps = steps_per_epoch * config.warmup_epochs
267
+
268
+ print(f"\nTraining schedule:")
269
+ print(f" Total images: {len(train_dataset):,}")
270
+ print(f" Batch size: {config.batch_size}")
271
+ print(f" Steps per epoch: {steps_per_epoch:,}")
272
+ print(f" Total epochs: {config.num_train_epochs}")
273
+ print(f" Total steps: {total_steps:,}")
274
+ print(f" Warmup steps: {warmup_steps:,}")
275
+ print(f"\nTimestep range:")
276
+ print(f" Min timestep: {config.min_timestep}")
277
+ print(f" Max timestep: {config.max_timestep}")
278
+ print(f" Training on: {config.max_timestep - config.min_timestep} timestep range")
279
+ print(f"\nPrompt preprocessing:")
280
+ print(f" Shuffle: {config.shuffle_prompt}")
281
+ print(f" Preserved tokens: {config.preserved_count}")
282
+ print(f" Prepend: '{config.prepend_prompt}'")
283
+ print(f" Append: '{config.append_prompt}'")
284
+ print(f" Remove: {config.remove_these}")
285
+
286
+ @torch.no_grad()
287
+ def collate_fn(examples):
288
+ """Encode images, masks (optional), and prompts at runtime"""
289
+ import numpy as np
290
+
291
+ images = []
292
+ masks = []
293
+ prompts = []
294
+ image_ids = []
295
+
296
+ for idx, ex in enumerate(examples):
297
+ # Convert PIL image to tensor
298
+ img = ex['image'].convert('RGB')
299
+ img = torch.tensor(np.array(img)).permute(2, 0, 1).float() / 255.0
300
+ img = img * 2.0 - 1.0 # Normalize to [-1, 1]
301
+ images.append(img)
302
+
303
+ # Conditionally load mask
304
+ if config.use_masks and config.mask_column in ex:
305
+ # Mask (0=ignore, 255=keep) -> convert to [0, 1]
306
+ mask = ex[config.mask_column].convert('L')
307
+ mask = torch.tensor(np.array(mask)).float() / 255.0
308
+ masks.append(mask)
309
+
310
+ # Preprocess caption with config
311
+ raw_text = ex['text']
312
+ processed_prompt = preprocess_caption(raw_text, config)
313
+ prompts.append(processed_prompt)
314
+ image_ids.append(idx)
315
+
316
+ images = torch.stack(images).to(device)
317
+
318
+ # Encode images with VAE
319
+ latents = vae.encode(images).latent_dist.sample()
320
+ latents = latents * config.vae_scale
321
+
322
+ # Conditionally process masks
323
+ if config.use_masks and masks:
324
+ masks = torch.stack(masks).to(device)
325
+ # Downsample masks to latent resolution (64x64 -> 8x8 for 512x512 images)
326
+ masks_downsampled = F.interpolate(
327
+ masks.unsqueeze(1),
328
+ size=latents.shape[-2:],
329
+ mode='nearest'
330
+ ).squeeze(1)
331
+ else:
332
+ # Create dummy masks (all ones) for consistent batch structure
333
+ masks_downsampled = torch.ones(
334
+ (latents.shape[0], latents.shape[2], latents.shape[3]),
335
+ dtype=torch.float32
336
+ )
337
+
338
+ # Encode prompts with CLIP
339
+ text_inputs = tokenizer(
340
+ prompts,
341
+ padding="max_length",
342
+ max_length=tokenizer.model_max_length,
343
+ truncation=True,
344
+ return_tensors="pt"
345
+ ).to(device)
346
+
347
+ encoder_hidden_states = text_encoder(text_inputs.input_ids)[0]
348
+
349
+ return (
350
+ latents.cpu(),
351
+ masks_downsampled.cpu(),
352
+ encoder_hidden_states.cpu(),
353
+ image_ids,
354
+ prompts
355
+ )
356
+
357
+ train_dataloader = DataLoader(
358
+ dataset=train_dataset,
359
+ batch_size=config.batch_size,
360
+ shuffle=True,
361
+ collate_fn=collate_fn,
362
+ num_workers=config.num_workers,
363
+ pin_memory=True
364
+ )
365
+
366
+ # Load student UNet
367
+ print(f"\nLoading model from HuggingFace...")
368
+ unet, checkpoint = load_student_unet(config.model_repo, config.checkpoint_filename, device=device)
369
+ unet.requires_grad_(True)
370
+ unet.train()
371
+
372
+ # Fresh optimizer
373
+ optimizer = torch.optim.AdamW(
374
+ unet.parameters(),
375
+ lr=config.base_lr,
376
+ betas=(0.9, 0.999),
377
+ weight_decay=0.01,
378
+ eps=1e-8
379
+ )
380
+
381
+ # Warmup scheduler
382
+ if config.continue_from_checkpoint:
383
+ scheduler = torch.optim.lr_scheduler.LambdaLR(
384
+ optimizer,
385
+ lr_lambda=lambda step: 1.0
386
+ )
387
+ else:
388
+ def get_lr_scale(step):
389
+ if step < warmup_steps:
390
+ return step / warmup_steps
391
+ return 1.0
392
+
393
+ scheduler = torch.optim.lr_scheduler.LambdaLR(
394
+ optimizer,
395
+ lr_lambda=get_lr_scale
396
+ )
397
+
398
+ # Optionally continue from checkpoint
399
+ start_step = 0
400
+
401
+ if config.continue_from_checkpoint:
402
+ if "opt" in checkpoint and "scheduler" in checkpoint:
403
+ optimizer.load_state_dict(checkpoint["opt"])
404
+ scheduler.load_state_dict(checkpoint["scheduler"])
405
+ start_step = checkpoint.get("gstep", 0)
406
+ print(f"✓ Resumed optimizer and scheduler from step {start_step}")
407
+ print(f" Will train for {config.num_train_epochs} more epoch(s) = {total_steps:,} additional steps")
408
+ else:
409
+ print("⚠ No optimizer/scheduler state in checkpoint, starting fresh")
410
+ else:
411
+ print("✓ Starting with fresh optimizer (no state loaded)")
412
+
413
+ global_step = start_step
414
+ end_step = start_step + total_steps
415
+ train_logs = {
416
+ "train_step": [],
417
+ "train_loss": [],
418
+ "train_timestep": [],
419
+ "trained_images": []
420
+ }
421
+
422
+ def get_prediction(batch, log_to=None):
423
+ latents, masks, encoder_hidden_states, ids, prompts = batch
424
+
425
+ latents = latents.to(dtype=torch.float32, device=device)
426
+ if config.use_masks:
427
+ masks = masks.to(dtype=torch.float32, device=device)
428
+ encoder_hidden_states = encoder_hidden_states.to(dtype=torch.float32, device=device)
429
+
430
+ batch_size = latents.shape[0]
431
+
432
+ # Apply dropout for CFG support
433
+ dropout_mask = torch.rand(batch_size, device=device) < config.dropout
434
+ encoder_hidden_states = encoder_hidden_states.clone()
435
+ encoder_hidden_states[dropout_mask] = 0
436
+
437
+ # Sample timesteps with shift - constrained to [min_timestep, max_timestep]
438
+ min_sigma = config.min_timestep / 1000.0
439
+ max_sigma = config.max_timestep / 1000.0
440
+
441
+ sigmas = torch.rand(batch_size, device=device)
442
+ sigmas = min_sigma + sigmas * (max_sigma - min_sigma)
443
+
444
+ # Apply shift transformation
445
+ sigmas = (config.shift * sigmas) / (1 + (config.shift - 1) * sigmas)
446
+ timesteps = sigmas * 1000
447
+ sigmas = sigmas[:, None, None, None]
448
+
449
+ # Flow matching
450
+ noise = torch.randn_like(latents)
451
+ noisy_latents = noise * sigmas + latents * (1 - sigmas)
452
+ target = noise - latents
453
+
454
+ # Predict velocity (standard 4-channel input)
455
+ pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
456
+
457
+ # Calculate loss
458
+ loss = F.mse_loss(pred, target, reduction="none")
459
+ loss = loss.mean(dim=1) # Average over channels: [B, H, W]
460
+
461
+ # Apply Min-SNR weighting for velocity prediction
462
+ # SNR = (1 - sigma)^2 / sigma^2
463
+ snr = ((1 - sigmas.squeeze()) ** 2) / (sigmas.squeeze() ** 2 + 1e-8)
464
+ snr_weight = torch.minimum(snr, torch.ones_like(snr) * config.min_snr_gamma) / snr
465
+
466
+ # Velocity prediction adjustment: divide by (SNR + 1)
467
+ snr_weight = snr_weight / (snr + 1)
468
+ snr_weight = snr_weight[:, None, None] # [B, 1, 1] for broadcasting
469
+
470
+ loss = loss * snr_weight # Apply SNR weighting
471
+
472
+ # Conditionally apply mask
473
+ if config.use_masks:
474
+ # Apply mask: only compute loss on non-masked regions
475
+ # masks: [B, H, W] with 1=keep, 0=ignore
476
+ masked_loss = loss * masks
477
+
478
+ # Average over spatial dimensions, weighted by mask
479
+ loss_per_sample = masked_loss.sum(dim=[1, 2]) / (masks.sum(dim=[1, 2]) + 1e-8)
480
+ else:
481
+ # Standard spatial average
482
+ loss_per_sample = loss.mean(dim=[1, 2])
483
+
484
+ if log_to is not None:
485
+ for i in range(batch_size):
486
+ log_to["train_step"].append(global_step)
487
+ log_to["train_loss"].append(loss_per_sample[i].item())
488
+ log_to["train_timestep"].append(timesteps[i].item())
489
+ log_to["trained_images"].append({
490
+ "step": global_step,
491
+ "id": ids[i],
492
+ "prompt": prompts[i]
493
+ })
494
+
495
+ return loss_per_sample.mean()
496
+
497
+ def plot_logs(log_dict):
498
+ plt.figure(figsize=(10, 6))
499
+ plt.scatter(
500
+ log_dict["train_timestep"],
501
+ log_dict["train_loss"],
502
+ s=3,
503
+ c=log_dict["train_step"],
504
+ marker=".",
505
+ cmap='cool'
506
+ )
507
+ plt.xlabel("timestep")
508
+ plt.ylabel("loss")
509
+ plt.yscale("log")
510
+ plt.colorbar(label="step")
511
+
512
+ def save_checkpoint(step, relative_epoch):
513
+ checkpoint_path = os.path.join(real_output_dir, f"{config.run_name}_checkpoint-{step:08}")
514
+ os.makedirs(checkpoint_path, exist_ok=True)
515
+
516
+ # Save UNet weights as diffusers format
517
+ unet.save_pretrained(
518
+ os.path.join(checkpoint_path, "unet"),
519
+ safe_serialization=True
520
+ )
521
+
522
+ # Save complete checkpoint
523
+ pt_filename = f"sd15_flow_{config.run_name}_s{step}.pt"
524
+ pt_path = os.path.join(checkpoint_path, pt_filename)
525
+
526
+ torch.save({
527
+ "cfg": asdict(config),
528
+ "student": unet.state_dict(),
529
+ "opt": optimizer.state_dict(),
530
+ "scheduler": scheduler.state_dict(),
531
+ "gstep": step,
532
+ "relative_epoch": relative_epoch
533
+ }, pt_path)
534
+
535
+ # Save metadata
536
+ metadata = {
537
+ "step": step,
538
+ "relative_epoch": relative_epoch,
539
+ "trained_images": train_logs["trained_images"]
540
+ }
541
+ metadata_path = os.path.join(checkpoint_path, "trained_images.json")
542
+ with open(metadata_path, "w") as f:
543
+ json.dump(metadata, f, indent=2)
544
+
545
+ print(f"✓ Checkpoint saved at step {step} (relative epoch {relative_epoch})")
546
+
547
+ # Upload to hub
548
+ if config.upload_to_hub and hf_api is not None:
549
+ try:
550
+ hf_api.upload_file(
551
+ path_or_fileobj=pt_path,
552
+ path_in_repo=pt_filename,
553
+ repo_id=config.hf_repo_id,
554
+ repo_type="model"
555
+ )
556
+ hf_api.upload_folder(
557
+ folder_path=os.path.join(checkpoint_path, "unet"),
558
+ path_in_repo=f"{config.run_name}/checkpoint-{step:08}/unet",
559
+ repo_id=config.hf_repo_id,
560
+ repo_type="model"
561
+ )
562
+ hf_api.upload_file(
563
+ path_or_fileobj=metadata_path,
564
+ path_in_repo=f"{config.run_name}/checkpoint-{step:08}/trained_images.json",
565
+ repo_id=config.hf_repo_id,
566
+ repo_type="model"
567
+ )
568
+ print(f"✓ Uploaded to hub: {config.hf_repo_id}")
569
+ except Exception as e:
570
+ print(f"⚠ Upload failed: {e}")
571
+
572
+ print("\nStarting training...")
573
+ progress_bar = tqdm(total=total_steps, initial=0)
574
+
575
+ epoch = 0
576
+ while global_step < end_step:
577
+ epoch += 1
578
+ for batch in train_dataloader:
579
+ if global_step >= end_step:
580
+ break
581
+
582
+ loss = get_prediction(batch, log_to=train_logs)
583
+ t_writer.add_scalar("train/loss", loss.item(), global_step)
584
+ t_writer.add_scalar("train/lr", scheduler.get_last_lr()[0], global_step)
585
+
586
+ # Log timestep distribution
587
+ if len(train_logs["train_timestep"]) > 0:
588
+ recent_timesteps = train_logs["train_timestep"][-config.batch_size:]
589
+ t_writer.add_scalar("train/mean_timestep", sum(recent_timesteps) / len(recent_timesteps), global_step)
590
+ t_writer.add_scalar("train/min_timestep", min(recent_timesteps), global_step)
591
+ t_writer.add_scalar("train/max_timestep", max(recent_timesteps), global_step)
592
+
593
+ loss.backward()
594
+
595
+ grad_norm = torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
596
+ t_writer.add_scalar("train/grad_norm", grad_norm.item(), global_step)
597
+
598
+ optimizer.step()
599
+ scheduler.step()
600
+ optimizer.zero_grad()
601
+
602
+ progress_bar.update(1)
603
+ progress_bar.set_postfix({
604
+ "epoch": epoch,
605
+ "loss": f"{loss.item():.4f}",
606
+ "lr": f"{scheduler.get_last_lr()[0]:.2e}",
607
+ "gstep": global_step
608
+ })
609
+ global_step += 1
610
+
611
+ if global_step % 100 == 0:
612
+ plot_logs(train_logs)
613
+ t_writer.add_figure("train_loss", plt.gcf(), global_step)
614
+ plt.close()
615
+
616
+ if global_step % config.checkpointing_steps == 0:
617
+ save_checkpoint(global_step, epoch)
618
+
619
+ # End of epoch checkpoint
620
+ save_checkpoint(global_step, epoch)
621
+
622
+ print("\n✅ Training complete!")
623
+
624
+
625
+ if __name__ == "__main__":
626
+ config = TrainConfig()
627
+ train(config)