| """Evaluate trained image-input world models on long open-loop rollouts.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import importlib |
| import json |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| from torch.utils.data import DataLoader |
|
|
| from experiments.shared.src.data.image_dataset import ImageTrajectoryDataset |
| from experiments.shared.src.methods import PAPER_LEARNED_METHODS |
| from experiments.shared.src.vision.clean_renderer import render_clean_boat_history_tensor |
| from experiments.train_image_world_models import configure_training_runtime |
| from experiments.train_image_world_models import autocast_context |
| from experiments.train_image_world_models import decode_predictions |
| from experiments.train_image_world_models import required_model_history |
| from experiments.train_image_world_models import selected_history_indices |
|
|
|
|
| METHODS = PAPER_LEARNED_METHODS |
|
|
|
|
| def loader_kwargs(num_workers: int) -> dict: |
| if num_workers <= 0: |
| return {} |
| return { |
| "multiprocessing_context": "spawn", |
| "persistent_workers": True, |
| "prefetch_factor": 4, |
| } |
|
|
|
|
| def prepare_batch(batch, args, device: torch.device): |
| observation_hist, actions, future_actions, targets, origin, prev_origin, flow_type_id, boat_id = batch |
| history_indices = getattr(args, "history_indices", None) |
| if history_indices is None: |
| model_history_len = int(getattr(args, "model_history_len", observation_hist.shape[1])) |
| observation_hist = observation_hist[:, -model_history_len:] |
| actions = actions[:, -model_history_len:] |
| else: |
| observation_hist = observation_hist[:, history_indices] |
| actions = actions[:, history_indices] |
| actions = actions.to(device, non_blocking=True) |
| future_actions = future_actions.to(device, non_blocking=True) |
| targets = targets.to(device, non_blocking=True) |
| origin = origin.to(device, non_blocking=True) |
| if args.render_mode == "device": |
| states = observation_hist.to(device, non_blocking=True) |
| boat_id_device = boat_id.to(device, non_blocking=True) |
| images = render_clean_boat_history_tensor( |
| states, |
| boat_id_device, |
| image_size=args.image_size, |
| visual_scale=args.visual_scale, |
| ) |
| else: |
| images = observation_hist.to(device, non_blocking=True) |
| return images, actions, future_actions, targets, origin, prev_origin, flow_type_id, boat_id |
|
|
|
|
| def build_method(method: str): |
| config_module = importlib.import_module(f"experiments.{method}.src.config") |
| model_module = importlib.import_module(f"experiments.{method}.src.model") |
| cfg = config_module.default_config() |
| return cfg, model_module.build_model(cfg) |
|
|
|
|
| def load_flow_names(source_npz: str) -> dict[int, str]: |
| src = np.load(source_npz, allow_pickle=False) |
| metadata = json.loads(str(src["metadata"])) |
| return {int(v): str(k) for k, v in metadata["flows"].items()} |
|
|
|
|
| def load_group_names(source_npz: str, key: str) -> dict[int, str]: |
| src = np.load(source_npz, allow_pickle=False) |
| metadata = json.loads(str(src["metadata"])) |
| return {int(v): str(k) for k, v in metadata[key].items()} |
|
|
|
|
| @torch.no_grad() |
| def rollout_with_context(model, images: torch.Tensor, actions: torch.Tensor, future_actions: torch.Tensor, mode: str) -> torch.Tensor: |
| z, c = model.encode(images, actions) |
| if mode == "zero": |
| c = torch.zeros_like(c) |
| elif mode == "shuffled": |
| c = c.roll(shifts=1, dims=0) |
| if hasattr(model, "rollout_with_context"): |
| return model.rollout_with_context(z, c, future_actions) |
| preds = [] |
| cur = z |
| for t in range(future_actions.shape[1]): |
| cur = model.step(cur, future_actions[:, t], c) |
| preds.append(model.decoder(cur)) |
| return torch.stack(preds, dim=1) |
|
|
|
|
| @torch.no_grad() |
| def evaluate_model( |
| model, |
| loader, |
| device: torch.device, |
| horizon: int, |
| target_mode: str, |
| flow_names: dict[int, str], |
| traj_names: dict[int, str], |
| boat_names: dict[int, str], |
| context_mode: str, |
| args, |
| ) -> dict: |
| model.eval() |
| steps = [s for s in [1, 3, 6, 8, 10, 20, 30, 40, 60] if s <= horizon] |
| pos_sum = np.zeros(horizon, dtype=np.float64) |
| heading_sum = np.zeros(horizon, dtype=np.float64) |
| flow_pos: dict[int, np.ndarray] = {} |
| flow_heading: dict[int, np.ndarray] = {} |
| flow_count: dict[int, int] = {} |
| traj_pos: dict[int, np.ndarray] = {} |
| traj_heading: dict[int, np.ndarray] = {} |
| traj_count: dict[int, int] = {} |
| boat_pos: dict[int, np.ndarray] = {} |
| boat_heading: dict[int, np.ndarray] = {} |
| boat_count: dict[int, int] = {} |
| count = 0 |
| cursor = 0 |
| for batch in loader: |
| images, actions, future_actions, targets, origin, _prev_origin, flow_type_id, _boat_id = prepare_batch(batch, args, device) |
| with autocast_context(device, args.precision): |
| if context_mode == "inferred": |
| encoded = model.rollout(images, actions, future_actions) |
| else: |
| encoded = rollout_with_context(model, images, actions, future_actions, context_mode) |
| pred = decode_predictions(encoded.float(), origin, target_mode) |
| pos = torch.linalg.norm(pred[..., :2] - targets[..., :2], dim=-1) |
| pred_angle = torch.atan2(pred[..., 3], pred[..., 2]) |
| target_angle = torch.atan2(targets[..., 3], targets[..., 2]) |
| heading = torch.atan2(torch.sin(pred_angle - target_angle), torch.cos(pred_angle - target_angle)).abs() |
| pos_np = pos.cpu().numpy() |
| heading_np = heading.cpu().numpy() |
| pos_sum += pos_np.sum(axis=0) |
| heading_sum += heading_np.sum(axis=0) |
| count += int(pos_np.shape[0]) |
| flow_np = flow_type_id.numpy() |
| batch_indices = loader.dataset.indices[cursor : cursor + int(pos_np.shape[0])] |
| cursor += int(pos_np.shape[0]) |
| traj_np = np.array([loader.dataset.traj_type_ids[ep] for ep, _t in batch_indices], dtype=np.int64) |
| boat_np = np.array([loader.dataset.boat_ids[ep] for ep, _t in batch_indices], dtype=np.int64) |
| for flow_id in np.unique(flow_np): |
| mask = flow_np == flow_id |
| fid = int(flow_id) |
| flow_pos.setdefault(fid, np.zeros(horizon, dtype=np.float64)) |
| flow_heading.setdefault(fid, np.zeros(horizon, dtype=np.float64)) |
| flow_count[fid] = flow_count.get(fid, 0) + int(mask.sum()) |
| flow_pos[fid] += pos_np[mask].sum(axis=0) |
| flow_heading[fid] += heading_np[mask].sum(axis=0) |
| for traj_id in np.unique(traj_np): |
| mask = traj_np == traj_id |
| tid = int(traj_id) |
| traj_pos.setdefault(tid, np.zeros(horizon, dtype=np.float64)) |
| traj_heading.setdefault(tid, np.zeros(horizon, dtype=np.float64)) |
| traj_count[tid] = traj_count.get(tid, 0) + int(mask.sum()) |
| traj_pos[tid] += pos_np[mask].sum(axis=0) |
| traj_heading[tid] += heading_np[mask].sum(axis=0) |
| for boat_id in np.unique(boat_np): |
| mask = boat_np == boat_id |
| bid = int(boat_id) |
| boat_pos.setdefault(bid, np.zeros(horizon, dtype=np.float64)) |
| boat_heading.setdefault(bid, np.zeros(horizon, dtype=np.float64)) |
| boat_count[bid] = boat_count.get(bid, 0) + int(mask.sum()) |
| boat_pos[bid] += pos_np[mask].sum(axis=0) |
| boat_heading[bid] += heading_np[mask].sum(axis=0) |
| result = summarize(pos_sum / count, heading_sum / count, steps) |
| by_flow = {} |
| for fid, n in sorted(flow_count.items()): |
| by_flow[flow_names.get(fid, str(fid))] = summarize(flow_pos[fid] / n, flow_heading[fid] / n, steps) |
| result["by_flow"] = by_flow |
| result["by_trajectory"] = { |
| traj_names.get(tid, str(tid)): summarize(traj_pos[tid] / n, traj_heading[tid] / n, steps) |
| for tid, n in sorted(traj_count.items()) |
| } |
| result["by_boat"] = { |
| boat_names.get(bid, str(bid)): summarize(boat_pos[bid] / n, boat_heading[bid] / n, steps) |
| for bid, n in sorted(boat_count.items()) |
| } |
| return result |
|
|
|
|
| def summarize(pos_mean: np.ndarray, heading_mean: np.ndarray, steps: list[int]) -> dict[str, float]: |
| result: dict[str, float] = {} |
| for step in steps: |
| result[f"pos{step}"] = float(pos_mean[step - 1]) |
| result[f"heading{step}"] = float(heading_mean[step - 1]) |
| return result |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--methods", nargs="+", default=METHODS) |
| parser.add_argument("--test-source", default="data/paper/test.npz") |
| parser.add_argument("--test-episodes", type=int, default=256) |
| parser.add_argument("--history-len", type=int, default=32) |
| parser.add_argument("--horizon", type=int, default=60) |
| parser.add_argument("--test-windows", type=int, default=4096) |
| parser.add_argument("--batch-size", type=int, default=64) |
| parser.add_argument("--seed", type=int, default=20) |
| parser.add_argument("--device", default="cuda") |
| parser.add_argument("--target-mode", choices=["absolute_normalized", "relative_motion"], default="absolute_normalized") |
| parser.add_argument("--checkpoint-name", default="image_local.pt") |
| parser.add_argument("--out", default="experiments/reports/image_long_rollout_eval.json") |
| parser.add_argument("--num-workers", type=int, default=4) |
| parser.add_argument("--image-size", type=int, default=160) |
| parser.add_argument("--visual-scale", type=float, default=2.5) |
| parser.add_argument("--render-mode", choices=["device", "dataset"], default="device") |
| parser.add_argument("--precision", choices=["fp32", "bf16", "fp16"], default="fp32") |
| args = parser.parse_args() |
| device = torch.device(args.device) |
| configure_training_runtime(device) |
| flow_names = load_flow_names(args.test_source) |
| traj_names = load_group_names(args.test_source, "trajectories") |
| boat_names = load_group_names(args.test_source, "boats") |
| ds = ImageTrajectoryDataset( |
| args.test_source, |
| history_len=args.history_len, |
| horizon=args.horizon, |
| episodes=args.test_episodes, |
| max_windows=args.test_windows, |
| seed=args.seed, |
| image_size=args.image_size, |
| visual_scale=args.visual_scale, |
| return_aux=True, |
| render_images=args.render_mode == "dataset", |
| ) |
| loader = DataLoader( |
| ds, |
| batch_size=args.batch_size, |
| shuffle=False, |
| num_workers=args.num_workers, |
| pin_memory=device.type == "cuda", |
| **loader_kwargs(args.num_workers), |
| ) |
| payload = [] |
| for method in args.methods: |
| _cfg, model = build_method(method) |
| state = torch.load(Path("experiments") / method / "checkpoint" / args.checkpoint_name, map_location="cpu") |
| model.load_state_dict(state) |
| model.to(device) |
| if device.type == "cuda": |
| model.to(memory_format=torch.channels_last) |
| args.model_history_len = required_model_history(model, args.history_len) |
| args.history_indices = selected_history_indices(model, args.history_len) |
| item = { |
| "method": method, |
| "inferred": evaluate_model(model, loader, device, args.horizon, args.target_mode, flow_names, traj_names, boat_names, "inferred", args), |
| } |
| if method == "flowmo": |
| item["context_zero"] = evaluate_model(model, loader, device, args.horizon, args.target_mode, flow_names, traj_names, boat_names, "zero", args) |
| item["context_shuffled"] = evaluate_model(model, loader, device, args.horizon, args.target_mode, flow_names, traj_names, boat_names, "shuffled", args) |
| payload.append(item) |
| out = Path(args.out) |
| out.parent.mkdir(parents=True, exist_ok=True) |
| out.write_text(json.dumps(payload, indent=2)) |
| print(json.dumps(payload, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|