|
|
""" |
|
|
Evaluation and WandB visualization for diffusion models on The Well. |
|
|
|
|
|
Produces: |
|
|
- Single-step comparison images: Condition | Ground Truth | Prediction |
|
|
- Multi-step rollout videos: GT trajectory vs Predicted trajectory (side-by-side) |
|
|
- Per-step MSE metrics for rollout quality analysis |
|
|
""" |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_colormap(name="RdBu_r"): |
|
|
"""Return a colormap function (avoids repeated imports).""" |
|
|
import matplotlib |
|
|
matplotlib.use("Agg") |
|
|
import matplotlib.cm as cm |
|
|
return cm.get_cmap(name) |
|
|
|
|
|
_CMAP_CACHE = {} |
|
|
|
|
|
def apply_colormap(field_01, cmap_name="RdBu_r"): |
|
|
"""[H, W] float in [0,1] → [H, W, 3] uint8 RGB.""" |
|
|
if cmap_name not in _CMAP_CACHE: |
|
|
_CMAP_CACHE[cmap_name] = _get_colormap(cmap_name) |
|
|
rgba = _CMAP_CACHE[cmap_name](np.clip(field_01, 0, 1)) |
|
|
return (rgba[:, :, :3] * 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
def normalize_for_vis(f, vmin=None, vmax=None): |
|
|
"""Percentile-robust normalization to [0, 1].""" |
|
|
if vmin is None: |
|
|
vmin = np.percentile(f, 2) |
|
|
if vmax is None: |
|
|
vmax = np.percentile(f, 98) |
|
|
return np.clip((f - vmin) / max(vmax - vmin, 1e-8), 0, 1), vmin, vmax |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _comparison_image(cond, gt, pred, cmap="RdBu_r"): |
|
|
"""Build a [H, W*3+4, 3] uint8 image: Cond | GT | Pred.""" |
|
|
vals = np.concatenate([cond.flat, gt.flat, pred.flat]) |
|
|
vmin, vmax = np.percentile(vals, 2), np.percentile(vals, 98) |
|
|
|
|
|
def rgb(f): |
|
|
n, _, _ = normalize_for_vis(f, vmin, vmax) |
|
|
return apply_colormap(n, cmap) |
|
|
|
|
|
H = cond.shape[0] |
|
|
sep = np.full((H, 2, 3), 200, dtype=np.uint8) |
|
|
return np.concatenate([rgb(cond), sep, rgb(gt), sep, rgb(pred)], axis=1) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def single_step_eval(model, val_loader, device, n_batches=4, ddim_steps=50): |
|
|
"""Compute val MSE and generate comparison images. |
|
|
|
|
|
Returns: |
|
|
metrics: dict {'val/mse': float} |
|
|
comparisons: list of (image_array, caption_string) |
|
|
""" |
|
|
from data_pipeline import prepare_batch |
|
|
|
|
|
model.eval() |
|
|
total_mse, n_samples = 0.0, 0 |
|
|
first_data = None |
|
|
|
|
|
for i, batch in enumerate(val_loader): |
|
|
if i >= n_batches: |
|
|
break |
|
|
x_cond, x_target = prepare_batch(batch, device) |
|
|
x_pred = model.sample_ddim(x_cond, shape=x_target.shape, steps=ddim_steps) |
|
|
|
|
|
mse = F.mse_loss(x_pred, x_target).item() |
|
|
total_mse += mse * x_target.shape[0] |
|
|
n_samples += x_target.shape[0] |
|
|
|
|
|
if i == 0: |
|
|
first_data = (x_cond[:4].cpu(), x_target[:4].cpu(), x_pred[:4].cpu()) |
|
|
|
|
|
avg_mse = total_mse / max(n_samples, 1) |
|
|
|
|
|
comparisons = [] |
|
|
if first_data is not None: |
|
|
xc, xt, xp = first_data |
|
|
n_ch = min(xc.shape[1], 4) |
|
|
for b in range(xc.shape[0]): |
|
|
for ch in range(n_ch): |
|
|
img = _comparison_image( |
|
|
xc[b, ch].numpy(), xt[b, ch].numpy(), xp[b, ch].numpy() |
|
|
) |
|
|
comparisons.append((img, f"sample{b}_ch{ch}")) |
|
|
|
|
|
model.train() |
|
|
return {"val/mse": avg_mse}, comparisons |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def rollout_eval( |
|
|
model, rollout_loader, device, |
|
|
n_rollout=20, ddim_steps=50, channel=0, cmap="RdBu_r", |
|
|
): |
|
|
"""Autoregressive rollout with GT comparison video. |
|
|
|
|
|
Creates side-by-side video: Ground Truth | Prediction |
|
|
and computes per-step MSE. |
|
|
|
|
|
Args: |
|
|
model: GaussianDiffusion instance. |
|
|
rollout_loader: DataLoader with n_steps_output >= n_rollout. |
|
|
device: torch device. |
|
|
n_rollout: autoregressive prediction steps. |
|
|
ddim_steps: DDIM denoising steps per prediction. |
|
|
channel: which field channel to visualize. |
|
|
cmap: matplotlib colormap. |
|
|
|
|
|
Returns: |
|
|
video: [T, 3, H, W_combined] uint8 for wandb.Video. |
|
|
per_step_mse: list[float] of length n_rollout. |
|
|
""" |
|
|
model.eval() |
|
|
batch = next(iter(rollout_loader)) |
|
|
|
|
|
|
|
|
inp = batch["input_fields"][:1] |
|
|
out = batch["output_fields"][:1] |
|
|
|
|
|
T_out = out.shape[1] |
|
|
n_steps = min(n_rollout, T_out) |
|
|
C = inp.shape[-1] |
|
|
|
|
|
|
|
|
x_cond = inp[:, 0].permute(0, 3, 1, 2).float().to(device) |
|
|
|
|
|
|
|
|
gt_frames = [out[:, t].permute(0, 3, 1, 2).float() for t in range(n_steps)] |
|
|
|
|
|
|
|
|
pred_frames = [] |
|
|
per_step_mse = [] |
|
|
cond = x_cond |
|
|
|
|
|
for t in range(n_steps): |
|
|
pred = model.sample_ddim(cond, shape=cond.shape, steps=ddim_steps, eta=0.0) |
|
|
pred_cpu = pred.cpu() |
|
|
pred_frames.append(pred_cpu) |
|
|
|
|
|
mse_t = F.mse_loss(pred_cpu, gt_frames[t]).item() |
|
|
per_step_mse.append(mse_t) |
|
|
|
|
|
cond = pred |
|
|
if (t + 1) % 5 == 0: |
|
|
logger.info(f" rollout step {t+1}/{n_steps}, mse={mse_t:.6f}") |
|
|
|
|
|
|
|
|
ch = min(channel, C - 1) |
|
|
|
|
|
|
|
|
all_vals = [x_cond[0, ch].cpu().numpy().flat] |
|
|
for t in range(n_steps): |
|
|
all_vals.append(gt_frames[t][0, ch].numpy().flat) |
|
|
all_vals.append(pred_frames[t][0, ch].numpy().flat) |
|
|
all_vals = np.concatenate(list(all_vals)) |
|
|
vmin, vmax = np.percentile(all_vals, 2), np.percentile(all_vals, 98) |
|
|
|
|
|
def to_rgb(field_2d): |
|
|
n, _, _ = normalize_for_vis(field_2d, vmin, vmax) |
|
|
return apply_colormap(n, cmap) |
|
|
|
|
|
H, W = x_cond.shape[2], x_cond.shape[3] |
|
|
sep = np.full((H, 4, 3), 200, dtype=np.uint8) |
|
|
|
|
|
|
|
|
def _label_frame(gt_rgb, pred_rgb): |
|
|
"""Concatenate with separator.""" |
|
|
return np.concatenate([gt_rgb, sep, pred_rgb], axis=1) |
|
|
|
|
|
frames = [] |
|
|
|
|
|
|
|
|
init_rgb = to_rgb(x_cond[0, ch].cpu().numpy()) |
|
|
frames.append(_label_frame(init_rgb, init_rgb).transpose(2, 0, 1)) |
|
|
|
|
|
|
|
|
for t in range(n_steps): |
|
|
gt_rgb = to_rgb(gt_frames[t][0, ch].numpy()) |
|
|
pred_rgb = to_rgb(pred_frames[t][0, ch].numpy()) |
|
|
frames.append(_label_frame(gt_rgb, pred_rgb).transpose(2, 0, 1)) |
|
|
|
|
|
video = np.stack(frames).astype(np.uint8) |
|
|
|
|
|
model.train() |
|
|
return video, per_step_mse |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_evaluation( |
|
|
model, val_loader, rollout_loader, device, |
|
|
global_step, wandb_run=None, |
|
|
n_val_batches=4, n_rollout=20, ddim_steps=50, |
|
|
): |
|
|
"""Run full evaluation: single-step metrics + rollout video. |
|
|
|
|
|
Logs everything to WandB if wandb_run is provided. |
|
|
|
|
|
Returns: |
|
|
dict of all metrics. |
|
|
""" |
|
|
logger.info("Running single-step evaluation...") |
|
|
metrics, comparisons = single_step_eval( |
|
|
model, val_loader, device, n_batches=n_val_batches, ddim_steps=ddim_steps |
|
|
) |
|
|
logger.info(f" val/mse = {metrics['val/mse']:.6f}") |
|
|
|
|
|
logger.info(f"Running {n_rollout}-step rollout evaluation...") |
|
|
video, rollout_mse = rollout_eval( |
|
|
model, rollout_loader, device, n_rollout=n_rollout, ddim_steps=ddim_steps |
|
|
) |
|
|
logger.info(f" rollout MSE (step 1/last): {rollout_mse[0]:.6f} / {rollout_mse[-1]:.6f}") |
|
|
|
|
|
|
|
|
metrics["val/rollout_mse_mean"] = float(np.mean(rollout_mse)) |
|
|
metrics["val/rollout_mse_final"] = rollout_mse[-1] |
|
|
for t, m in enumerate(rollout_mse): |
|
|
metrics[f"val/rollout_mse_step{t}"] = m |
|
|
|
|
|
|
|
|
if wandb_run is not None: |
|
|
import wandb |
|
|
|
|
|
wandb_run.log(metrics, step=global_step) |
|
|
|
|
|
|
|
|
for img, caption in comparisons[:8]: |
|
|
wandb_run.log( |
|
|
{f"eval/{caption}": wandb.Image(img, caption="Cond | GT | Pred")}, |
|
|
step=global_step, |
|
|
) |
|
|
|
|
|
|
|
|
wandb_run.log( |
|
|
{"eval/rollout_video": wandb.Video(video, fps=4, format="mp4", |
|
|
caption="Left=GT Right=Prediction")}, |
|
|
step=global_step, |
|
|
) |
|
|
|
|
|
|
|
|
table = wandb.Table(columns=["step", "mse"], data=[[t, m] for t, m in enumerate(rollout_mse)]) |
|
|
wandb_run.log( |
|
|
{"eval/rollout_mse_curve": wandb.plot.line( |
|
|
table, "step", "mse", title="Rollout MSE vs Step" |
|
|
)}, |
|
|
step=global_step, |
|
|
) |
|
|
|
|
|
return metrics |
|
|
|