""" Evaluation and WandB visualization for diffusion models on The Well. Produces: - Single-step comparison images: Condition | Ground Truth | Prediction - Multi-step rollout videos: GT trajectory vs Predicted trajectory (side-by-side) - Per-step MSE metrics for rollout quality analysis """ import numpy as np import torch import torch.nn.functional as F import logging logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Colormap helpers # --------------------------------------------------------------------------- def _get_colormap(name="RdBu_r"): """Return a colormap function (avoids repeated imports).""" import matplotlib matplotlib.use("Agg") import matplotlib.cm as cm return cm.get_cmap(name) _CMAP_CACHE = {} def apply_colormap(field_01, cmap_name="RdBu_r"): """[H, W] float in [0,1] → [H, W, 3] uint8 RGB.""" if cmap_name not in _CMAP_CACHE: _CMAP_CACHE[cmap_name] = _get_colormap(cmap_name) rgba = _CMAP_CACHE[cmap_name](np.clip(field_01, 0, 1)) return (rgba[:, :, :3] * 255).astype(np.uint8) def normalize_for_vis(f, vmin=None, vmax=None): """Percentile-robust normalization to [0, 1].""" if vmin is None: vmin = np.percentile(f, 2) if vmax is None: vmax = np.percentile(f, 98) return np.clip((f - vmin) / max(vmax - vmin, 1e-8), 0, 1), vmin, vmax # --------------------------------------------------------------------------- # Single-step evaluation # --------------------------------------------------------------------------- def _comparison_image(cond, gt, pred, cmap="RdBu_r"): """Build a [H, W*3+4, 3] uint8 image: Cond | GT | Pred.""" vals = np.concatenate([cond.flat, gt.flat, pred.flat]) vmin, vmax = np.percentile(vals, 2), np.percentile(vals, 98) def rgb(f): n, _, _ = normalize_for_vis(f, vmin, vmax) return apply_colormap(n, cmap) H = cond.shape[0] sep = np.full((H, 2, 3), 200, dtype=np.uint8) return np.concatenate([rgb(cond), sep, rgb(gt), sep, rgb(pred)], axis=1) @torch.no_grad() def single_step_eval(model, val_loader, device, n_batches=4, ddim_steps=50): """Compute val MSE and generate comparison images. Returns: metrics: dict {'val/mse': float} comparisons: list of (image_array, caption_string) """ from data_pipeline import prepare_batch model.eval() total_mse, n_samples = 0.0, 0 first_data = None for i, batch in enumerate(val_loader): if i >= n_batches: break x_cond, x_target = prepare_batch(batch, device) x_pred = model.sample_ddim(x_cond, shape=x_target.shape, steps=ddim_steps) mse = F.mse_loss(x_pred, x_target).item() total_mse += mse * x_target.shape[0] n_samples += x_target.shape[0] if i == 0: first_data = (x_cond[:4].cpu(), x_target[:4].cpu(), x_pred[:4].cpu()) avg_mse = total_mse / max(n_samples, 1) comparisons = [] if first_data is not None: xc, xt, xp = first_data n_ch = min(xc.shape[1], 4) for b in range(xc.shape[0]): for ch in range(n_ch): img = _comparison_image( xc[b, ch].numpy(), xt[b, ch].numpy(), xp[b, ch].numpy() ) comparisons.append((img, f"sample{b}_ch{ch}")) model.train() return {"val/mse": avg_mse}, comparisons # --------------------------------------------------------------------------- # Multi-step rollout evaluation (produces WandB video) # --------------------------------------------------------------------------- @torch.no_grad() def rollout_eval( model, rollout_loader, device, n_rollout=20, ddim_steps=50, channel=0, cmap="RdBu_r", ): """Autoregressive rollout with GT comparison video. Creates side-by-side video: Ground Truth | Prediction and computes per-step MSE. Args: model: GaussianDiffusion instance. rollout_loader: DataLoader with n_steps_output >= n_rollout. device: torch device. n_rollout: autoregressive prediction steps. ddim_steps: DDIM denoising steps per prediction. channel: which field channel to visualize. cmap: matplotlib colormap. Returns: video: [T, 3, H, W_combined] uint8 for wandb.Video. per_step_mse: list[float] of length n_rollout. """ model.eval() batch = next(iter(rollout_loader)) # Raw tensors from The Well (channels-last, keep time dim) inp = batch["input_fields"][:1] # [1, Ti, H, W, C] out = batch["output_fields"][:1] # [1, To, H, W, C] T_out = out.shape[1] n_steps = min(n_rollout, T_out) C = inp.shape[-1] # First condition frame → channels-first on device x_cond = inp[:, 0].permute(0, 3, 1, 2).float().to(device) # [1, C, H, W] # Ground truth frames (channels-first, CPU) gt_frames = [out[:, t].permute(0, 3, 1, 2).float() for t in range(n_steps)] # Autoregressive prediction pred_frames = [] per_step_mse = [] cond = x_cond for t in range(n_steps): pred = model.sample_ddim(cond, shape=cond.shape, steps=ddim_steps, eta=0.0) pred_cpu = pred.cpu() pred_frames.append(pred_cpu) mse_t = F.mse_loss(pred_cpu, gt_frames[t]).item() per_step_mse.append(mse_t) cond = pred # feed prediction back as next condition if (t + 1) % 5 == 0: logger.info(f" rollout step {t+1}/{n_steps}, mse={mse_t:.6f}") # --- build video --- ch = min(channel, C - 1) # Shared color range across all frames all_vals = [x_cond[0, ch].cpu().numpy().flat] for t in range(n_steps): all_vals.append(gt_frames[t][0, ch].numpy().flat) all_vals.append(pred_frames[t][0, ch].numpy().flat) all_vals = np.concatenate(list(all_vals)) vmin, vmax = np.percentile(all_vals, 2), np.percentile(all_vals, 98) def to_rgb(field_2d): n, _, _ = normalize_for_vis(field_2d, vmin, vmax) return apply_colormap(n, cmap) H, W = x_cond.shape[2], x_cond.shape[3] sep = np.full((H, 4, 3), 200, dtype=np.uint8) # Add text labels on the first frame def _label_frame(gt_rgb, pred_rgb): """Concatenate with separator.""" return np.concatenate([gt_rgb, sep, pred_rgb], axis=1) frames = [] # Frame 0: initial condition (same for both panels) init_rgb = to_rgb(x_cond[0, ch].cpu().numpy()) frames.append(_label_frame(init_rgb, init_rgb).transpose(2, 0, 1)) # Frames 1..N for t in range(n_steps): gt_rgb = to_rgb(gt_frames[t][0, ch].numpy()) pred_rgb = to_rgb(pred_frames[t][0, ch].numpy()) frames.append(_label_frame(gt_rgb, pred_rgb).transpose(2, 0, 1)) video = np.stack(frames).astype(np.uint8) # [T, 3, H, W_combined] model.train() return video, per_step_mse # --------------------------------------------------------------------------- # Full evaluation entry point # --------------------------------------------------------------------------- def run_evaluation( model, val_loader, rollout_loader, device, global_step, wandb_run=None, n_val_batches=4, n_rollout=20, ddim_steps=50, ): """Run full evaluation: single-step metrics + rollout video. Logs everything to WandB if wandb_run is provided. Returns: dict of all metrics. """ logger.info("Running single-step evaluation...") metrics, comparisons = single_step_eval( model, val_loader, device, n_batches=n_val_batches, ddim_steps=ddim_steps ) logger.info(f" val/mse = {metrics['val/mse']:.6f}") logger.info(f"Running {n_rollout}-step rollout evaluation...") video, rollout_mse = rollout_eval( model, rollout_loader, device, n_rollout=n_rollout, ddim_steps=ddim_steps ) logger.info(f" rollout MSE (step 1/last): {rollout_mse[0]:.6f} / {rollout_mse[-1]:.6f}") # Aggregate rollout metrics metrics["val/rollout_mse_mean"] = float(np.mean(rollout_mse)) metrics["val/rollout_mse_final"] = rollout_mse[-1] for t, m in enumerate(rollout_mse): metrics[f"val/rollout_mse_step{t}"] = m # WandB logging if wandb_run is not None: import wandb wandb_run.log(metrics, step=global_step) # Comparison images (Cond | GT | Pred) for img, caption in comparisons[:8]: wandb_run.log( {f"eval/{caption}": wandb.Image(img, caption="Cond | GT | Pred")}, step=global_step, ) # Rollout video (GT | Pred side-by-side) wandb_run.log( {"eval/rollout_video": wandb.Video(video, fps=4, format="mp4", caption="Left=GT Right=Prediction")}, step=global_step, ) # Rollout MSE curve as a custom chart table = wandb.Table(columns=["step", "mse"], data=[[t, m] for t, m in enumerate(rollout_mse)]) wandb_run.log( {"eval/rollout_mse_curve": wandb.plot.line( table, "step", "mse", title="Rollout MSE vs Step" )}, step=global_step, ) return metrics