File size: 9,228 Bytes
5f30413 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 |
"""
Evaluation and WandB visualization for diffusion models on The Well.
Produces:
- Single-step comparison images: Condition | Ground Truth | Prediction
- Multi-step rollout videos: GT trajectory vs Predicted trajectory (side-by-side)
- Per-step MSE metrics for rollout quality analysis
"""
import numpy as np
import torch
import torch.nn.functional as F
import logging
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Colormap helpers
# ---------------------------------------------------------------------------
def _get_colormap(name="RdBu_r"):
"""Return a colormap function (avoids repeated imports)."""
import matplotlib
matplotlib.use("Agg")
import matplotlib.cm as cm
return cm.get_cmap(name)
_CMAP_CACHE = {}
def apply_colormap(field_01, cmap_name="RdBu_r"):
"""[H, W] float in [0,1] → [H, W, 3] uint8 RGB."""
if cmap_name not in _CMAP_CACHE:
_CMAP_CACHE[cmap_name] = _get_colormap(cmap_name)
rgba = _CMAP_CACHE[cmap_name](np.clip(field_01, 0, 1))
return (rgba[:, :, :3] * 255).astype(np.uint8)
def normalize_for_vis(f, vmin=None, vmax=None):
"""Percentile-robust normalization to [0, 1]."""
if vmin is None:
vmin = np.percentile(f, 2)
if vmax is None:
vmax = np.percentile(f, 98)
return np.clip((f - vmin) / max(vmax - vmin, 1e-8), 0, 1), vmin, vmax
# ---------------------------------------------------------------------------
# Single-step evaluation
# ---------------------------------------------------------------------------
def _comparison_image(cond, gt, pred, cmap="RdBu_r"):
"""Build a [H, W*3+4, 3] uint8 image: Cond | GT | Pred."""
vals = np.concatenate([cond.flat, gt.flat, pred.flat])
vmin, vmax = np.percentile(vals, 2), np.percentile(vals, 98)
def rgb(f):
n, _, _ = normalize_for_vis(f, vmin, vmax)
return apply_colormap(n, cmap)
H = cond.shape[0]
sep = np.full((H, 2, 3), 200, dtype=np.uint8)
return np.concatenate([rgb(cond), sep, rgb(gt), sep, rgb(pred)], axis=1)
@torch.no_grad()
def single_step_eval(model, val_loader, device, n_batches=4, ddim_steps=50):
"""Compute val MSE and generate comparison images.
Returns:
metrics: dict {'val/mse': float}
comparisons: list of (image_array, caption_string)
"""
from data_pipeline import prepare_batch
model.eval()
total_mse, n_samples = 0.0, 0
first_data = None
for i, batch in enumerate(val_loader):
if i >= n_batches:
break
x_cond, x_target = prepare_batch(batch, device)
x_pred = model.sample_ddim(x_cond, shape=x_target.shape, steps=ddim_steps)
mse = F.mse_loss(x_pred, x_target).item()
total_mse += mse * x_target.shape[0]
n_samples += x_target.shape[0]
if i == 0:
first_data = (x_cond[:4].cpu(), x_target[:4].cpu(), x_pred[:4].cpu())
avg_mse = total_mse / max(n_samples, 1)
comparisons = []
if first_data is not None:
xc, xt, xp = first_data
n_ch = min(xc.shape[1], 4)
for b in range(xc.shape[0]):
for ch in range(n_ch):
img = _comparison_image(
xc[b, ch].numpy(), xt[b, ch].numpy(), xp[b, ch].numpy()
)
comparisons.append((img, f"sample{b}_ch{ch}"))
model.train()
return {"val/mse": avg_mse}, comparisons
# ---------------------------------------------------------------------------
# Multi-step rollout evaluation (produces WandB video)
# ---------------------------------------------------------------------------
@torch.no_grad()
def rollout_eval(
model, rollout_loader, device,
n_rollout=20, ddim_steps=50, channel=0, cmap="RdBu_r",
):
"""Autoregressive rollout with GT comparison video.
Creates side-by-side video: Ground Truth | Prediction
and computes per-step MSE.
Args:
model: GaussianDiffusion instance.
rollout_loader: DataLoader with n_steps_output >= n_rollout.
device: torch device.
n_rollout: autoregressive prediction steps.
ddim_steps: DDIM denoising steps per prediction.
channel: which field channel to visualize.
cmap: matplotlib colormap.
Returns:
video: [T, 3, H, W_combined] uint8 for wandb.Video.
per_step_mse: list[float] of length n_rollout.
"""
model.eval()
batch = next(iter(rollout_loader))
# Raw tensors from The Well (channels-last, keep time dim)
inp = batch["input_fields"][:1] # [1, Ti, H, W, C]
out = batch["output_fields"][:1] # [1, To, H, W, C]
T_out = out.shape[1]
n_steps = min(n_rollout, T_out)
C = inp.shape[-1]
# First condition frame → channels-first on device
x_cond = inp[:, 0].permute(0, 3, 1, 2).float().to(device) # [1, C, H, W]
# Ground truth frames (channels-first, CPU)
gt_frames = [out[:, t].permute(0, 3, 1, 2).float() for t in range(n_steps)]
# Autoregressive prediction
pred_frames = []
per_step_mse = []
cond = x_cond
for t in range(n_steps):
pred = model.sample_ddim(cond, shape=cond.shape, steps=ddim_steps, eta=0.0)
pred_cpu = pred.cpu()
pred_frames.append(pred_cpu)
mse_t = F.mse_loss(pred_cpu, gt_frames[t]).item()
per_step_mse.append(mse_t)
cond = pred # feed prediction back as next condition
if (t + 1) % 5 == 0:
logger.info(f" rollout step {t+1}/{n_steps}, mse={mse_t:.6f}")
# --- build video ---
ch = min(channel, C - 1)
# Shared color range across all frames
all_vals = [x_cond[0, ch].cpu().numpy().flat]
for t in range(n_steps):
all_vals.append(gt_frames[t][0, ch].numpy().flat)
all_vals.append(pred_frames[t][0, ch].numpy().flat)
all_vals = np.concatenate(list(all_vals))
vmin, vmax = np.percentile(all_vals, 2), np.percentile(all_vals, 98)
def to_rgb(field_2d):
n, _, _ = normalize_for_vis(field_2d, vmin, vmax)
return apply_colormap(n, cmap)
H, W = x_cond.shape[2], x_cond.shape[3]
sep = np.full((H, 4, 3), 200, dtype=np.uint8)
# Add text labels on the first frame
def _label_frame(gt_rgb, pred_rgb):
"""Concatenate with separator."""
return np.concatenate([gt_rgb, sep, pred_rgb], axis=1)
frames = []
# Frame 0: initial condition (same for both panels)
init_rgb = to_rgb(x_cond[0, ch].cpu().numpy())
frames.append(_label_frame(init_rgb, init_rgb).transpose(2, 0, 1))
# Frames 1..N
for t in range(n_steps):
gt_rgb = to_rgb(gt_frames[t][0, ch].numpy())
pred_rgb = to_rgb(pred_frames[t][0, ch].numpy())
frames.append(_label_frame(gt_rgb, pred_rgb).transpose(2, 0, 1))
video = np.stack(frames).astype(np.uint8) # [T, 3, H, W_combined]
model.train()
return video, per_step_mse
# ---------------------------------------------------------------------------
# Full evaluation entry point
# ---------------------------------------------------------------------------
def run_evaluation(
model, val_loader, rollout_loader, device,
global_step, wandb_run=None,
n_val_batches=4, n_rollout=20, ddim_steps=50,
):
"""Run full evaluation: single-step metrics + rollout video.
Logs everything to WandB if wandb_run is provided.
Returns:
dict of all metrics.
"""
logger.info("Running single-step evaluation...")
metrics, comparisons = single_step_eval(
model, val_loader, device, n_batches=n_val_batches, ddim_steps=ddim_steps
)
logger.info(f" val/mse = {metrics['val/mse']:.6f}")
logger.info(f"Running {n_rollout}-step rollout evaluation...")
video, rollout_mse = rollout_eval(
model, rollout_loader, device, n_rollout=n_rollout, ddim_steps=ddim_steps
)
logger.info(f" rollout MSE (step 1/last): {rollout_mse[0]:.6f} / {rollout_mse[-1]:.6f}")
# Aggregate rollout metrics
metrics["val/rollout_mse_mean"] = float(np.mean(rollout_mse))
metrics["val/rollout_mse_final"] = rollout_mse[-1]
for t, m in enumerate(rollout_mse):
metrics[f"val/rollout_mse_step{t}"] = m
# WandB logging
if wandb_run is not None:
import wandb
wandb_run.log(metrics, step=global_step)
# Comparison images (Cond | GT | Pred)
for img, caption in comparisons[:8]:
wandb_run.log(
{f"eval/{caption}": wandb.Image(img, caption="Cond | GT | Pred")},
step=global_step,
)
# Rollout video (GT | Pred side-by-side)
wandb_run.log(
{"eval/rollout_video": wandb.Video(video, fps=4, format="mp4",
caption="Left=GT Right=Prediction")},
step=global_step,
)
# Rollout MSE curve as a custom chart
table = wandb.Table(columns=["step", "mse"], data=[[t, m] for t, m in enumerate(rollout_mse)])
wandb_run.log(
{"eval/rollout_mse_curve": wandb.plot.line(
table, "step", "mse", title="Rollout MSE vs Step"
)},
step=global_step,
)
return metrics
|