"""Closed-loop planning evaluation for clean-image world models and controllers.""" from __future__ import annotations import argparse import importlib import json from collections import deque from pathlib import Path import numpy as np import torch import torch.nn.functional as F from driftwm.sim.env import SurfaceBoatEnv from driftwm.sim.flow import sample_flow from driftwm.sim.render import render_frame, save_gif from experiments.shared.src.methods import PAPER_LEARNED_METHODS, TRADITIONAL_METHODS from experiments.shared.src.vision.clean_renderer import render_clean_boat_array from experiments.train_image_world_models import autocast_context LEARNED_METHODS = PAPER_LEARNED_METHODS POSITION_SCALE = 5.0 def build_method(method: str): config_module = importlib.import_module(f"experiments.{method}.src.config") model_module = importlib.import_module(f"experiments.{method}.src.model") cfg = config_module.default_config() return cfg, model_module.build_model(cfg) def decode_absolute(prediction: torch.Tensor) -> torch.Tensor: xy = prediction[..., :2] * POSITION_SCALE + POSITION_SCALE return torch.cat([xy, prediction[..., 2:4]], dim=-1) def clean_observation(env: SurfaceBoatEnv, image_size: int, visual_scale: float) -> np.ndarray: image = render_clean_boat_array(env.full_state()[:6], env.spec, image_size=image_size, visual_scale=visual_scale) return np.transpose(image, (2, 0, 1)) def pad_action(action: np.ndarray, action_dim: int) -> np.ndarray: out = np.zeros((action_dim,), dtype=np.float32) action = np.asarray(action, dtype=np.float32) out[: min(len(action), action_dim)] = action[: min(len(action), action_dim)] return out def task_goals(task: str, rng: np.random.Generator) -> np.ndarray: if task == "waypoint_square": return np.array([[2.5, 2.5], [7.5, 2.5], [7.5, 7.5], [2.5, 7.5]], dtype=np.float32) if task == "waypoint_zigzag": return np.array([[2.5, 7.0], [4.2, 3.0], [5.8, 7.0], [7.5, 3.0]], dtype=np.float32) if task == "station_keeping": return np.array([[5.0, 5.0]], dtype=np.float32) return np.array([[8.0, 8.0]], dtype=np.float32) def set_task_state(env: SurfaceBoatEnv, state: np.ndarray) -> None: env.state[:6] = np.asarray(state, dtype=np.float32) env.last_flow_velocity = env.flow_at(env.state[:2]).astype(np.float32) def reset_task(env: SurfaceBoatEnv, task: str, flow_type: str, rng: np.random.Generator) -> None: if task == "station_keeping": flow = sample_flow(flow_type, rng, flow_id=10_000 + int(rng.integers(1, 1_000_000)), workspace=env.workspace) env.reset(flow_type=flow_type, flow=flow, random_velocity=False) set_task_state(env, np.array([5.0, 5.0, 0.3, 0.0, 0.0, 0.0], dtype=np.float32)) return flow = sample_flow(flow_type, rng, flow_id=10_000 + int(rng.integers(1, 1_000_000)), workspace=env.workspace) env.reset(flow_type=flow_type, flow=flow, random_velocity=False) set_task_state(env, np.array([2.0, 2.0, float(rng.uniform(-np.pi, np.pi)), 0.0, 0.0, 0.0], dtype=np.float32)) def rollout_latent(model, z: torch.Tensor, c: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: cur = z.repeat(actions.shape[0], 1) ctx = c.repeat(actions.shape[0], 1) if c.numel() else c preds = [] for t in range(actions.shape[1]): cur = model.step(cur, actions[:, t], ctx) preds.append(model.decoder(cur)) return decode_absolute(torch.stack(preds, dim=1)).float() def warm_start_mean( previous_mean: np.ndarray | None, horizon: int, action_dim: int, active_action_dim: int, device: torch.device, ) -> torch.Tensor: mean = torch.zeros((horizon, action_dim), dtype=torch.float32, device=device) if previous_mean is None: return mean previous = torch.as_tensor(previous_mean, dtype=torch.float32, device=device) steps = min(horizon, max(0, previous.shape[0] - 1)) if steps > 0: mean[:steps, :active_action_dim] = previous[1 : 1 + steps, :active_action_dim] if previous.shape[0] > 0 and steps < horizon: mean[steps:, :active_action_dim] = previous[-1, :active_action_dim] return mean.clamp(-1.0, 1.0) def sample_action_sequences(mean: torch.Tensor, std: torch.Tensor, population: int, knots: int) -> torch.Tensor: horizon, action_dim = mean.shape if knots >= horizon: noise = torch.randn(population, horizon, action_dim, device=mean.device) return mean.unsqueeze(0) + std.unsqueeze(0) * noise knots = max(2, knots) knot_idx = torch.linspace(0, horizon - 1, knots, device=mean.device).round().long() knot_mean = mean[knot_idx] knot_std = std[knot_idx] knot_samples = knot_mean.unsqueeze(0) + knot_std.unsqueeze(0) * torch.randn( population, knots, action_dim, device=mean.device, ) samples = F.interpolate( knot_samples.permute(0, 2, 1), size=horizon, mode="linear", align_corners=True, ).permute(0, 2, 1) return samples def route_points_tensor( current_pos: torch.Tensor, goals: np.ndarray, goal_idx: int, ) -> torch.Tensor: remaining = torch.as_tensor(goals[goal_idx:], dtype=torch.float32, device=current_pos.device) return torch.cat([current_pos.reshape(1, 2).detach(), remaining], dim=0) def route_projection(pos: torch.Tensor, route_points: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: starts = route_points[:-1] ends = route_points[1:] seg = ends - starts seg_len = torch.linalg.norm(seg, dim=-1).clamp_min(1.0e-6) seg_len_sq = (seg_len * seg_len).clamp_min(1.0e-6) rel = pos[:, :, None, :] - starts.view(1, 1, -1, 2) t = (rel * seg.view(1, 1, -1, 2)).sum(dim=-1) / seg_len_sq.view(1, 1, -1) t = t.clamp(0.0, 1.0) proj = starts.view(1, 1, -1, 2) + t[..., None] * seg.view(1, 1, -1, 2) dist_sq = ((pos[:, :, None, :] - proj) ** 2).sum(dim=-1) min_dist_sq, idx = dist_sq.min(dim=-1) cum = torch.cat([torch.zeros(1, device=pos.device), seg_len.cumsum(dim=0)[:-1]], dim=0) along = cum.view(1, 1, -1) + t * seg_len.view(1, 1, -1) route_s = along.gather(dim=-1, index=idx[..., None]).squeeze(-1) return min_dist_sq, route_s, seg_len.sum() def route_points_at_s(route_points: torch.Tensor, s: torch.Tensor) -> torch.Tensor: starts = route_points[:-1] ends = route_points[1:] seg = ends - starts seg_len = torch.linalg.norm(seg, dim=-1).clamp_min(1.0e-6) cum_end = seg_len.cumsum(dim=0) cum_start = cum_end - seg_len flat_s = s.reshape(-1).clamp(0.0, float(cum_end[-1].detach().cpu())) idx = torch.searchsorted(cum_end, flat_s, right=False).clamp(max=seg_len.numel() - 1) local = ((flat_s - cum_start[idx]) / seg_len[idx]).clamp(0.0, 1.0) pts = starts[idx] + local[:, None] * seg[idx] return pts.reshape(*s.shape, 2) def learned_plan( model, image_history: deque, action_history: deque, goals: np.ndarray, goal_idx: int, active_action_dim: int, args, prev_action: np.ndarray, previous_mean: np.ndarray | None, context_mode: str, donor_context: torch.Tensor | None, ) -> tuple[np.ndarray, np.ndarray | None, np.ndarray]: device = next(model.parameters()).device images = torch.as_tensor(np.asarray(image_history, dtype=np.uint8), device=device).unsqueeze(0) actions = torch.as_tensor(np.asarray(action_history, dtype=np.float32), device=device).unsqueeze(0) with torch.no_grad(), autocast_context(device, args.precision): z, c = model.encode(images, actions) if c.numel() and context_mode == "zero": c = torch.zeros_like(c) if c.numel() and context_mode == "shuffled" and donor_context is not None: c = donor_context.to(device=device, dtype=torch.float32) z = z.detach() c = c.detach() goal = goals[goal_idx] goal_t = torch.as_tensor(goal, dtype=torch.float32, device=device).view(1, 2) with torch.no_grad(), autocast_context(device, args.precision): current_pos = decode_absolute(model.decoder(z)).float().detach()[..., :2] route_points = route_points_tensor(current_pos[0], goals, goal_idx) mean = warm_start_mean( previous_mean, args.cem_horizon, model.config.action_dim, active_action_dim, device, ) std = torch.full_like(mean, args.cem_action_std) prev = torch.zeros((model.config.action_dim,), dtype=torch.float32, device=device) prev[:active_action_dim] = torch.as_tensor(prev_action, dtype=torch.float32, device=device) best_candidates = None with torch.no_grad(): action, best_candidates, mean = cem_plan( model, z, c, mean, std, goal_t, route_points, current_pos, prev, active_action_dim, args, ) return action, best_candidates, mean def planning_cost( pred: torch.Tensor, samples: torch.Tensor, goal_t: torch.Tensor, route_points: torch.Tensor, current_pos: torch.Tensor, prev: torch.Tensor, active_action_dim: int, args, ) -> torch.Tensor: pos = pred[..., :2] goal_delta = goal_t - current_pos goal_dir = goal_delta / torch.linalg.norm(goal_delta, dim=-1, keepdim=True).clamp_min(1.0e-6) direct_progress = ((pos - current_pos[:, None]) * goal_dir[:, None]).sum(dim=-1).amax(dim=-1) alpha = torch.linspace(1.0 / pos.shape[1], 1.0, pos.shape[1], device=pos.device, dtype=pos.dtype) direct_route = current_pos[:, None] + alpha.view(1, -1, 1) * goal_delta[:, None] direct_route_error = ((pos - direct_route) ** 2).sum(dim=-1).mean(dim=-1) route_dist_sq, route_s, route_len = route_projection(pos, route_points) route_error = route_dist_sq.mean(dim=-1) scheduled_s = alpha * torch.minimum( route_len, torch.as_tensor(args.cem_route_horizon_distance, dtype=pos.dtype, device=pos.device), ) scheduled = route_points_at_s(route_points, scheduled_s).view(1, pos.shape[1], 2) lookahead_error = ((pos - scheduled) ** 2).sum(dim=-1).mean(dim=-1) route_progress = (route_s.amax(dim=-1) / route_len.clamp_min(1.0e-6)).clamp(0.0, 1.0) goal_from_pos = goal_t[:, None] - pos goal_from_pos = goal_from_pos / torch.linalg.norm(goal_from_pos, dim=-1, keepdim=True).clamp_min(1.0e-6) heading = pred[..., 2:4] heading = heading / torch.linalg.norm(heading, dim=-1, keepdim=True).clamp_min(1.0e-6) heading_error = (1.0 - (heading * goal_from_pos).sum(dim=-1)).mean(dim=-1) terminal = ((pos[:, -1] - goal_t) ** 2).sum(dim=-1) path = ((pos - goal_t[:, None]) ** 2).sum(dim=-1).mean(dim=-1) via = ((pos - goal_t[:, None]) ** 2).sum(dim=-1).amin(dim=-1) energy = (samples[..., :active_action_dim] ** 2).mean(dim=(1, 2)) smooth_prev = torch.cat([prev.view(1, 1, -1).repeat(samples.shape[0], 1, 1), samples[:, :-1]], dim=1) smooth = ((samples - smooth_prev) ** 2).mean(dim=(1, 2)) margin = args.cem_boundary_margin boundary = ( torch.relu(margin - pos[..., 0]) + torch.relu(pos[..., 0] - (10.0 - margin)) + torch.relu(margin - pos[..., 1]) + torch.relu(pos[..., 1] - (10.0 - margin)) ).mean(dim=-1) return ( args.cem_w_goal * terminal + args.cem_w_path * path + args.cem_w_route * (route_error + 0.25 * direct_route_error) + args.cem_w_lookahead * lookahead_error + args.cem_w_via * via + args.cem_w_heading_goal * heading_error + args.cem_w_action * energy + args.cem_w_smooth * smooth + args.cem_w_boundary * boundary - args.cem_w_progress * (route_progress + 0.1 * direct_progress) ) def cem_plan( model, z: torch.Tensor, c: torch.Tensor, mean: torch.Tensor, std: torch.Tensor, goal_t: torch.Tensor, route_points: torch.Tensor, current_pos: torch.Tensor, prev: torch.Tensor, active_action_dim: int, args, ) -> tuple[np.ndarray, np.ndarray | None, np.ndarray]: best_candidates = None for _ in range(args.cem_iterations): samples = sample_action_sequences(mean, std, args.cem_population, args.cem_knots) samples[0] = mean samples = samples.clamp(-1.0, 1.0) if active_action_dim < model.config.action_dim: samples[:, :, active_action_dim:] = 0.0 with autocast_context(mean.device, args.precision): pred = rollout_latent(model, z, c, samples) cost = planning_cost(pred, samples, goal_t, route_points, current_pos, prev, active_action_dim, args) elite_idx = torch.topk(cost, k=args.cem_elites, largest=False).indices elites = samples[elite_idx] mean = elites.mean(dim=0) std = elites.std(dim=0).clamp_min(0.05) if args.make_gifs: pos = pred[..., :2] best_candidates = pos[elite_idx[:12]].detach().cpu().numpy() action = mean[0, :active_action_dim].detach().cpu().numpy() return ( np.clip(action, -1.0, 1.0).astype(np.float32), best_candidates, mean.detach().cpu().numpy(), ) @torch.no_grad() def donor_context_for_flowmo(model, env: SurfaceBoatEnv, args, seed: int) -> torch.Tensor | None: if not hasattr(model, "to_c"): return None rng = np.random.default_rng(seed + 99_999) donor = SurfaceBoatEnv( boat=env.config.boat, flow_type=env.config.flow_type, boundary="terminate", episode_steps=model.config.context_len + 8, seed=seed + 99, ) donor.reset(flow_type=env.config.flow_type, random_velocity=False) image_history = deque(maxlen=args.history_len) action_history = deque(maxlen=args.history_len) action = np.zeros((model.config.action_dim,), dtype=np.float32) for _ in range(args.history_len): image_history.append(clean_observation(donor, args.image_size, args.visual_scale)) action_history.append(action.copy()) raw = rng.uniform(-0.5, 0.5, size=donor.action_dim).astype(np.float32) donor.step(raw) action = pad_action(raw, model.config.action_dim) device = next(model.parameters()).device images = torch.as_tensor(np.asarray(image_history, dtype=np.uint8), device=device).unsqueeze(0) actions = torch.as_tensor(np.asarray(action_history, dtype=np.float32), device=device).unsqueeze(0) with autocast_context(device, args.precision): return model.encode(images, actions)[1].detach() def traditional_action(method: str, image_history: deque, env: SurfaceBoatEnv, goal: np.ndarray) -> np.ndarray: evaluate_module = importlib.import_module(f"experiments.{method}.src.evaluate") image = np.transpose(image_history[-1], (1, 2, 0)) history = [np.transpose(x, (1, 2, 0)) for x in image_history] cfg = { "image": image, "history": history, "true_flow": env.last_flow_velocity.copy(), "goal": goal.astype(float).tolist(), "action_dim": env.action_dim, "boat": env.config.boat, } return evaluate_module.evaluate(cfg)[: env.action_dim].astype(np.float32) def evaluate_one_method(method: str, args) -> dict: torch.manual_seed(args.seed) learned = method in LEARNED_METHODS model = None if learned: _cfg, model = build_method(method) state = torch.load(Path("experiments") / method / "checkpoint" / args.checkpoint_name, map_location="cpu") model.load_state_dict(state) model.to(torch.device(args.device)) if torch.device(args.device).type == "cuda": model.to(memory_format=torch.channels_last) model.eval() for param in model.parameters(): param.requires_grad_(False) results = [] gif_dir = Path(args.out) / "gifs" gif_dir.mkdir(parents=True, exist_ok=True) context_modes = args.context_modes if method == "flowmo" else ["inferred"] for context_mode in context_modes: for ep in range(args.episodes): episode_seed = int(args.seed + ep) rng = np.random.default_rng(episode_seed) env = SurfaceBoatEnv( boat=args.boat, flow_type=args.flow_type, boundary="terminate", episode_steps=args.max_steps, seed=episode_seed, ) reset_task(env, args.task, args.flow_type, rng) goals = task_goals(args.task, rng) goal_idx = 0 image_history = deque(maxlen=args.history_len) action_history = deque(maxlen=args.history_len) zero = np.zeros((model.config.action_dim if learned else 3,), dtype=np.float32) first = clean_observation(env, args.image_size, args.visual_scale) for _ in range(args.history_len): image_history.append(first.copy()) action_history.append(zero.copy()) donor_context = donor_context_for_flowmo(model, env, args, episode_seed) if learned and context_mode == "shuffled" else None trajectory = [env.full_state()[:6].copy()] frames = [] prev_action = np.zeros((env.action_dim,), dtype=np.float32) energy = 0.0 reached_times: list[int] = [] min_goal_dists = np.full((len(goals),), np.inf, dtype=np.float32) planned = None learned_plan_mean = None for t in range(args.max_steps): goal = goals[goal_idx] if learned: action, planned, learned_plan_mean = learned_plan( model, image_history, action_history, goals, goal_idx, env.action_dim, args, prev_action, learned_plan_mean, context_mode, donor_context, ) else: action = traditional_action(method, image_history, env, goal) planned = None prev_action = action.copy() _obs, _reward, done, _info = env.step(action) energy += float(np.sum(action * action)) trajectory.append(env.full_state()[:6].copy()) image_history.append(clean_observation(env, args.image_size, args.visual_scale)) action_history.append(pad_action(action, len(action_history[-1]))) dists = np.linalg.norm(goals - env.state[:2], axis=1) min_goal_dists = np.minimum(min_goal_dists, dists) if ep < args.make_gifs and t % args.gif_stride == 0: frames.append( render_frame( env.full_state()[:6], env.spec, env.flow, env.workspace, trajectory=np.asarray(trajectory), goal=goal, planned=planned, t=env.time, ) ) if float(dists[goal_idx]) < args.success_radius: reached_times.append(t + 1) if args.task == "station_keeping": if t >= max(40, args.max_steps // 3): break else: goal_idx += 1 learned_plan_mean = None if goal_idx >= len(goals): break if done: break path = np.asarray(trajectory)[:, :2] final_goal = goals[min(goal_idx, len(goals) - 1)] record = { "method": method, "context_mode": context_mode, "episode": ep, "success": bool(goal_idx >= len(goals) or (args.task == "station_keeping" and np.linalg.norm(env.state[:2] - goals[0]) < args.success_radius)), "final_distance": float(np.linalg.norm(env.state[:2] - final_goal)), "mean_min_goal_distance": float(min_goal_dists.mean()), "path_length": float(np.linalg.norm(np.diff(path, axis=0), axis=-1).sum()) if len(path) > 1 else 0.0, "energy": energy, "steps": len(trajectory) - 1, "reached_times": reached_times, } results.append(record) if ep < args.make_gifs and frames: name = f"image_planning_{method}_{context_mode}_{args.boat}_{args.task}_{args.flow_type}_ep{ep:03d}.gif" save_gif(frames, gif_dir / name, duration_ms=args.gif_duration_ms) return summarize(method, args, results) def summarize(method: str, args, results: list[dict]) -> dict: groups = sorted({r["context_mode"] for r in results}) by_context = {} def success_mean(items: list[dict], key: str) -> float | None: successful = [r[key] for r in items if r["success"]] return float(np.mean(successful)) if successful else None for context in groups: items = [r for r in results if r["context_mode"] == context] by_context[context] = { "episodes": len(items), "successes": len([r for r in items if r["success"]]), "success_rate": float(np.mean([r["success"] for r in items])), "final_distance_mean": float(np.mean([r["final_distance"] for r in items])), "mean_min_goal_distance": float(np.mean([r["mean_min_goal_distance"] for r in items])), "path_length_success_mean": success_mean(items, "path_length"), "energy_success_mean": success_mean(items, "energy"), "steps_success_mean": success_mean(items, "steps"), } return { "method": method, "task": args.task, "boat": args.boat, "flow_type": args.flow_type, "by_context": by_context, "results": results, } def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--methods", nargs="+", default=LEARNED_METHODS + TRADITIONAL_METHODS) parser.add_argument("--task", choices=["reach_target", "station_keeping", "waypoint_square", "waypoint_zigzag"], default="reach_target") parser.add_argument("--boat", choices=["twin", "triangle"], default="twin") parser.add_argument("--flow-type", choices=["noflow", "uniform", "vortex_center", "double_gyre", "source_sink", "source_sink_pair", "gradient", "shear", "turbulent_patch", "random_fourier"], default="uniform") parser.add_argument("--episodes", type=int, default=50) parser.add_argument("--max-steps", type=int, default=420) parser.add_argument("--history-len", type=int, default=32) parser.add_argument("--image-size", type=int, default=160) parser.add_argument("--visual-scale", type=float, default=2.5) parser.add_argument("--checkpoint-name", default="paper.pt") parser.add_argument("--context-modes", nargs="+", default=["inferred", "zero", "shuffled"]) parser.add_argument("--cem-horizon", type=int, default=45) parser.add_argument("--cem-population", type=int, default=512) parser.add_argument("--cem-elites", type=int, default=64) parser.add_argument("--cem-iterations", type=int, default=4) parser.add_argument("--cem-action-std", type=float, default=0.5) parser.add_argument("--cem-knots", type=int, default=10) parser.add_argument("--cem-w-goal", type=float, default=6.0) parser.add_argument("--cem-w-path", type=float, default=0.2) parser.add_argument("--cem-w-route", type=float, default=6.0) parser.add_argument("--cem-w-lookahead", type=float, default=2.0) parser.add_argument("--cem-w-via", type=float, default=2.0) parser.add_argument("--cem-route-horizon-distance", type=float, default=3.0) parser.add_argument("--cem-w-heading-goal", type=float, default=0.0) parser.add_argument("--cem-w-action", type=float, default=0.08) parser.add_argument("--cem-w-smooth", type=float, default=0.08) parser.add_argument("--cem-w-boundary", type=float, default=250.0) parser.add_argument("--cem-boundary-margin", type=float, default=0.75) parser.add_argument("--cem-w-progress", type=float, default=2.0) parser.add_argument("--success-radius", type=float, default=0.65) parser.add_argument("--make-gifs", type=int, default=3) parser.add_argument("--gif-stride", type=int, default=1) parser.add_argument("--gif-duration-ms", type=int, default=55) parser.add_argument("--seed", type=int, default=33) parser.add_argument("--device", default="cuda") parser.add_argument("--precision", choices=["fp32", "bf16", "fp16"], default="fp32") parser.add_argument("--out", default="experiments/reports/paper_planning") args = parser.parse_args() out_dir = Path(args.out) out_dir.mkdir(parents=True, exist_ok=True) payload = [evaluate_one_method(method, args) for method in args.methods] out_path = out_dir / f"{args.task}_{args.boat}_{args.flow_type}.json" out_path.write_text(json.dumps(payload, indent=2)) print(json.dumps(payload, indent=2)) if __name__ == "__main__": main()