"""Train and evaluate all image-input world models locally.""" from __future__ import annotations import argparse import gc import random import importlib import json from contextlib import nullcontext from pathlib import Path import numpy as np import torch import torch.nn.functional as F from torch.utils.data import DataLoader from experiments.shared.src.data.image_dataset import ImageTrajectoryDataset from experiments.shared.src.methods import PAPER_LEARNED_METHODS from experiments.shared.src.utils.parameter_count import save_parameter_count from experiments.shared.src.vision.clean_renderer import render_clean_boat_history_tensor METHODS = PAPER_LEARNED_METHODS POSITION_SCALE = 5.0 def build_method(method: str): config_module = importlib.import_module(f"experiments.{method}.src.config") model_module = importlib.import_module(f"experiments.{method}.src.model") cfg = config_module.default_config() return cfg, model_module.build_model(cfg) def loader_kwargs(num_workers: int) -> dict: if num_workers <= 0: return {} return { "multiprocessing_context": "spawn", "persistent_workers": True, "prefetch_factor": 4, } def dataloader_pin_memory(device: torch.device) -> bool: return device.type == "cuda" def configure_training_runtime(device: torch.device) -> None: if hasattr(torch, "set_float32_matmul_precision"): torch.set_float32_matmul_precision("high") if device.type == "cuda" and hasattr(torch.backends, "cudnn"): torch.backends.cudnn.benchmark = True def autocast_context(device: torch.device, precision: str): if device.type != "cuda" or precision == "fp32": return nullcontext() dtype = torch.bfloat16 if precision == "bf16" else torch.float16 return torch.autocast(device_type="cuda", dtype=dtype) def method_seed(base_seed: int, method: str) -> int: return int(base_seed) + sum((idx + 1) * ord(char) for idx, char in enumerate(method)) def set_training_seed(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def required_model_history(model, default_history_len: int) -> int: config = getattr(model, "config", None) history_len = int(getattr(config, "history_len", default_history_len)) context_len = int(getattr(config, "context_len", history_len)) return min(int(default_history_len), max(history_len, context_len)) def selected_history_indices(model, default_history_len: int) -> list[int]: if hasattr(model, "selected_history_indices"): return [int(i) for i in model.selected_history_indices(default_history_len)] needed = required_model_history(model, default_history_len) return list(range(default_history_len - needed, default_history_len)) def prepare_batch(batch, args, device: torch.device): observation_hist, actions, future_actions, targets, origin, prev_origin, flow_type_id, boat_id = batch history_indices = getattr(args, "history_indices", None) if history_indices is None: model_history_len = int(getattr(args, "model_history_len", observation_hist.shape[1])) observation_hist = observation_hist[:, -model_history_len:] actions = actions[:, -model_history_len:] else: observation_hist = observation_hist[:, history_indices] actions = actions[:, history_indices] actions = actions.to(device, non_blocking=True) future_actions = future_actions.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True) origin = origin.to(device, non_blocking=True) if args.render_mode == "device": states = observation_hist.to(device, non_blocking=True) boat_id = boat_id.to(device, non_blocking=True) images = render_clean_boat_history_tensor( states, boat_id, image_size=args.image_size, visual_scale=args.visual_scale, ) else: images = observation_hist.to(device, non_blocking=True) return images, actions, future_actions, targets, origin, prev_origin, flow_type_id, boat_id def rollout_from_encoded(model, z: torch.Tensor, c: torch.Tensor, future_actions: torch.Tensor) -> torch.Tensor: preds = [] cur = z for t in range(future_actions.shape[1]): cur = model.step(cur, future_actions[:, t], c) preds.append(model.decoder(cur)) return torch.stack(preds, dim=1) def save_eval_checkpoint(model, checkpoint_dir: Path, checkpoint_name: str, step: int | None = None) -> str: checkpoint_dir.mkdir(parents=True, exist_ok=True) path = checkpoint_dir / checkpoint_name if step is not None: stem = Path(checkpoint_name).stem suffix = Path(checkpoint_name).suffix or ".pt" path = checkpoint_dir / f"{stem}_step_{step:06d}{suffix}" torch.save(model.state_dict(), path) return path.name def train_method(method: str, args) -> dict[str, float | int | list[float]]: set_training_seed(method_seed(args.seed, method)) cfg, model = build_method(method) device = torch.device(args.device) configure_training_runtime(device) args.model_history_len = required_model_history(model, args.history_len) args.history_indices = selected_history_indices(model, args.history_len) model.to(device) if device.type == "cuda": model.to(memory_format=torch.channels_last) optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4) scaler = torch.amp.GradScaler("cuda", enabled=(device.type == "cuda" and args.precision == "fp16")) out_dir = Path("experiments") / method checkpoint_dir = out_dir / "checkpoint" result_dir = out_dir / "result" checkpoint_dir.mkdir(parents=True, exist_ok=True) result_dir.mkdir(parents=True, exist_ok=True) trace_path = result_dir / f"{Path(args.checkpoint_name).stem}_training_trace.jsonl" trace_path.write_text("") train_ds = ImageTrajectoryDataset( args.train_source, history_len=args.history_len, horizon=args.horizon, episodes=args.train_episodes, image_size=args.image_size, visual_scale=args.visual_scale, max_windows=args.train_windows, seed=args.seed, return_aux=True, render_images=args.render_mode == "dataset", ) train_loader = DataLoader( train_ds, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=dataloader_pin_memory(device), drop_last=True, **loader_kwargs(args.num_workers), ) logged_losses = [] model.train() step = 0 running_loss = torch.zeros((), device=device) running_count = 0 saved_checkpoints: list[str] = [] while step < args.steps: for batch in train_loader: step += 1 images, actions, future_actions, targets, origin, _prev_origin, _flow_type_id, _boat_id = prepare_batch(batch, args, device) train_targets = encode_targets(targets, origin, args.target_mode) with autocast_context(device, args.precision): z, c = model.encode(images, actions) pred = rollout_from_encoded(model, z, c, future_actions) loss = weighted_pose_loss(pred.float(), train_targets.float(), args.heading_weight) if args.motion_weight > 0.0: pred_abs = decode_predictions(pred.float(), origin, args.target_mode) loss = loss + args.motion_weight * motion_delta_loss(pred_abs, targets) if args.current_pose_weight > 0.0: current_target = encode_absolute_pose(origin) loss = loss + args.current_pose_weight * weighted_pose_loss( model.decoder(z).float(), current_target, args.heading_weight, ) optimizer.zero_grad(set_to_none=True) if scaler.is_enabled(): scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) scaler.step(optimizer) scaler.update() else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) optimizer.step() running_loss = running_loss + loss.detach() running_count += 1 if step % args.log_every == 0: mean_loss = float((running_loss / max(running_count, 1)).item()) logged_losses.append(mean_loss) print(f"{method} step {step:05d} loss={mean_loss:.5f}", flush=True) with trace_path.open("a") as f: f.write(json.dumps({"method": method, "step": int(step), "loss": mean_loss}) + "\n") running_loss = torch.zeros((), device=device) running_count = 0 if args.checkpoint_interval > 0 and step % args.checkpoint_interval == 0: saved_checkpoints.append(save_eval_checkpoint(model, checkpoint_dir, args.checkpoint_name, step)) if step >= args.steps: break if running_count: mean_loss = float((running_loss / running_count).item()) logged_losses.append(mean_loss) with trace_path.open("a") as f: f.write(json.dumps({"method": method, "step": int(step), "loss": mean_loss}) + "\n") final_checkpoint = save_eval_checkpoint(model, checkpoint_dir, args.checkpoint_name) counts = save_parameter_count(model, result_dir / "parameter_count.json") del train_loader, train_ds gc.collect() test_ds = ImageTrajectoryDataset( args.test_source, history_len=args.history_len, horizon=args.horizon, episodes=args.test_episodes, max_windows=args.test_windows, seed=args.seed + 1, image_size=args.image_size, visual_scale=args.visual_scale, return_aux=True, render_images=args.render_mode == "dataset", ) test_loader = DataLoader( test_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=dataloader_pin_memory(device), **loader_kwargs(args.num_workers), ) metrics = evaluate_model(model, test_loader, device, args.horizon, args.target_mode, args) result = { "method": method, "steps": int(args.steps), "batch_size": int(args.batch_size), "train_samples": int(args.steps * args.batch_size), "final_train_loss": float(logged_losses[-1]), "total_parameters": int(counts["total"]), "target_mode": args.target_mode, "position_scale": POSITION_SCALE, "heading_weight": float(args.heading_weight), "current_pose_weight": float(args.current_pose_weight), "motion_weight": float(args.motion_weight), "precision": args.precision, "checkpoint_name": args.checkpoint_name, "final_checkpoint": final_checkpoint, "intermediate_checkpoints": saved_checkpoints, "checkpoint_interval": int(args.checkpoint_interval), "prediction": metrics, } result_name = f"{Path(args.checkpoint_name).stem}_training.json" (result_dir / result_name).write_text(json.dumps(result, indent=2)) del test_loader, test_ds, model gc.collect() if device.type == "cuda": torch.cuda.empty_cache() return result @torch.no_grad() def evaluate_model(model, loader, device: torch.device, horizon: int, target_mode: str, args) -> dict[str, float]: model.eval() pos_sums = np.zeros(horizon, dtype=np.float64) heading_sums = np.zeros(horizon, dtype=np.float64) count = 0 for batch in loader: images, actions, future_actions, targets, origin, _prev_origin, _flow_type_id, _boat_id = prepare_batch(batch, args, device) with autocast_context(device, args.precision): encoded = model.rollout(images, actions, future_actions) pred = decode_predictions(encoded.float(), origin, target_mode) pos = torch.linalg.norm(pred[..., :2] - targets[..., :2], dim=-1) pred_angle = torch.atan2(pred[..., 3], pred[..., 2]) target_angle = torch.atan2(targets[..., 3], targets[..., 2]) heading = torch.atan2(torch.sin(pred_angle - target_angle), torch.cos(pred_angle - target_angle)).abs() pos_sums += pos.sum(dim=0).cpu().numpy() heading_sums += heading.sum(dim=0).cpu().numpy() count += int(images.shape[0]) pos_mean = pos_sums / count heading_mean = heading_sums / count result: dict[str, float] = {} for step in [1, 3, 6, 8, 10, 20]: if horizon >= step: result[f"pos{step}"] = float(pos_mean[step - 1]) result[f"heading{step}"] = float(heading_mean[step - 1]) return result def encode_targets(targets: torch.Tensor, origin: torch.Tensor, target_mode: str) -> torch.Tensor: if target_mode == "absolute_normalized": return encode_absolute_pose(targets) if target_mode == "relative_motion": rel_xy = (targets[..., :2] - origin[:, None, :2]) / POSITION_SCALE origin_angle = torch.atan2(origin[:, 3], origin[:, 2]) target_angle = torch.atan2(targets[..., 3], targets[..., 2]) delta = target_angle - origin_angle[:, None] rel_heading = torch.stack([torch.cos(delta), torch.sin(delta)], dim=-1) return torch.cat([rel_xy, rel_heading], dim=-1) raise ValueError(f"unknown target_mode: {target_mode}") def encode_absolute_pose(obs: torch.Tensor) -> torch.Tensor: xy = (obs[..., :2] - POSITION_SCALE) / POSITION_SCALE return torch.cat([xy, obs[..., 2:4]], dim=-1) def decode_predictions(predictions: torch.Tensor, origin: torch.Tensor, target_mode: str) -> torch.Tensor: if target_mode == "absolute_normalized": xy = predictions[..., :2] * POSITION_SCALE + POSITION_SCALE return torch.cat([xy, predictions[..., 2:4]], dim=-1) if target_mode == "relative_motion": xy = predictions[..., :2] * POSITION_SCALE + origin[:, None, :2] origin_angle = torch.atan2(origin[:, 3], origin[:, 2]) delta = torch.atan2(predictions[..., 3], predictions[..., 2]) angle = origin_angle[:, None] + delta heading = torch.stack([torch.cos(angle), torch.sin(angle)], dim=-1) return torch.cat([xy, heading], dim=-1) raise ValueError(f"unknown target_mode: {target_mode}") def weighted_pose_loss(predictions: torch.Tensor, targets: torch.Tensor, heading_weight: float) -> torch.Tensor: pos_loss = F.mse_loss(predictions[..., :2], targets[..., :2]) heading_loss = F.mse_loss(predictions[..., 2:4], targets[..., 2:4]) return pos_loss + heading_weight * heading_loss def motion_delta_loss(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: pred_delta = predictions[:, 1:, :2] - predictions[:, :-1, :2] target_delta = targets[:, 1:, :2] - targets[:, :-1, :2] pred_angle = torch.atan2(predictions[..., 3], predictions[..., 2]) target_angle = torch.atan2(targets[..., 3], targets[..., 2]) pred_turn = torch.atan2( torch.sin(pred_angle[:, 1:] - pred_angle[:, :-1]), torch.cos(pred_angle[:, 1:] - pred_angle[:, :-1]), ) target_turn = torch.atan2( torch.sin(target_angle[:, 1:] - target_angle[:, :-1]), torch.cos(target_angle[:, 1:] - target_angle[:, :-1]), ) return F.mse_loss(pred_delta, target_delta) + 0.2 * F.mse_loss(pred_turn, target_turn) def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--methods", nargs="+", default=METHODS) parser.add_argument("--train-source", default="data/paper/train.npz") parser.add_argument("--test-source", default="data/paper/test.npz") parser.add_argument("--train-episodes", type=int, default=512) parser.add_argument("--test-episodes", type=int, default=256) parser.add_argument("--history-len", type=int, default=32) parser.add_argument("--horizon", type=int, default=20) parser.add_argument("--train-windows", type=int, default=65536) parser.add_argument("--test-windows", type=int, default=8192) parser.add_argument("--batch-size", type=int, default=64) parser.add_argument("--steps", type=int, default=16000) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--log-every", type=int, default=200) parser.add_argument("--seed", type=int, default=19) parser.add_argument("--device", default="cuda") parser.add_argument("--num-workers", type=int, default=4) parser.add_argument("--episode-chunk-size", type=int, default=64) parser.add_argument("--image-size", type=int, default=160) parser.add_argument("--visual-scale", type=float, default=2.5) parser.add_argument("--render-mode", choices=["device", "dataset"], default="device") parser.add_argument("--precision", choices=["fp32", "bf16", "fp16"], default="fp32") parser.add_argument("--target-mode", choices=["absolute_normalized", "relative_motion"], default="absolute_normalized") parser.add_argument("--heading-weight", type=float, default=2.0) parser.add_argument("--current-pose-weight", type=float, default=1.0) parser.add_argument("--motion-weight", type=float, default=0.0) parser.add_argument("--checkpoint-name", default="paper.pt") parser.add_argument("--checkpoint-interval", type=int, default=2000) args = parser.parse_args() results = [train_method(method, args) for method in args.methods] print(json.dumps(results, indent=2)) if __name__ == "__main__": main()