| """Closed-loop planning evaluation for clean-image world models and controllers.""" | |
| from __future__ import annotations | |
| import argparse | |
| import importlib | |
| import json | |
| from collections import deque | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from driftwm.sim.env import SurfaceBoatEnv | |
| from driftwm.sim.flow import sample_flow | |
| from driftwm.sim.render import render_frame, save_gif | |
| from experiments.shared.src.methods import PAPER_LEARNED_METHODS, TRADITIONAL_METHODS | |
| from experiments.shared.src.vision.clean_renderer import render_clean_boat_array | |
| from experiments.train_image_world_models import autocast_context | |
| LEARNED_METHODS = PAPER_LEARNED_METHODS | |
| POSITION_SCALE = 5.0 | |
| def build_method(method: str): | |
| config_module = importlib.import_module(f"experiments.{method}.src.config") | |
| model_module = importlib.import_module(f"experiments.{method}.src.model") | |
| cfg = config_module.default_config() | |
| return cfg, model_module.build_model(cfg) | |
| def decode_absolute(prediction: torch.Tensor) -> torch.Tensor: | |
| xy = prediction[..., :2] * POSITION_SCALE + POSITION_SCALE | |
| return torch.cat([xy, prediction[..., 2:4]], dim=-1) | |
| def clean_observation(env: SurfaceBoatEnv, image_size: int, visual_scale: float) -> np.ndarray: | |
| image = render_clean_boat_array(env.full_state()[:6], env.spec, image_size=image_size, visual_scale=visual_scale) | |
| return np.transpose(image, (2, 0, 1)) | |
| def pad_action(action: np.ndarray, action_dim: int) -> np.ndarray: | |
| out = np.zeros((action_dim,), dtype=np.float32) | |
| action = np.asarray(action, dtype=np.float32) | |
| out[: min(len(action), action_dim)] = action[: min(len(action), action_dim)] | |
| return out | |
| def task_goals(task: str, rng: np.random.Generator) -> np.ndarray: | |
| if task == "waypoint_square": | |
| return np.array([[2.5, 2.5], [7.5, 2.5], [7.5, 7.5], [2.5, 7.5]], dtype=np.float32) | |
| if task == "waypoint_zigzag": | |
| return np.array([[2.5, 7.0], [4.2, 3.0], [5.8, 7.0], [7.5, 3.0]], dtype=np.float32) | |
| if task == "station_keeping": | |
| return np.array([[5.0, 5.0]], dtype=np.float32) | |
| return np.array([[8.0, 8.0]], dtype=np.float32) | |
| def set_task_state(env: SurfaceBoatEnv, state: np.ndarray) -> None: | |
| env.state[:6] = np.asarray(state, dtype=np.float32) | |
| env.last_flow_velocity = env.flow_at(env.state[:2]).astype(np.float32) | |
| def reset_task(env: SurfaceBoatEnv, task: str, flow_type: str, rng: np.random.Generator) -> None: | |
| if task == "station_keeping": | |
| flow = sample_flow(flow_type, rng, flow_id=10_000 + int(rng.integers(1, 1_000_000)), workspace=env.workspace) | |
| env.reset(flow_type=flow_type, flow=flow, random_velocity=False) | |
| set_task_state(env, np.array([5.0, 5.0, 0.3, 0.0, 0.0, 0.0], dtype=np.float32)) | |
| return | |
| flow = sample_flow(flow_type, rng, flow_id=10_000 + int(rng.integers(1, 1_000_000)), workspace=env.workspace) | |
| env.reset(flow_type=flow_type, flow=flow, random_velocity=False) | |
| set_task_state(env, np.array([2.0, 2.0, float(rng.uniform(-np.pi, np.pi)), 0.0, 0.0, 0.0], dtype=np.float32)) | |
| def rollout_latent(model, z: torch.Tensor, c: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: | |
| cur = z.repeat(actions.shape[0], 1) | |
| ctx = c.repeat(actions.shape[0], 1) if c.numel() else c | |
| preds = [] | |
| for t in range(actions.shape[1]): | |
| cur = model.step(cur, actions[:, t], ctx) | |
| preds.append(model.decoder(cur)) | |
| return decode_absolute(torch.stack(preds, dim=1)).float() | |
| def warm_start_mean( | |
| previous_mean: np.ndarray | None, | |
| horizon: int, | |
| action_dim: int, | |
| active_action_dim: int, | |
| device: torch.device, | |
| ) -> torch.Tensor: | |
| mean = torch.zeros((horizon, action_dim), dtype=torch.float32, device=device) | |
| if previous_mean is None: | |
| return mean | |
| previous = torch.as_tensor(previous_mean, dtype=torch.float32, device=device) | |
| steps = min(horizon, max(0, previous.shape[0] - 1)) | |
| if steps > 0: | |
| mean[:steps, :active_action_dim] = previous[1 : 1 + steps, :active_action_dim] | |
| if previous.shape[0] > 0 and steps < horizon: | |
| mean[steps:, :active_action_dim] = previous[-1, :active_action_dim] | |
| return mean.clamp(-1.0, 1.0) | |
| def sample_action_sequences(mean: torch.Tensor, std: torch.Tensor, population: int, knots: int) -> torch.Tensor: | |
| horizon, action_dim = mean.shape | |
| if knots >= horizon: | |
| noise = torch.randn(population, horizon, action_dim, device=mean.device) | |
| return mean.unsqueeze(0) + std.unsqueeze(0) * noise | |
| knots = max(2, knots) | |
| knot_idx = torch.linspace(0, horizon - 1, knots, device=mean.device).round().long() | |
| knot_mean = mean[knot_idx] | |
| knot_std = std[knot_idx] | |
| knot_samples = knot_mean.unsqueeze(0) + knot_std.unsqueeze(0) * torch.randn( | |
| population, | |
| knots, | |
| action_dim, | |
| device=mean.device, | |
| ) | |
| samples = F.interpolate( | |
| knot_samples.permute(0, 2, 1), | |
| size=horizon, | |
| mode="linear", | |
| align_corners=True, | |
| ).permute(0, 2, 1) | |
| return samples | |
| def route_points_tensor( | |
| current_pos: torch.Tensor, | |
| goals: np.ndarray, | |
| goal_idx: int, | |
| ) -> torch.Tensor: | |
| remaining = torch.as_tensor(goals[goal_idx:], dtype=torch.float32, device=current_pos.device) | |
| return torch.cat([current_pos.reshape(1, 2).detach(), remaining], dim=0) | |
| def route_projection(pos: torch.Tensor, route_points: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| starts = route_points[:-1] | |
| ends = route_points[1:] | |
| seg = ends - starts | |
| seg_len = torch.linalg.norm(seg, dim=-1).clamp_min(1.0e-6) | |
| seg_len_sq = (seg_len * seg_len).clamp_min(1.0e-6) | |
| rel = pos[:, :, None, :] - starts.view(1, 1, -1, 2) | |
| t = (rel * seg.view(1, 1, -1, 2)).sum(dim=-1) / seg_len_sq.view(1, 1, -1) | |
| t = t.clamp(0.0, 1.0) | |
| proj = starts.view(1, 1, -1, 2) + t[..., None] * seg.view(1, 1, -1, 2) | |
| dist_sq = ((pos[:, :, None, :] - proj) ** 2).sum(dim=-1) | |
| min_dist_sq, idx = dist_sq.min(dim=-1) | |
| cum = torch.cat([torch.zeros(1, device=pos.device), seg_len.cumsum(dim=0)[:-1]], dim=0) | |
| along = cum.view(1, 1, -1) + t * seg_len.view(1, 1, -1) | |
| route_s = along.gather(dim=-1, index=idx[..., None]).squeeze(-1) | |
| return min_dist_sq, route_s, seg_len.sum() | |
| def route_points_at_s(route_points: torch.Tensor, s: torch.Tensor) -> torch.Tensor: | |
| starts = route_points[:-1] | |
| ends = route_points[1:] | |
| seg = ends - starts | |
| seg_len = torch.linalg.norm(seg, dim=-1).clamp_min(1.0e-6) | |
| cum_end = seg_len.cumsum(dim=0) | |
| cum_start = cum_end - seg_len | |
| flat_s = s.reshape(-1).clamp(0.0, float(cum_end[-1].detach().cpu())) | |
| idx = torch.searchsorted(cum_end, flat_s, right=False).clamp(max=seg_len.numel() - 1) | |
| local = ((flat_s - cum_start[idx]) / seg_len[idx]).clamp(0.0, 1.0) | |
| pts = starts[idx] + local[:, None] * seg[idx] | |
| return pts.reshape(*s.shape, 2) | |
| def learned_plan( | |
| model, | |
| image_history: deque, | |
| action_history: deque, | |
| goals: np.ndarray, | |
| goal_idx: int, | |
| active_action_dim: int, | |
| args, | |
| prev_action: np.ndarray, | |
| previous_mean: np.ndarray | None, | |
| context_mode: str, | |
| donor_context: torch.Tensor | None, | |
| ) -> tuple[np.ndarray, np.ndarray | None, np.ndarray]: | |
| device = next(model.parameters()).device | |
| images = torch.as_tensor(np.asarray(image_history, dtype=np.uint8), device=device).unsqueeze(0) | |
| actions = torch.as_tensor(np.asarray(action_history, dtype=np.float32), device=device).unsqueeze(0) | |
| with torch.no_grad(), autocast_context(device, args.precision): | |
| z, c = model.encode(images, actions) | |
| if c.numel() and context_mode == "zero": | |
| c = torch.zeros_like(c) | |
| if c.numel() and context_mode == "shuffled" and donor_context is not None: | |
| c = donor_context.to(device=device, dtype=torch.float32) | |
| z = z.detach() | |
| c = c.detach() | |
| goal = goals[goal_idx] | |
| goal_t = torch.as_tensor(goal, dtype=torch.float32, device=device).view(1, 2) | |
| with torch.no_grad(), autocast_context(device, args.precision): | |
| current_pos = decode_absolute(model.decoder(z)).float().detach()[..., :2] | |
| route_points = route_points_tensor(current_pos[0], goals, goal_idx) | |
| mean = warm_start_mean( | |
| previous_mean, | |
| args.cem_horizon, | |
| model.config.action_dim, | |
| active_action_dim, | |
| device, | |
| ) | |
| std = torch.full_like(mean, args.cem_action_std) | |
| prev = torch.zeros((model.config.action_dim,), dtype=torch.float32, device=device) | |
| prev[:active_action_dim] = torch.as_tensor(prev_action, dtype=torch.float32, device=device) | |
| best_candidates = None | |
| with torch.no_grad(): | |
| action, best_candidates, mean = cem_plan( | |
| model, | |
| z, | |
| c, | |
| mean, | |
| std, | |
| goal_t, | |
| route_points, | |
| current_pos, | |
| prev, | |
| active_action_dim, | |
| args, | |
| ) | |
| return action, best_candidates, mean | |
| def planning_cost( | |
| pred: torch.Tensor, | |
| samples: torch.Tensor, | |
| goal_t: torch.Tensor, | |
| route_points: torch.Tensor, | |
| current_pos: torch.Tensor, | |
| prev: torch.Tensor, | |
| active_action_dim: int, | |
| args, | |
| ) -> torch.Tensor: | |
| pos = pred[..., :2] | |
| goal_delta = goal_t - current_pos | |
| goal_dir = goal_delta / torch.linalg.norm(goal_delta, dim=-1, keepdim=True).clamp_min(1.0e-6) | |
| direct_progress = ((pos - current_pos[:, None]) * goal_dir[:, None]).sum(dim=-1).amax(dim=-1) | |
| alpha = torch.linspace(1.0 / pos.shape[1], 1.0, pos.shape[1], device=pos.device, dtype=pos.dtype) | |
| direct_route = current_pos[:, None] + alpha.view(1, -1, 1) * goal_delta[:, None] | |
| direct_route_error = ((pos - direct_route) ** 2).sum(dim=-1).mean(dim=-1) | |
| route_dist_sq, route_s, route_len = route_projection(pos, route_points) | |
| route_error = route_dist_sq.mean(dim=-1) | |
| scheduled_s = alpha * torch.minimum( | |
| route_len, | |
| torch.as_tensor(args.cem_route_horizon_distance, dtype=pos.dtype, device=pos.device), | |
| ) | |
| scheduled = route_points_at_s(route_points, scheduled_s).view(1, pos.shape[1], 2) | |
| lookahead_error = ((pos - scheduled) ** 2).sum(dim=-1).mean(dim=-1) | |
| route_progress = (route_s.amax(dim=-1) / route_len.clamp_min(1.0e-6)).clamp(0.0, 1.0) | |
| goal_from_pos = goal_t[:, None] - pos | |
| goal_from_pos = goal_from_pos / torch.linalg.norm(goal_from_pos, dim=-1, keepdim=True).clamp_min(1.0e-6) | |
| heading = pred[..., 2:4] | |
| heading = heading / torch.linalg.norm(heading, dim=-1, keepdim=True).clamp_min(1.0e-6) | |
| heading_error = (1.0 - (heading * goal_from_pos).sum(dim=-1)).mean(dim=-1) | |
| terminal = ((pos[:, -1] - goal_t) ** 2).sum(dim=-1) | |
| path = ((pos - goal_t[:, None]) ** 2).sum(dim=-1).mean(dim=-1) | |
| via = ((pos - goal_t[:, None]) ** 2).sum(dim=-1).amin(dim=-1) | |
| energy = (samples[..., :active_action_dim] ** 2).mean(dim=(1, 2)) | |
| smooth_prev = torch.cat([prev.view(1, 1, -1).repeat(samples.shape[0], 1, 1), samples[:, :-1]], dim=1) | |
| smooth = ((samples - smooth_prev) ** 2).mean(dim=(1, 2)) | |
| margin = args.cem_boundary_margin | |
| boundary = ( | |
| torch.relu(margin - pos[..., 0]) | |
| + torch.relu(pos[..., 0] - (10.0 - margin)) | |
| + torch.relu(margin - pos[..., 1]) | |
| + torch.relu(pos[..., 1] - (10.0 - margin)) | |
| ).mean(dim=-1) | |
| return ( | |
| args.cem_w_goal * terminal | |
| + args.cem_w_path * path | |
| + args.cem_w_route * (route_error + 0.25 * direct_route_error) | |
| + args.cem_w_lookahead * lookahead_error | |
| + args.cem_w_via * via | |
| + args.cem_w_heading_goal * heading_error | |
| + args.cem_w_action * energy | |
| + args.cem_w_smooth * smooth | |
| + args.cem_w_boundary * boundary | |
| - args.cem_w_progress * (route_progress + 0.1 * direct_progress) | |
| ) | |
| def cem_plan( | |
| model, | |
| z: torch.Tensor, | |
| c: torch.Tensor, | |
| mean: torch.Tensor, | |
| std: torch.Tensor, | |
| goal_t: torch.Tensor, | |
| route_points: torch.Tensor, | |
| current_pos: torch.Tensor, | |
| prev: torch.Tensor, | |
| active_action_dim: int, | |
| args, | |
| ) -> tuple[np.ndarray, np.ndarray | None, np.ndarray]: | |
| best_candidates = None | |
| for _ in range(args.cem_iterations): | |
| samples = sample_action_sequences(mean, std, args.cem_population, args.cem_knots) | |
| samples[0] = mean | |
| samples = samples.clamp(-1.0, 1.0) | |
| if active_action_dim < model.config.action_dim: | |
| samples[:, :, active_action_dim:] = 0.0 | |
| with autocast_context(mean.device, args.precision): | |
| pred = rollout_latent(model, z, c, samples) | |
| cost = planning_cost(pred, samples, goal_t, route_points, current_pos, prev, active_action_dim, args) | |
| elite_idx = torch.topk(cost, k=args.cem_elites, largest=False).indices | |
| elites = samples[elite_idx] | |
| mean = elites.mean(dim=0) | |
| std = elites.std(dim=0).clamp_min(0.05) | |
| if args.make_gifs: | |
| pos = pred[..., :2] | |
| best_candidates = pos[elite_idx[:12]].detach().cpu().numpy() | |
| action = mean[0, :active_action_dim].detach().cpu().numpy() | |
| return ( | |
| np.clip(action, -1.0, 1.0).astype(np.float32), | |
| best_candidates, | |
| mean.detach().cpu().numpy(), | |
| ) | |
| def donor_context_for_flowmo(model, env: SurfaceBoatEnv, args, seed: int) -> torch.Tensor | None: | |
| if not hasattr(model, "to_c"): | |
| return None | |
| rng = np.random.default_rng(seed + 99_999) | |
| donor = SurfaceBoatEnv( | |
| boat=env.config.boat, | |
| flow_type=env.config.flow_type, | |
| boundary="terminate", | |
| episode_steps=model.config.context_len + 8, | |
| seed=seed + 99, | |
| ) | |
| donor.reset(flow_type=env.config.flow_type, random_velocity=False) | |
| image_history = deque(maxlen=args.history_len) | |
| action_history = deque(maxlen=args.history_len) | |
| action = np.zeros((model.config.action_dim,), dtype=np.float32) | |
| for _ in range(args.history_len): | |
| image_history.append(clean_observation(donor, args.image_size, args.visual_scale)) | |
| action_history.append(action.copy()) | |
| raw = rng.uniform(-0.5, 0.5, size=donor.action_dim).astype(np.float32) | |
| donor.step(raw) | |
| action = pad_action(raw, model.config.action_dim) | |
| device = next(model.parameters()).device | |
| images = torch.as_tensor(np.asarray(image_history, dtype=np.uint8), device=device).unsqueeze(0) | |
| actions = torch.as_tensor(np.asarray(action_history, dtype=np.float32), device=device).unsqueeze(0) | |
| with autocast_context(device, args.precision): | |
| return model.encode(images, actions)[1].detach() | |
| def traditional_action(method: str, image_history: deque, env: SurfaceBoatEnv, goal: np.ndarray) -> np.ndarray: | |
| evaluate_module = importlib.import_module(f"experiments.{method}.src.evaluate") | |
| image = np.transpose(image_history[-1], (1, 2, 0)) | |
| history = [np.transpose(x, (1, 2, 0)) for x in image_history] | |
| cfg = { | |
| "image": image, | |
| "history": history, | |
| "true_flow": env.last_flow_velocity.copy(), | |
| "goal": goal.astype(float).tolist(), | |
| "action_dim": env.action_dim, | |
| "boat": env.config.boat, | |
| } | |
| return evaluate_module.evaluate(cfg)[: env.action_dim].astype(np.float32) | |
| def evaluate_one_method(method: str, args) -> dict: | |
| torch.manual_seed(args.seed) | |
| learned = method in LEARNED_METHODS | |
| model = None | |
| if learned: | |
| _cfg, model = build_method(method) | |
| state = torch.load(Path("experiments") / method / "checkpoint" / args.checkpoint_name, map_location="cpu") | |
| model.load_state_dict(state) | |
| model.to(torch.device(args.device)) | |
| if torch.device(args.device).type == "cuda": | |
| model.to(memory_format=torch.channels_last) | |
| model.eval() | |
| for param in model.parameters(): | |
| param.requires_grad_(False) | |
| results = [] | |
| gif_dir = Path(args.out) / "gifs" | |
| gif_dir.mkdir(parents=True, exist_ok=True) | |
| context_modes = args.context_modes if method == "flowmo" else ["inferred"] | |
| for context_mode in context_modes: | |
| for ep in range(args.episodes): | |
| episode_seed = int(args.seed + ep) | |
| rng = np.random.default_rng(episode_seed) | |
| env = SurfaceBoatEnv( | |
| boat=args.boat, | |
| flow_type=args.flow_type, | |
| boundary="terminate", | |
| episode_steps=args.max_steps, | |
| seed=episode_seed, | |
| ) | |
| reset_task(env, args.task, args.flow_type, rng) | |
| goals = task_goals(args.task, rng) | |
| goal_idx = 0 | |
| image_history = deque(maxlen=args.history_len) | |
| action_history = deque(maxlen=args.history_len) | |
| zero = np.zeros((model.config.action_dim if learned else 3,), dtype=np.float32) | |
| first = clean_observation(env, args.image_size, args.visual_scale) | |
| for _ in range(args.history_len): | |
| image_history.append(first.copy()) | |
| action_history.append(zero.copy()) | |
| donor_context = donor_context_for_flowmo(model, env, args, episode_seed) if learned and context_mode == "shuffled" else None | |
| trajectory = [env.full_state()[:6].copy()] | |
| frames = [] | |
| prev_action = np.zeros((env.action_dim,), dtype=np.float32) | |
| energy = 0.0 | |
| reached_times: list[int] = [] | |
| min_goal_dists = np.full((len(goals),), np.inf, dtype=np.float32) | |
| planned = None | |
| learned_plan_mean = None | |
| for t in range(args.max_steps): | |
| goal = goals[goal_idx] | |
| if learned: | |
| action, planned, learned_plan_mean = learned_plan( | |
| model, | |
| image_history, | |
| action_history, | |
| goals, | |
| goal_idx, | |
| env.action_dim, | |
| args, | |
| prev_action, | |
| learned_plan_mean, | |
| context_mode, | |
| donor_context, | |
| ) | |
| else: | |
| action = traditional_action(method, image_history, env, goal) | |
| planned = None | |
| prev_action = action.copy() | |
| _obs, _reward, done, _info = env.step(action) | |
| energy += float(np.sum(action * action)) | |
| trajectory.append(env.full_state()[:6].copy()) | |
| image_history.append(clean_observation(env, args.image_size, args.visual_scale)) | |
| action_history.append(pad_action(action, len(action_history[-1]))) | |
| dists = np.linalg.norm(goals - env.state[:2], axis=1) | |
| min_goal_dists = np.minimum(min_goal_dists, dists) | |
| if ep < args.make_gifs and t % args.gif_stride == 0: | |
| frames.append( | |
| render_frame( | |
| env.full_state()[:6], | |
| env.spec, | |
| env.flow, | |
| env.workspace, | |
| trajectory=np.asarray(trajectory), | |
| goal=goal, | |
| planned=planned, | |
| t=env.time, | |
| ) | |
| ) | |
| if float(dists[goal_idx]) < args.success_radius: | |
| reached_times.append(t + 1) | |
| if args.task == "station_keeping": | |
| if t >= max(40, args.max_steps // 3): | |
| break | |
| else: | |
| goal_idx += 1 | |
| learned_plan_mean = None | |
| if goal_idx >= len(goals): | |
| break | |
| if done: | |
| break | |
| path = np.asarray(trajectory)[:, :2] | |
| final_goal = goals[min(goal_idx, len(goals) - 1)] | |
| record = { | |
| "method": method, | |
| "context_mode": context_mode, | |
| "episode": ep, | |
| "success": bool(goal_idx >= len(goals) or (args.task == "station_keeping" and np.linalg.norm(env.state[:2] - goals[0]) < args.success_radius)), | |
| "final_distance": float(np.linalg.norm(env.state[:2] - final_goal)), | |
| "mean_min_goal_distance": float(min_goal_dists.mean()), | |
| "path_length": float(np.linalg.norm(np.diff(path, axis=0), axis=-1).sum()) if len(path) > 1 else 0.0, | |
| "energy": energy, | |
| "steps": len(trajectory) - 1, | |
| "reached_times": reached_times, | |
| } | |
| results.append(record) | |
| if ep < args.make_gifs and frames: | |
| name = f"image_planning_{method}_{context_mode}_{args.boat}_{args.task}_{args.flow_type}_ep{ep:03d}.gif" | |
| save_gif(frames, gif_dir / name, duration_ms=args.gif_duration_ms) | |
| return summarize(method, args, results) | |
| def summarize(method: str, args, results: list[dict]) -> dict: | |
| groups = sorted({r["context_mode"] for r in results}) | |
| by_context = {} | |
| def success_mean(items: list[dict], key: str) -> float | None: | |
| successful = [r[key] for r in items if r["success"]] | |
| return float(np.mean(successful)) if successful else None | |
| for context in groups: | |
| items = [r for r in results if r["context_mode"] == context] | |
| by_context[context] = { | |
| "episodes": len(items), | |
| "successes": len([r for r in items if r["success"]]), | |
| "success_rate": float(np.mean([r["success"] for r in items])), | |
| "final_distance_mean": float(np.mean([r["final_distance"] for r in items])), | |
| "mean_min_goal_distance": float(np.mean([r["mean_min_goal_distance"] for r in items])), | |
| "path_length_success_mean": success_mean(items, "path_length"), | |
| "energy_success_mean": success_mean(items, "energy"), | |
| "steps_success_mean": success_mean(items, "steps"), | |
| } | |
| return { | |
| "method": method, | |
| "task": args.task, | |
| "boat": args.boat, | |
| "flow_type": args.flow_type, | |
| "by_context": by_context, | |
| "results": results, | |
| } | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--methods", nargs="+", default=LEARNED_METHODS + TRADITIONAL_METHODS) | |
| parser.add_argument("--task", choices=["reach_target", "station_keeping", "waypoint_square", "waypoint_zigzag"], default="reach_target") | |
| parser.add_argument("--boat", choices=["twin", "triangle"], default="twin") | |
| parser.add_argument("--flow-type", choices=["noflow", "uniform", "vortex_center", "double_gyre", "source_sink", "source_sink_pair", "gradient", "shear", "turbulent_patch", "random_fourier"], default="uniform") | |
| parser.add_argument("--episodes", type=int, default=50) | |
| parser.add_argument("--max-steps", type=int, default=420) | |
| parser.add_argument("--history-len", type=int, default=32) | |
| parser.add_argument("--image-size", type=int, default=160) | |
| parser.add_argument("--visual-scale", type=float, default=2.5) | |
| parser.add_argument("--checkpoint-name", default="paper.pt") | |
| parser.add_argument("--context-modes", nargs="+", default=["inferred", "zero", "shuffled"]) | |
| parser.add_argument("--cem-horizon", type=int, default=45) | |
| parser.add_argument("--cem-population", type=int, default=512) | |
| parser.add_argument("--cem-elites", type=int, default=64) | |
| parser.add_argument("--cem-iterations", type=int, default=4) | |
| parser.add_argument("--cem-action-std", type=float, default=0.5) | |
| parser.add_argument("--cem-knots", type=int, default=10) | |
| parser.add_argument("--cem-w-goal", type=float, default=6.0) | |
| parser.add_argument("--cem-w-path", type=float, default=0.2) | |
| parser.add_argument("--cem-w-route", type=float, default=6.0) | |
| parser.add_argument("--cem-w-lookahead", type=float, default=2.0) | |
| parser.add_argument("--cem-w-via", type=float, default=2.0) | |
| parser.add_argument("--cem-route-horizon-distance", type=float, default=3.0) | |
| parser.add_argument("--cem-w-heading-goal", type=float, default=0.0) | |
| parser.add_argument("--cem-w-action", type=float, default=0.08) | |
| parser.add_argument("--cem-w-smooth", type=float, default=0.08) | |
| parser.add_argument("--cem-w-boundary", type=float, default=250.0) | |
| parser.add_argument("--cem-boundary-margin", type=float, default=0.75) | |
| parser.add_argument("--cem-w-progress", type=float, default=2.0) | |
| parser.add_argument("--success-radius", type=float, default=0.65) | |
| parser.add_argument("--make-gifs", type=int, default=3) | |
| parser.add_argument("--gif-stride", type=int, default=1) | |
| parser.add_argument("--gif-duration-ms", type=int, default=55) | |
| parser.add_argument("--seed", type=int, default=33) | |
| parser.add_argument("--device", default="cuda") | |
| parser.add_argument("--precision", choices=["fp32", "bf16", "fp16"], default="fp32") | |
| parser.add_argument("--out", default="experiments/reports/paper_planning") | |
| args = parser.parse_args() | |
| out_dir = Path(args.out) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| payload = [evaluate_one_method(method, args) for method in args.methods] | |
| out_path = out_dir / f"{args.task}_{args.boat}_{args.flow_type}.json" | |
| out_path.write_text(json.dumps(payload, indent=2)) | |
| print(json.dumps(payload, indent=2)) | |
| if __name__ == "__main__": | |
| main() | |