| """Train and evaluate all image-input world models locally.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import gc |
| import random |
| import importlib |
| import json |
| from contextlib import nullcontext |
| from pathlib import Path |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
|
|
| from experiments.shared.src.data.image_dataset import ImageTrajectoryDataset |
| from experiments.shared.src.methods import PAPER_LEARNED_METHODS |
| from experiments.shared.src.utils.parameter_count import save_parameter_count |
| from experiments.shared.src.vision.clean_renderer import render_clean_boat_history_tensor |
|
|
|
|
| METHODS = PAPER_LEARNED_METHODS |
|
|
| POSITION_SCALE = 5.0 |
|
|
|
|
| def build_method(method: str): |
| config_module = importlib.import_module(f"experiments.{method}.src.config") |
| model_module = importlib.import_module(f"experiments.{method}.src.model") |
| cfg = config_module.default_config() |
| return cfg, model_module.build_model(cfg) |
|
|
|
|
| def loader_kwargs(num_workers: int) -> dict: |
| if num_workers <= 0: |
| return {} |
| return { |
| "multiprocessing_context": "spawn", |
| "persistent_workers": True, |
| "prefetch_factor": 4, |
| } |
|
|
|
|
| def dataloader_pin_memory(device: torch.device) -> bool: |
| return device.type == "cuda" |
|
|
|
|
| def configure_training_runtime(device: torch.device) -> None: |
| if hasattr(torch, "set_float32_matmul_precision"): |
| torch.set_float32_matmul_precision("high") |
| if device.type == "cuda" and hasattr(torch.backends, "cudnn"): |
| torch.backends.cudnn.benchmark = True |
|
|
|
|
| def autocast_context(device: torch.device, precision: str): |
| if device.type != "cuda" or precision == "fp32": |
| return nullcontext() |
| dtype = torch.bfloat16 if precision == "bf16" else torch.float16 |
| return torch.autocast(device_type="cuda", dtype=dtype) |
|
|
|
|
| def method_seed(base_seed: int, method: str) -> int: |
| return int(base_seed) + sum((idx + 1) * ord(char) for idx, char in enumerate(method)) |
|
|
|
|
| def set_training_seed(seed: int) -> None: |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def required_model_history(model, default_history_len: int) -> int: |
| config = getattr(model, "config", None) |
| history_len = int(getattr(config, "history_len", default_history_len)) |
| context_len = int(getattr(config, "context_len", history_len)) |
| return min(int(default_history_len), max(history_len, context_len)) |
|
|
|
|
| def selected_history_indices(model, default_history_len: int) -> list[int]: |
| if hasattr(model, "selected_history_indices"): |
| return [int(i) for i in model.selected_history_indices(default_history_len)] |
| needed = required_model_history(model, default_history_len) |
| return list(range(default_history_len - needed, default_history_len)) |
|
|
|
|
| def prepare_batch(batch, args, device: torch.device): |
| observation_hist, actions, future_actions, targets, origin, prev_origin, flow_type_id, boat_id = batch |
| history_indices = getattr(args, "history_indices", None) |
| if history_indices is None: |
| model_history_len = int(getattr(args, "model_history_len", observation_hist.shape[1])) |
| observation_hist = observation_hist[:, -model_history_len:] |
| actions = actions[:, -model_history_len:] |
| else: |
| observation_hist = observation_hist[:, history_indices] |
| actions = actions[:, history_indices] |
| actions = actions.to(device, non_blocking=True) |
| future_actions = future_actions.to(device, non_blocking=True) |
| targets = targets.to(device, non_blocking=True) |
| origin = origin.to(device, non_blocking=True) |
| if args.render_mode == "device": |
| states = observation_hist.to(device, non_blocking=True) |
| boat_id = boat_id.to(device, non_blocking=True) |
| images = render_clean_boat_history_tensor( |
| states, |
| boat_id, |
| image_size=args.image_size, |
| visual_scale=args.visual_scale, |
| ) |
| else: |
| images = observation_hist.to(device, non_blocking=True) |
| return images, actions, future_actions, targets, origin, prev_origin, flow_type_id, boat_id |
|
|
|
|
| def rollout_from_encoded(model, z: torch.Tensor, c: torch.Tensor, future_actions: torch.Tensor) -> torch.Tensor: |
| preds = [] |
| cur = z |
| for t in range(future_actions.shape[1]): |
| cur = model.step(cur, future_actions[:, t], c) |
| preds.append(model.decoder(cur)) |
| return torch.stack(preds, dim=1) |
|
|
|
|
| def save_eval_checkpoint(model, checkpoint_dir: Path, checkpoint_name: str, step: int | None = None) -> str: |
| checkpoint_dir.mkdir(parents=True, exist_ok=True) |
| path = checkpoint_dir / checkpoint_name |
| if step is not None: |
| stem = Path(checkpoint_name).stem |
| suffix = Path(checkpoint_name).suffix or ".pt" |
| path = checkpoint_dir / f"{stem}_step_{step:06d}{suffix}" |
| torch.save(model.state_dict(), path) |
| return path.name |
|
|
|
|
| def train_method(method: str, args) -> dict[str, float | int | list[float]]: |
| set_training_seed(method_seed(args.seed, method)) |
| cfg, model = build_method(method) |
| device = torch.device(args.device) |
| configure_training_runtime(device) |
| args.model_history_len = required_model_history(model, args.history_len) |
| args.history_indices = selected_history_indices(model, args.history_len) |
| model.to(device) |
| if device.type == "cuda": |
| model.to(memory_format=torch.channels_last) |
| optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4) |
| scaler = torch.amp.GradScaler("cuda", enabled=(device.type == "cuda" and args.precision == "fp16")) |
| out_dir = Path("experiments") / method |
| checkpoint_dir = out_dir / "checkpoint" |
| result_dir = out_dir / "result" |
| checkpoint_dir.mkdir(parents=True, exist_ok=True) |
| result_dir.mkdir(parents=True, exist_ok=True) |
| trace_path = result_dir / f"{Path(args.checkpoint_name).stem}_training_trace.jsonl" |
| trace_path.write_text("") |
| train_ds = ImageTrajectoryDataset( |
| args.train_source, |
| history_len=args.history_len, |
| horizon=args.horizon, |
| episodes=args.train_episodes, |
| image_size=args.image_size, |
| visual_scale=args.visual_scale, |
| max_windows=args.train_windows, |
| seed=args.seed, |
| return_aux=True, |
| render_images=args.render_mode == "dataset", |
| ) |
| train_loader = DataLoader( |
| train_ds, |
| batch_size=args.batch_size, |
| shuffle=True, |
| num_workers=args.num_workers, |
| pin_memory=dataloader_pin_memory(device), |
| drop_last=True, |
| **loader_kwargs(args.num_workers), |
| ) |
| logged_losses = [] |
| model.train() |
| step = 0 |
| running_loss = torch.zeros((), device=device) |
| running_count = 0 |
| saved_checkpoints: list[str] = [] |
| while step < args.steps: |
| for batch in train_loader: |
| step += 1 |
| images, actions, future_actions, targets, origin, _prev_origin, _flow_type_id, _boat_id = prepare_batch(batch, args, device) |
| train_targets = encode_targets(targets, origin, args.target_mode) |
| with autocast_context(device, args.precision): |
| z, c = model.encode(images, actions) |
| pred = rollout_from_encoded(model, z, c, future_actions) |
| loss = weighted_pose_loss(pred.float(), train_targets.float(), args.heading_weight) |
| if args.motion_weight > 0.0: |
| pred_abs = decode_predictions(pred.float(), origin, args.target_mode) |
| loss = loss + args.motion_weight * motion_delta_loss(pred_abs, targets) |
| if args.current_pose_weight > 0.0: |
| current_target = encode_absolute_pose(origin) |
| loss = loss + args.current_pose_weight * weighted_pose_loss( |
| model.decoder(z).float(), |
| current_target, |
| args.heading_weight, |
| ) |
| optimizer.zero_grad(set_to_none=True) |
| if scaler.is_enabled(): |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) |
| optimizer.step() |
| running_loss = running_loss + loss.detach() |
| running_count += 1 |
| if step % args.log_every == 0: |
| mean_loss = float((running_loss / max(running_count, 1)).item()) |
| logged_losses.append(mean_loss) |
| print(f"{method} step {step:05d} loss={mean_loss:.5f}", flush=True) |
| with trace_path.open("a") as f: |
| f.write(json.dumps({"method": method, "step": int(step), "loss": mean_loss}) + "\n") |
| running_loss = torch.zeros((), device=device) |
| running_count = 0 |
| if args.checkpoint_interval > 0 and step % args.checkpoint_interval == 0: |
| saved_checkpoints.append(save_eval_checkpoint(model, checkpoint_dir, args.checkpoint_name, step)) |
| if step >= args.steps: |
| break |
| if running_count: |
| mean_loss = float((running_loss / running_count).item()) |
| logged_losses.append(mean_loss) |
| with trace_path.open("a") as f: |
| f.write(json.dumps({"method": method, "step": int(step), "loss": mean_loss}) + "\n") |
| final_checkpoint = save_eval_checkpoint(model, checkpoint_dir, args.checkpoint_name) |
| counts = save_parameter_count(model, result_dir / "parameter_count.json") |
| del train_loader, train_ds |
| gc.collect() |
| test_ds = ImageTrajectoryDataset( |
| args.test_source, |
| history_len=args.history_len, |
| horizon=args.horizon, |
| episodes=args.test_episodes, |
| max_windows=args.test_windows, |
| seed=args.seed + 1, |
| image_size=args.image_size, |
| visual_scale=args.visual_scale, |
| return_aux=True, |
| render_images=args.render_mode == "dataset", |
| ) |
| test_loader = DataLoader( |
| test_ds, |
| batch_size=args.batch_size, |
| shuffle=False, |
| num_workers=args.num_workers, |
| pin_memory=dataloader_pin_memory(device), |
| **loader_kwargs(args.num_workers), |
| ) |
| metrics = evaluate_model(model, test_loader, device, args.horizon, args.target_mode, args) |
| result = { |
| "method": method, |
| "steps": int(args.steps), |
| "batch_size": int(args.batch_size), |
| "train_samples": int(args.steps * args.batch_size), |
| "final_train_loss": float(logged_losses[-1]), |
| "total_parameters": int(counts["total"]), |
| "target_mode": args.target_mode, |
| "position_scale": POSITION_SCALE, |
| "heading_weight": float(args.heading_weight), |
| "current_pose_weight": float(args.current_pose_weight), |
| "motion_weight": float(args.motion_weight), |
| "precision": args.precision, |
| "checkpoint_name": args.checkpoint_name, |
| "final_checkpoint": final_checkpoint, |
| "intermediate_checkpoints": saved_checkpoints, |
| "checkpoint_interval": int(args.checkpoint_interval), |
| "prediction": metrics, |
| } |
| result_name = f"{Path(args.checkpoint_name).stem}_training.json" |
| (result_dir / result_name).write_text(json.dumps(result, indent=2)) |
| del test_loader, test_ds, model |
| gc.collect() |
| if device.type == "cuda": |
| torch.cuda.empty_cache() |
| return result |
|
|
|
|
| @torch.no_grad() |
| def evaluate_model(model, loader, device: torch.device, horizon: int, target_mode: str, args) -> dict[str, float]: |
| model.eval() |
| pos_sums = np.zeros(horizon, dtype=np.float64) |
| heading_sums = np.zeros(horizon, dtype=np.float64) |
| count = 0 |
| for batch in loader: |
| images, actions, future_actions, targets, origin, _prev_origin, _flow_type_id, _boat_id = prepare_batch(batch, args, device) |
| with autocast_context(device, args.precision): |
| encoded = model.rollout(images, actions, future_actions) |
| pred = decode_predictions(encoded.float(), origin, target_mode) |
| pos = torch.linalg.norm(pred[..., :2] - targets[..., :2], dim=-1) |
| pred_angle = torch.atan2(pred[..., 3], pred[..., 2]) |
| target_angle = torch.atan2(targets[..., 3], targets[..., 2]) |
| heading = torch.atan2(torch.sin(pred_angle - target_angle), torch.cos(pred_angle - target_angle)).abs() |
| pos_sums += pos.sum(dim=0).cpu().numpy() |
| heading_sums += heading.sum(dim=0).cpu().numpy() |
| count += int(images.shape[0]) |
| pos_mean = pos_sums / count |
| heading_mean = heading_sums / count |
| result: dict[str, float] = {} |
| for step in [1, 3, 6, 8, 10, 20]: |
| if horizon >= step: |
| result[f"pos{step}"] = float(pos_mean[step - 1]) |
| result[f"heading{step}"] = float(heading_mean[step - 1]) |
| return result |
|
|
|
|
| def encode_targets(targets: torch.Tensor, origin: torch.Tensor, target_mode: str) -> torch.Tensor: |
| if target_mode == "absolute_normalized": |
| return encode_absolute_pose(targets) |
| if target_mode == "relative_motion": |
| rel_xy = (targets[..., :2] - origin[:, None, :2]) / POSITION_SCALE |
| origin_angle = torch.atan2(origin[:, 3], origin[:, 2]) |
| target_angle = torch.atan2(targets[..., 3], targets[..., 2]) |
| delta = target_angle - origin_angle[:, None] |
| rel_heading = torch.stack([torch.cos(delta), torch.sin(delta)], dim=-1) |
| return torch.cat([rel_xy, rel_heading], dim=-1) |
| raise ValueError(f"unknown target_mode: {target_mode}") |
|
|
|
|
| def encode_absolute_pose(obs: torch.Tensor) -> torch.Tensor: |
| xy = (obs[..., :2] - POSITION_SCALE) / POSITION_SCALE |
| return torch.cat([xy, obs[..., 2:4]], dim=-1) |
|
|
|
|
| def decode_predictions(predictions: torch.Tensor, origin: torch.Tensor, target_mode: str) -> torch.Tensor: |
| if target_mode == "absolute_normalized": |
| xy = predictions[..., :2] * POSITION_SCALE + POSITION_SCALE |
| return torch.cat([xy, predictions[..., 2:4]], dim=-1) |
| if target_mode == "relative_motion": |
| xy = predictions[..., :2] * POSITION_SCALE + origin[:, None, :2] |
| origin_angle = torch.atan2(origin[:, 3], origin[:, 2]) |
| delta = torch.atan2(predictions[..., 3], predictions[..., 2]) |
| angle = origin_angle[:, None] + delta |
| heading = torch.stack([torch.cos(angle), torch.sin(angle)], dim=-1) |
| return torch.cat([xy, heading], dim=-1) |
| raise ValueError(f"unknown target_mode: {target_mode}") |
|
|
|
|
| def weighted_pose_loss(predictions: torch.Tensor, targets: torch.Tensor, heading_weight: float) -> torch.Tensor: |
| pos_loss = F.mse_loss(predictions[..., :2], targets[..., :2]) |
| heading_loss = F.mse_loss(predictions[..., 2:4], targets[..., 2:4]) |
| return pos_loss + heading_weight * heading_loss |
|
|
|
|
| def motion_delta_loss(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
| pred_delta = predictions[:, 1:, :2] - predictions[:, :-1, :2] |
| target_delta = targets[:, 1:, :2] - targets[:, :-1, :2] |
| pred_angle = torch.atan2(predictions[..., 3], predictions[..., 2]) |
| target_angle = torch.atan2(targets[..., 3], targets[..., 2]) |
| pred_turn = torch.atan2( |
| torch.sin(pred_angle[:, 1:] - pred_angle[:, :-1]), |
| torch.cos(pred_angle[:, 1:] - pred_angle[:, :-1]), |
| ) |
| target_turn = torch.atan2( |
| torch.sin(target_angle[:, 1:] - target_angle[:, :-1]), |
| torch.cos(target_angle[:, 1:] - target_angle[:, :-1]), |
| ) |
| return F.mse_loss(pred_delta, target_delta) + 0.2 * F.mse_loss(pred_turn, target_turn) |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--methods", nargs="+", default=METHODS) |
| parser.add_argument("--train-source", default="data/paper/train.npz") |
| parser.add_argument("--test-source", default="data/paper/test.npz") |
| parser.add_argument("--train-episodes", type=int, default=512) |
| parser.add_argument("--test-episodes", type=int, default=256) |
| parser.add_argument("--history-len", type=int, default=32) |
| parser.add_argument("--horizon", type=int, default=20) |
| parser.add_argument("--train-windows", type=int, default=65536) |
| parser.add_argument("--test-windows", type=int, default=8192) |
| parser.add_argument("--batch-size", type=int, default=64) |
| parser.add_argument("--steps", type=int, default=16000) |
| parser.add_argument("--lr", type=float, default=3e-4) |
| parser.add_argument("--log-every", type=int, default=200) |
| parser.add_argument("--seed", type=int, default=19) |
| parser.add_argument("--device", default="cuda") |
| parser.add_argument("--num-workers", type=int, default=4) |
| parser.add_argument("--episode-chunk-size", type=int, default=64) |
| parser.add_argument("--image-size", type=int, default=160) |
| parser.add_argument("--visual-scale", type=float, default=2.5) |
| parser.add_argument("--render-mode", choices=["device", "dataset"], default="device") |
| parser.add_argument("--precision", choices=["fp32", "bf16", "fp16"], default="fp32") |
| parser.add_argument("--target-mode", choices=["absolute_normalized", "relative_motion"], default="absolute_normalized") |
| parser.add_argument("--heading-weight", type=float, default=2.0) |
| parser.add_argument("--current-pose-weight", type=float, default=1.0) |
| parser.add_argument("--motion-weight", type=float, default=0.0) |
| parser.add_argument("--checkpoint-name", default="paper.pt") |
| parser.add_argument("--checkpoint-interval", type=int, default=2000) |
| args = parser.parse_args() |
| results = [train_method(method, args) for method in args.methods] |
| print(json.dumps(results, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|