dreamlessx commited on
Commit
ef63076
·
verified ·
1 Parent(s): eac09b2

Upload landmarkdiff/validation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. landmarkdiff/validation.py +231 -0
landmarkdiff/validation.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Validation callback for training loop monitoring.
2
+
3
+ Periodically generates sample images from the validation set, computes
4
+ metrics (SSIM, LPIPS, NME, identity similarity), and logs results
5
+ to WandB and/or disk.
6
+
7
+ Designed for use with train_controlnet.py — call at regular intervals
8
+ during training to monitor quality without disrupting the training loop.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import json
14
+ import time
15
+ from pathlib import Path
16
+
17
+ import cv2
18
+ import numpy as np
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from PIL import Image
22
+
23
+ from landmarkdiff.evaluation import compute_ssim, compute_lpips, compute_nme
24
+
25
+
26
+ class ValidationCallback:
27
+ """Validation callback that generates and evaluates samples during training.
28
+
29
+ Usage::
30
+
31
+ val_cb = ValidationCallback(
32
+ val_dataset=val_dataset,
33
+ output_dir=Path("checkpoints/val"),
34
+ num_samples=8,
35
+ )
36
+
37
+ # In training loop:
38
+ if global_step % val_every == 0:
39
+ val_metrics = val_cb.run(
40
+ controlnet=ema_controlnet,
41
+ vae=vae,
42
+ unet=unet,
43
+ text_embeddings=text_embeddings,
44
+ noise_scheduler=noise_scheduler,
45
+ device=device,
46
+ weight_dtype=weight_dtype,
47
+ global_step=global_step,
48
+ )
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ val_dataset,
54
+ output_dir: Path,
55
+ num_samples: int = 8,
56
+ num_inference_steps: int = 25,
57
+ guidance_scale: float = 7.5,
58
+ ):
59
+ self.val_dataset = val_dataset
60
+ self.output_dir = Path(output_dir)
61
+ self.output_dir.mkdir(parents=True, exist_ok=True)
62
+ self.num_samples = min(num_samples, len(val_dataset))
63
+ self.num_inference_steps = num_inference_steps
64
+ self.guidance_scale = guidance_scale
65
+ self.history: list[dict] = []
66
+
67
+ @torch.no_grad()
68
+ def run(
69
+ self,
70
+ controlnet: torch.nn.Module,
71
+ vae,
72
+ unet,
73
+ text_embeddings: torch.Tensor,
74
+ noise_scheduler,
75
+ device: torch.device,
76
+ weight_dtype: torch.dtype,
77
+ global_step: int,
78
+ ) -> dict:
79
+ """Run validation: generate samples and compute metrics.
80
+
81
+ Returns dict with aggregate metrics.
82
+ """
83
+ from diffusers import DPMSolverMultistepScheduler
84
+
85
+ t0 = time.time()
86
+ controlnet.eval()
87
+
88
+ step_dir = self.output_dir / f"step-{global_step}"
89
+ step_dir.mkdir(parents=True, exist_ok=True)
90
+
91
+ # Set up inference scheduler (DPM++ 2M for quality)
92
+ scheduler = DPMSolverMultistepScheduler.from_config(noise_scheduler.config)
93
+ scheduler.set_timesteps(self.num_inference_steps, device=device)
94
+
95
+ ssim_scores = []
96
+ lpips_scores = []
97
+ generated_images = []
98
+
99
+ for i in range(self.num_samples):
100
+ sample = self.val_dataset[i]
101
+ conditioning = sample["conditioning"].unsqueeze(0).to(device, dtype=weight_dtype)
102
+ target = sample["target"].unsqueeze(0).to(device, dtype=weight_dtype)
103
+
104
+ # Encode target for latent shape
105
+ latents = vae.encode(target * 2 - 1).latent_dist.sample()
106
+ latents = latents * vae.config.scaling_factor
107
+
108
+ # Start from noise
109
+ noise = torch.randn_like(latents)
110
+ sample_latents = noise * scheduler.init_noise_sigma
111
+ encoder_hidden_states = text_embeddings[:1]
112
+
113
+ # Denoising loop with classifier-free guidance
114
+ for t in scheduler.timesteps:
115
+ scaled = scheduler.scale_model_input(sample_latents, t)
116
+
117
+ # ControlNet
118
+ down_samples, mid_sample = controlnet(
119
+ scaled, t, encoder_hidden_states=encoder_hidden_states,
120
+ controlnet_cond=conditioning, return_dict=False,
121
+ )
122
+
123
+ # UNet with ControlNet residuals
124
+ noise_pred = unet(
125
+ scaled, t, encoder_hidden_states=encoder_hidden_states,
126
+ down_block_additional_residuals=down_samples,
127
+ mid_block_additional_residual=mid_sample,
128
+ ).sample
129
+
130
+ sample_latents = scheduler.step(noise_pred, t, sample_latents).prev_sample
131
+
132
+ # Decode (use float32 for VAE to avoid color banding)
133
+ decoded = vae.decode(sample_latents.float() / vae.config.scaling_factor).sample
134
+ decoded = ((decoded + 1) / 2).clamp(0, 1)
135
+
136
+ # Convert to numpy for metrics
137
+ gen_np = (decoded[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
138
+ tgt_np = (target[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
139
+ cond_np = (conditioning[0].float().permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
140
+
141
+ # BGR for metrics (our metrics expect BGR)
142
+ gen_bgr = gen_np[:, :, ::-1].copy()
143
+ tgt_bgr = tgt_np[:, :, ::-1].copy()
144
+
145
+ # Compute metrics
146
+ ssim_val = compute_ssim(gen_bgr, tgt_bgr)
147
+ lpips_val = compute_lpips(gen_bgr, tgt_bgr)
148
+ ssim_scores.append(ssim_val)
149
+ lpips_scores.append(lpips_val)
150
+ generated_images.append(gen_np)
151
+
152
+ # Save comparison: conditioning | generated | target
153
+ comparison = np.hstack([cond_np, gen_np, tgt_np])
154
+ Image.fromarray(comparison).save(step_dir / f"val_{i:02d}.png")
155
+
156
+ # Aggregate metrics
157
+ metrics = {
158
+ "step": global_step,
159
+ "ssim_mean": float(np.nanmean(ssim_scores)),
160
+ "ssim_std": float(np.nanstd(ssim_scores)),
161
+ "lpips_mean": float(np.nanmean(lpips_scores)),
162
+ "lpips_std": float(np.nanstd(lpips_scores)),
163
+ "time_seconds": round(time.time() - t0, 1),
164
+ }
165
+
166
+ self.history.append(metrics)
167
+
168
+ # Save metrics
169
+ with open(step_dir / "metrics.json", "w") as f:
170
+ json.dump(metrics, f, indent=2)
171
+
172
+ # Save full history
173
+ with open(self.output_dir / "validation_history.json", "w") as f:
174
+ json.dump(self.history, f, indent=2)
175
+
176
+ # Create comparison grid (all samples in one image)
177
+ if generated_images:
178
+ grid_rows = []
179
+ for i in range(0, len(generated_images), 4):
180
+ row_imgs = generated_images[i:i+4]
181
+ while len(row_imgs) < 4:
182
+ row_imgs.append(np.zeros_like(generated_images[0]))
183
+ grid_rows.append(np.hstack(row_imgs))
184
+ grid = np.vstack(grid_rows)
185
+ Image.fromarray(grid).save(step_dir / "grid.png")
186
+
187
+ controlnet.train()
188
+
189
+ print(
190
+ f" Validation @ step {global_step}: "
191
+ f"SSIM={metrics['ssim_mean']:.4f}±{metrics['ssim_std']:.4f} "
192
+ f"LPIPS={metrics['lpips_mean']:.4f}±{metrics['lpips_std']:.4f} "
193
+ f"({metrics['time_seconds']:.1f}s)"
194
+ )
195
+
196
+ return metrics
197
+
198
+ def plot_history(self, output_path: str | None = None) -> None:
199
+ """Plot validation metrics over training steps."""
200
+ if not self.history:
201
+ return
202
+
203
+ try:
204
+ import matplotlib
205
+ matplotlib.use("Agg")
206
+ import matplotlib.pyplot as plt
207
+ except ImportError:
208
+ return
209
+
210
+ steps = [h["step"] for h in self.history]
211
+ ssim = [h["ssim_mean"] for h in self.history]
212
+ lpips = [h["lpips_mean"] for h in self.history]
213
+
214
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
215
+
216
+ ax1.plot(steps, ssim, "b-o", markersize=4)
217
+ ax1.set_xlabel("Training Step")
218
+ ax1.set_ylabel("SSIM")
219
+ ax1.set_title("Validation SSIM (higher=better)")
220
+ ax1.grid(alpha=0.3)
221
+
222
+ ax2.plot(steps, lpips, "r-o", markersize=4)
223
+ ax2.set_xlabel("Training Step")
224
+ ax2.set_ylabel("LPIPS")
225
+ ax2.set_title("Validation LPIPS (lower=better)")
226
+ ax2.grid(alpha=0.3)
227
+
228
+ plt.tight_layout()
229
+ path = output_path or str(self.output_dir / "validation_curves.png")
230
+ plt.savefig(path, dpi=150, bbox_inches="tight")
231
+ plt.close()