AlexWortega commited on
Commit
5f30413
·
verified ·
1 Parent(s): 47f8396

Upload eval_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. eval_utils.py +274 -0
eval_utils.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation and WandB visualization for diffusion models on The Well.
3
+
4
+ Produces:
5
+ - Single-step comparison images: Condition | Ground Truth | Prediction
6
+ - Multi-step rollout videos: GT trajectory vs Predicted trajectory (side-by-side)
7
+ - Per-step MSE metrics for rollout quality analysis
8
+ """
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import logging
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ # ---------------------------------------------------------------------------
18
+ # Colormap helpers
19
+ # ---------------------------------------------------------------------------
20
+
21
+ def _get_colormap(name="RdBu_r"):
22
+ """Return a colormap function (avoids repeated imports)."""
23
+ import matplotlib
24
+ matplotlib.use("Agg")
25
+ import matplotlib.cm as cm
26
+ return cm.get_cmap(name)
27
+
28
+ _CMAP_CACHE = {}
29
+
30
+ def apply_colormap(field_01, cmap_name="RdBu_r"):
31
+ """[H, W] float in [0,1] → [H, W, 3] uint8 RGB."""
32
+ if cmap_name not in _CMAP_CACHE:
33
+ _CMAP_CACHE[cmap_name] = _get_colormap(cmap_name)
34
+ rgba = _CMAP_CACHE[cmap_name](np.clip(field_01, 0, 1))
35
+ return (rgba[:, :, :3] * 255).astype(np.uint8)
36
+
37
+
38
+ def normalize_for_vis(f, vmin=None, vmax=None):
39
+ """Percentile-robust normalization to [0, 1]."""
40
+ if vmin is None:
41
+ vmin = np.percentile(f, 2)
42
+ if vmax is None:
43
+ vmax = np.percentile(f, 98)
44
+ return np.clip((f - vmin) / max(vmax - vmin, 1e-8), 0, 1), vmin, vmax
45
+
46
+
47
+ # ---------------------------------------------------------------------------
48
+ # Single-step evaluation
49
+ # ---------------------------------------------------------------------------
50
+
51
+ def _comparison_image(cond, gt, pred, cmap="RdBu_r"):
52
+ """Build a [H, W*3+4, 3] uint8 image: Cond | GT | Pred."""
53
+ vals = np.concatenate([cond.flat, gt.flat, pred.flat])
54
+ vmin, vmax = np.percentile(vals, 2), np.percentile(vals, 98)
55
+
56
+ def rgb(f):
57
+ n, _, _ = normalize_for_vis(f, vmin, vmax)
58
+ return apply_colormap(n, cmap)
59
+
60
+ H = cond.shape[0]
61
+ sep = np.full((H, 2, 3), 200, dtype=np.uint8)
62
+ return np.concatenate([rgb(cond), sep, rgb(gt), sep, rgb(pred)], axis=1)
63
+
64
+
65
+ @torch.no_grad()
66
+ def single_step_eval(model, val_loader, device, n_batches=4, ddim_steps=50):
67
+ """Compute val MSE and generate comparison images.
68
+
69
+ Returns:
70
+ metrics: dict {'val/mse': float}
71
+ comparisons: list of (image_array, caption_string)
72
+ """
73
+ from data_pipeline import prepare_batch
74
+
75
+ model.eval()
76
+ total_mse, n_samples = 0.0, 0
77
+ first_data = None
78
+
79
+ for i, batch in enumerate(val_loader):
80
+ if i >= n_batches:
81
+ break
82
+ x_cond, x_target = prepare_batch(batch, device)
83
+ x_pred = model.sample_ddim(x_cond, shape=x_target.shape, steps=ddim_steps)
84
+
85
+ mse = F.mse_loss(x_pred, x_target).item()
86
+ total_mse += mse * x_target.shape[0]
87
+ n_samples += x_target.shape[0]
88
+
89
+ if i == 0:
90
+ first_data = (x_cond[:4].cpu(), x_target[:4].cpu(), x_pred[:4].cpu())
91
+
92
+ avg_mse = total_mse / max(n_samples, 1)
93
+
94
+ comparisons = []
95
+ if first_data is not None:
96
+ xc, xt, xp = first_data
97
+ n_ch = min(xc.shape[1], 4)
98
+ for b in range(xc.shape[0]):
99
+ for ch in range(n_ch):
100
+ img = _comparison_image(
101
+ xc[b, ch].numpy(), xt[b, ch].numpy(), xp[b, ch].numpy()
102
+ )
103
+ comparisons.append((img, f"sample{b}_ch{ch}"))
104
+
105
+ model.train()
106
+ return {"val/mse": avg_mse}, comparisons
107
+
108
+
109
+ # ---------------------------------------------------------------------------
110
+ # Multi-step rollout evaluation (produces WandB video)
111
+ # ---------------------------------------------------------------------------
112
+
113
+ @torch.no_grad()
114
+ def rollout_eval(
115
+ model, rollout_loader, device,
116
+ n_rollout=20, ddim_steps=50, channel=0, cmap="RdBu_r",
117
+ ):
118
+ """Autoregressive rollout with GT comparison video.
119
+
120
+ Creates side-by-side video: Ground Truth | Prediction
121
+ and computes per-step MSE.
122
+
123
+ Args:
124
+ model: GaussianDiffusion instance.
125
+ rollout_loader: DataLoader with n_steps_output >= n_rollout.
126
+ device: torch device.
127
+ n_rollout: autoregressive prediction steps.
128
+ ddim_steps: DDIM denoising steps per prediction.
129
+ channel: which field channel to visualize.
130
+ cmap: matplotlib colormap.
131
+
132
+ Returns:
133
+ video: [T, 3, H, W_combined] uint8 for wandb.Video.
134
+ per_step_mse: list[float] of length n_rollout.
135
+ """
136
+ model.eval()
137
+ batch = next(iter(rollout_loader))
138
+
139
+ # Raw tensors from The Well (channels-last, keep time dim)
140
+ inp = batch["input_fields"][:1] # [1, Ti, H, W, C]
141
+ out = batch["output_fields"][:1] # [1, To, H, W, C]
142
+
143
+ T_out = out.shape[1]
144
+ n_steps = min(n_rollout, T_out)
145
+ C = inp.shape[-1]
146
+
147
+ # First condition frame → channels-first on device
148
+ x_cond = inp[:, 0].permute(0, 3, 1, 2).float().to(device) # [1, C, H, W]
149
+
150
+ # Ground truth frames (channels-first, CPU)
151
+ gt_frames = [out[:, t].permute(0, 3, 1, 2).float() for t in range(n_steps)]
152
+
153
+ # Autoregressive prediction
154
+ pred_frames = []
155
+ per_step_mse = []
156
+ cond = x_cond
157
+
158
+ for t in range(n_steps):
159
+ pred = model.sample_ddim(cond, shape=cond.shape, steps=ddim_steps, eta=0.0)
160
+ pred_cpu = pred.cpu()
161
+ pred_frames.append(pred_cpu)
162
+
163
+ mse_t = F.mse_loss(pred_cpu, gt_frames[t]).item()
164
+ per_step_mse.append(mse_t)
165
+
166
+ cond = pred # feed prediction back as next condition
167
+ if (t + 1) % 5 == 0:
168
+ logger.info(f" rollout step {t+1}/{n_steps}, mse={mse_t:.6f}")
169
+
170
+ # --- build video ---
171
+ ch = min(channel, C - 1)
172
+
173
+ # Shared color range across all frames
174
+ all_vals = [x_cond[0, ch].cpu().numpy().flat]
175
+ for t in range(n_steps):
176
+ all_vals.append(gt_frames[t][0, ch].numpy().flat)
177
+ all_vals.append(pred_frames[t][0, ch].numpy().flat)
178
+ all_vals = np.concatenate(list(all_vals))
179
+ vmin, vmax = np.percentile(all_vals, 2), np.percentile(all_vals, 98)
180
+
181
+ def to_rgb(field_2d):
182
+ n, _, _ = normalize_for_vis(field_2d, vmin, vmax)
183
+ return apply_colormap(n, cmap)
184
+
185
+ H, W = x_cond.shape[2], x_cond.shape[3]
186
+ sep = np.full((H, 4, 3), 200, dtype=np.uint8)
187
+
188
+ # Add text labels on the first frame
189
+ def _label_frame(gt_rgb, pred_rgb):
190
+ """Concatenate with separator."""
191
+ return np.concatenate([gt_rgb, sep, pred_rgb], axis=1)
192
+
193
+ frames = []
194
+
195
+ # Frame 0: initial condition (same for both panels)
196
+ init_rgb = to_rgb(x_cond[0, ch].cpu().numpy())
197
+ frames.append(_label_frame(init_rgb, init_rgb).transpose(2, 0, 1))
198
+
199
+ # Frames 1..N
200
+ for t in range(n_steps):
201
+ gt_rgb = to_rgb(gt_frames[t][0, ch].numpy())
202
+ pred_rgb = to_rgb(pred_frames[t][0, ch].numpy())
203
+ frames.append(_label_frame(gt_rgb, pred_rgb).transpose(2, 0, 1))
204
+
205
+ video = np.stack(frames).astype(np.uint8) # [T, 3, H, W_combined]
206
+
207
+ model.train()
208
+ return video, per_step_mse
209
+
210
+
211
+ # ---------------------------------------------------------------------------
212
+ # Full evaluation entry point
213
+ # ---------------------------------------------------------------------------
214
+
215
+ def run_evaluation(
216
+ model, val_loader, rollout_loader, device,
217
+ global_step, wandb_run=None,
218
+ n_val_batches=4, n_rollout=20, ddim_steps=50,
219
+ ):
220
+ """Run full evaluation: single-step metrics + rollout video.
221
+
222
+ Logs everything to WandB if wandb_run is provided.
223
+
224
+ Returns:
225
+ dict of all metrics.
226
+ """
227
+ logger.info("Running single-step evaluation...")
228
+ metrics, comparisons = single_step_eval(
229
+ model, val_loader, device, n_batches=n_val_batches, ddim_steps=ddim_steps
230
+ )
231
+ logger.info(f" val/mse = {metrics['val/mse']:.6f}")
232
+
233
+ logger.info(f"Running {n_rollout}-step rollout evaluation...")
234
+ video, rollout_mse = rollout_eval(
235
+ model, rollout_loader, device, n_rollout=n_rollout, ddim_steps=ddim_steps
236
+ )
237
+ logger.info(f" rollout MSE (step 1/last): {rollout_mse[0]:.6f} / {rollout_mse[-1]:.6f}")
238
+
239
+ # Aggregate rollout metrics
240
+ metrics["val/rollout_mse_mean"] = float(np.mean(rollout_mse))
241
+ metrics["val/rollout_mse_final"] = rollout_mse[-1]
242
+ for t, m in enumerate(rollout_mse):
243
+ metrics[f"val/rollout_mse_step{t}"] = m
244
+
245
+ # WandB logging
246
+ if wandb_run is not None:
247
+ import wandb
248
+
249
+ wandb_run.log(metrics, step=global_step)
250
+
251
+ # Comparison images (Cond | GT | Pred)
252
+ for img, caption in comparisons[:8]:
253
+ wandb_run.log(
254
+ {f"eval/{caption}": wandb.Image(img, caption="Cond | GT | Pred")},
255
+ step=global_step,
256
+ )
257
+
258
+ # Rollout video (GT | Pred side-by-side)
259
+ wandb_run.log(
260
+ {"eval/rollout_video": wandb.Video(video, fps=4, format="mp4",
261
+ caption="Left=GT Right=Prediction")},
262
+ step=global_step,
263
+ )
264
+
265
+ # Rollout MSE curve as a custom chart
266
+ table = wandb.Table(columns=["step", "mse"], data=[[t, m] for t, m in enumerate(rollout_mse)])
267
+ wandb_run.log(
268
+ {"eval/rollout_mse_curve": wandb.plot.line(
269
+ table, "step", "mse", title="Rollout MSE vs Step"
270
+ )},
271
+ step=global_step,
272
+ )
273
+
274
+ return metrics