"""Evaluate trained image-input world models on long open-loop rollouts.""" from __future__ import annotations import argparse import importlib import json from pathlib import Path import numpy as np import torch from torch.utils.data import DataLoader from experiments.shared.src.data.image_dataset import ImageTrajectoryDataset from experiments.shared.src.methods import PAPER_LEARNED_METHODS from experiments.shared.src.vision.clean_renderer import render_clean_boat_history_tensor from experiments.train_image_world_models import configure_training_runtime from experiments.train_image_world_models import autocast_context from experiments.train_image_world_models import decode_predictions from experiments.train_image_world_models import required_model_history from experiments.train_image_world_models import selected_history_indices METHODS = PAPER_LEARNED_METHODS def loader_kwargs(num_workers: int) -> dict: if num_workers <= 0: return {} return { "multiprocessing_context": "spawn", "persistent_workers": True, "prefetch_factor": 4, } def prepare_batch(batch, args, device: torch.device): observation_hist, actions, future_actions, targets, origin, prev_origin, flow_type_id, boat_id = batch history_indices = getattr(args, "history_indices", None) if history_indices is None: model_history_len = int(getattr(args, "model_history_len", observation_hist.shape[1])) observation_hist = observation_hist[:, -model_history_len:] actions = actions[:, -model_history_len:] else: observation_hist = observation_hist[:, history_indices] actions = actions[:, history_indices] actions = actions.to(device, non_blocking=True) future_actions = future_actions.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) origin = origin.to(device, non_blocking=True) if args.render_mode == "device": states = observation_hist.to(device, non_blocking=True) boat_id_device = boat_id.to(device, non_blocking=True) images = render_clean_boat_history_tensor( states, boat_id_device, image_size=args.image_size, visual_scale=args.visual_scale, ) else: images = observation_hist.to(device, non_blocking=True) return images, actions, future_actions, targets, origin, prev_origin, flow_type_id, boat_id def build_method(method: str): config_module = importlib.import_module(f"experiments.{method}.src.config") model_module = importlib.import_module(f"experiments.{method}.src.model") cfg = config_module.default_config() return cfg, model_module.build_model(cfg) def load_flow_names(source_npz: str) -> dict[int, str]: src = np.load(source_npz, allow_pickle=False) metadata = json.loads(str(src["metadata"])) return {int(v): str(k) for k, v in metadata["flows"].items()} def load_group_names(source_npz: str, key: str) -> dict[int, str]: src = np.load(source_npz, allow_pickle=False) metadata = json.loads(str(src["metadata"])) return {int(v): str(k) for k, v in metadata[key].items()} @torch.no_grad() def rollout_with_context(model, images: torch.Tensor, actions: torch.Tensor, future_actions: torch.Tensor, mode: str) -> torch.Tensor: z, c = model.encode(images, actions) if mode == "zero": c = torch.zeros_like(c) elif mode == "shuffled": c = c.roll(shifts=1, dims=0) if hasattr(model, "rollout_with_context"): return model.rollout_with_context(z, c, future_actions) preds = [] cur = z for t in range(future_actions.shape[1]): cur = model.step(cur, future_actions[:, t], c) preds.append(model.decoder(cur)) return torch.stack(preds, dim=1) @torch.no_grad() def evaluate_model( model, loader, device: torch.device, horizon: int, target_mode: str, flow_names: dict[int, str], traj_names: dict[int, str], boat_names: dict[int, str], context_mode: str, args, ) -> dict: model.eval() steps = [s for s in [1, 3, 6, 8, 10, 20, 30, 40, 60] if s <= horizon] pos_sum = np.zeros(horizon, dtype=np.float64) heading_sum = np.zeros(horizon, dtype=np.float64) flow_pos: dict[int, np.ndarray] = {} flow_heading: dict[int, np.ndarray] = {} flow_count: dict[int, int] = {} traj_pos: dict[int, np.ndarray] = {} traj_heading: dict[int, np.ndarray] = {} traj_count: dict[int, int] = {} boat_pos: dict[int, np.ndarray] = {} boat_heading: dict[int, np.ndarray] = {} boat_count: dict[int, int] = {} count = 0 cursor = 0 for batch in loader: images, actions, future_actions, targets, origin, _prev_origin, flow_type_id, _boat_id = prepare_batch(batch, args, device) with autocast_context(device, args.precision): if context_mode == "inferred": encoded = model.rollout(images, actions, future_actions) else: encoded = rollout_with_context(model, images, actions, future_actions, context_mode) pred = decode_predictions(encoded.float(), origin, target_mode) pos = torch.linalg.norm(pred[..., :2] - targets[..., :2], dim=-1) pred_angle = torch.atan2(pred[..., 3], pred[..., 2]) target_angle = torch.atan2(targets[..., 3], targets[..., 2]) heading = torch.atan2(torch.sin(pred_angle - target_angle), torch.cos(pred_angle - target_angle)).abs() pos_np = pos.cpu().numpy() heading_np = heading.cpu().numpy() pos_sum += pos_np.sum(axis=0) heading_sum += heading_np.sum(axis=0) count += int(pos_np.shape[0]) flow_np = flow_type_id.numpy() batch_indices = loader.dataset.indices[cursor : cursor + int(pos_np.shape[0])] cursor += int(pos_np.shape[0]) traj_np = np.array([loader.dataset.traj_type_ids[ep] for ep, _t in batch_indices], dtype=np.int64) boat_np = np.array([loader.dataset.boat_ids[ep] for ep, _t in batch_indices], dtype=np.int64) for flow_id in np.unique(flow_np): mask = flow_np == flow_id fid = int(flow_id) flow_pos.setdefault(fid, np.zeros(horizon, dtype=np.float64)) flow_heading.setdefault(fid, np.zeros(horizon, dtype=np.float64)) flow_count[fid] = flow_count.get(fid, 0) + int(mask.sum()) flow_pos[fid] += pos_np[mask].sum(axis=0) flow_heading[fid] += heading_np[mask].sum(axis=0) for traj_id in np.unique(traj_np): mask = traj_np == traj_id tid = int(traj_id) traj_pos.setdefault(tid, np.zeros(horizon, dtype=np.float64)) traj_heading.setdefault(tid, np.zeros(horizon, dtype=np.float64)) traj_count[tid] = traj_count.get(tid, 0) + int(mask.sum()) traj_pos[tid] += pos_np[mask].sum(axis=0) traj_heading[tid] += heading_np[mask].sum(axis=0) for boat_id in np.unique(boat_np): mask = boat_np == boat_id bid = int(boat_id) boat_pos.setdefault(bid, np.zeros(horizon, dtype=np.float64)) boat_heading.setdefault(bid, np.zeros(horizon, dtype=np.float64)) boat_count[bid] = boat_count.get(bid, 0) + int(mask.sum()) boat_pos[bid] += pos_np[mask].sum(axis=0) boat_heading[bid] += heading_np[mask].sum(axis=0) result = summarize(pos_sum / count, heading_sum / count, steps) by_flow = {} for fid, n in sorted(flow_count.items()): by_flow[flow_names.get(fid, str(fid))] = summarize(flow_pos[fid] / n, flow_heading[fid] / n, steps) result["by_flow"] = by_flow result["by_trajectory"] = { traj_names.get(tid, str(tid)): summarize(traj_pos[tid] / n, traj_heading[tid] / n, steps) for tid, n in sorted(traj_count.items()) } result["by_boat"] = { boat_names.get(bid, str(bid)): summarize(boat_pos[bid] / n, boat_heading[bid] / n, steps) for bid, n in sorted(boat_count.items()) } return result def summarize(pos_mean: np.ndarray, heading_mean: np.ndarray, steps: list[int]) -> dict[str, float]: result: dict[str, float] = {} for step in steps: result[f"pos{step}"] = float(pos_mean[step - 1]) result[f"heading{step}"] = float(heading_mean[step - 1]) return result def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--methods", nargs="+", default=METHODS) parser.add_argument("--test-source", default="data/paper/test.npz") parser.add_argument("--test-episodes", type=int, default=256) parser.add_argument("--history-len", type=int, default=32) parser.add_argument("--horizon", type=int, default=60) parser.add_argument("--test-windows", type=int, default=4096) parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--seed", type=int, default=20) parser.add_argument("--device", default="cuda") parser.add_argument("--target-mode", choices=["absolute_normalized", "relative_motion"], default="absolute_normalized") parser.add_argument("--checkpoint-name", default="image_local.pt") parser.add_argument("--out", default="experiments/reports/image_long_rollout_eval.json") parser.add_argument("--num-workers", type=int, default=4) parser.add_argument("--image-size", type=int, default=160) parser.add_argument("--visual-scale", type=float, default=2.5) parser.add_argument("--render-mode", choices=["device", "dataset"], default="device") parser.add_argument("--precision", choices=["fp32", "bf16", "fp16"], default="fp32") args = parser.parse_args() device = torch.device(args.device) configure_training_runtime(device) flow_names = load_flow_names(args.test_source) traj_names = load_group_names(args.test_source, "trajectories") boat_names = load_group_names(args.test_source, "boats") ds = ImageTrajectoryDataset( args.test_source, history_len=args.history_len, horizon=args.horizon, episodes=args.test_episodes, max_windows=args.test_windows, seed=args.seed, image_size=args.image_size, visual_scale=args.visual_scale, return_aux=True, render_images=args.render_mode == "dataset", ) loader = DataLoader( ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=device.type == "cuda", **loader_kwargs(args.num_workers), ) payload = [] for method in args.methods: _cfg, model = build_method(method) state = torch.load(Path("experiments") / method / "checkpoint" / args.checkpoint_name, map_location="cpu") model.load_state_dict(state) model.to(device) if device.type == "cuda": model.to(memory_format=torch.channels_last) args.model_history_len = required_model_history(model, args.history_len) args.history_indices = selected_history_indices(model, args.history_len) item = { "method": method, "inferred": evaluate_model(model, loader, device, args.horizon, args.target_mode, flow_names, traj_names, boat_names, "inferred", args), } if method == "flowmo": item["context_zero"] = evaluate_model(model, loader, device, args.horizon, args.target_mode, flow_names, traj_names, boat_names, "zero", args) item["context_shuffled"] = evaluate_model(model, loader, device, args.horizon, args.target_mode, flow_names, traj_names, boat_names, "shuffled", args) payload.append(item) out = Path(args.out) out.parent.mkdir(parents=True, exist_ok=True) out.write_text(json.dumps(payload, indent=2)) print(json.dumps(payload, indent=2)) if __name__ == "__main__": main()