FlowMo-WM / experiments /evaluate_image_planning.py
cccat6's picture
Include flow type in planning GIF names
8093443 verified
"""Closed-loop planning evaluation for clean-image world models and controllers."""
from __future__ import annotations
import argparse
import importlib
import json
from collections import deque
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from driftwm.sim.env import SurfaceBoatEnv
from driftwm.sim.flow import sample_flow
from driftwm.sim.render import render_frame, save_gif
from experiments.shared.src.methods import PAPER_LEARNED_METHODS, TRADITIONAL_METHODS
from experiments.shared.src.vision.clean_renderer import render_clean_boat_array
from experiments.train_image_world_models import autocast_context
LEARNED_METHODS = PAPER_LEARNED_METHODS
POSITION_SCALE = 5.0
def build_method(method: str):
config_module = importlib.import_module(f"experiments.{method}.src.config")
model_module = importlib.import_module(f"experiments.{method}.src.model")
cfg = config_module.default_config()
return cfg, model_module.build_model(cfg)
def decode_absolute(prediction: torch.Tensor) -> torch.Tensor:
xy = prediction[..., :2] * POSITION_SCALE + POSITION_SCALE
return torch.cat([xy, prediction[..., 2:4]], dim=-1)
def clean_observation(env: SurfaceBoatEnv, image_size: int, visual_scale: float) -> np.ndarray:
image = render_clean_boat_array(env.full_state()[:6], env.spec, image_size=image_size, visual_scale=visual_scale)
return np.transpose(image, (2, 0, 1))
def pad_action(action: np.ndarray, action_dim: int) -> np.ndarray:
out = np.zeros((action_dim,), dtype=np.float32)
action = np.asarray(action, dtype=np.float32)
out[: min(len(action), action_dim)] = action[: min(len(action), action_dim)]
return out
def task_goals(task: str, rng: np.random.Generator) -> np.ndarray:
if task == "waypoint_square":
return np.array([[2.5, 2.5], [7.5, 2.5], [7.5, 7.5], [2.5, 7.5]], dtype=np.float32)
if task == "waypoint_zigzag":
return np.array([[2.5, 7.0], [4.2, 3.0], [5.8, 7.0], [7.5, 3.0]], dtype=np.float32)
if task == "station_keeping":
return np.array([[5.0, 5.0]], dtype=np.float32)
return np.array([[8.0, 8.0]], dtype=np.float32)
def set_task_state(env: SurfaceBoatEnv, state: np.ndarray) -> None:
env.state[:6] = np.asarray(state, dtype=np.float32)
env.last_flow_velocity = env.flow_at(env.state[:2]).astype(np.float32)
def reset_task(env: SurfaceBoatEnv, task: str, flow_type: str, rng: np.random.Generator) -> None:
if task == "station_keeping":
flow = sample_flow(flow_type, rng, flow_id=10_000 + int(rng.integers(1, 1_000_000)), workspace=env.workspace)
env.reset(flow_type=flow_type, flow=flow, random_velocity=False)
set_task_state(env, np.array([5.0, 5.0, 0.3, 0.0, 0.0, 0.0], dtype=np.float32))
return
flow = sample_flow(flow_type, rng, flow_id=10_000 + int(rng.integers(1, 1_000_000)), workspace=env.workspace)
env.reset(flow_type=flow_type, flow=flow, random_velocity=False)
set_task_state(env, np.array([2.0, 2.0, float(rng.uniform(-np.pi, np.pi)), 0.0, 0.0, 0.0], dtype=np.float32))
def rollout_latent(model, z: torch.Tensor, c: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
cur = z.repeat(actions.shape[0], 1)
ctx = c.repeat(actions.shape[0], 1) if c.numel() else c
preds = []
for t in range(actions.shape[1]):
cur = model.step(cur, actions[:, t], ctx)
preds.append(model.decoder(cur))
return decode_absolute(torch.stack(preds, dim=1)).float()
def warm_start_mean(
previous_mean: np.ndarray | None,
horizon: int,
action_dim: int,
active_action_dim: int,
device: torch.device,
) -> torch.Tensor:
mean = torch.zeros((horizon, action_dim), dtype=torch.float32, device=device)
if previous_mean is None:
return mean
previous = torch.as_tensor(previous_mean, dtype=torch.float32, device=device)
steps = min(horizon, max(0, previous.shape[0] - 1))
if steps > 0:
mean[:steps, :active_action_dim] = previous[1 : 1 + steps, :active_action_dim]
if previous.shape[0] > 0 and steps < horizon:
mean[steps:, :active_action_dim] = previous[-1, :active_action_dim]
return mean.clamp(-1.0, 1.0)
def sample_action_sequences(mean: torch.Tensor, std: torch.Tensor, population: int, knots: int) -> torch.Tensor:
horizon, action_dim = mean.shape
if knots >= horizon:
noise = torch.randn(population, horizon, action_dim, device=mean.device)
return mean.unsqueeze(0) + std.unsqueeze(0) * noise
knots = max(2, knots)
knot_idx = torch.linspace(0, horizon - 1, knots, device=mean.device).round().long()
knot_mean = mean[knot_idx]
knot_std = std[knot_idx]
knot_samples = knot_mean.unsqueeze(0) + knot_std.unsqueeze(0) * torch.randn(
population,
knots,
action_dim,
device=mean.device,
)
samples = F.interpolate(
knot_samples.permute(0, 2, 1),
size=horizon,
mode="linear",
align_corners=True,
).permute(0, 2, 1)
return samples
def route_points_tensor(
current_pos: torch.Tensor,
goals: np.ndarray,
goal_idx: int,
) -> torch.Tensor:
remaining = torch.as_tensor(goals[goal_idx:], dtype=torch.float32, device=current_pos.device)
return torch.cat([current_pos.reshape(1, 2).detach(), remaining], dim=0)
def route_projection(pos: torch.Tensor, route_points: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
starts = route_points[:-1]
ends = route_points[1:]
seg = ends - starts
seg_len = torch.linalg.norm(seg, dim=-1).clamp_min(1.0e-6)
seg_len_sq = (seg_len * seg_len).clamp_min(1.0e-6)
rel = pos[:, :, None, :] - starts.view(1, 1, -1, 2)
t = (rel * seg.view(1, 1, -1, 2)).sum(dim=-1) / seg_len_sq.view(1, 1, -1)
t = t.clamp(0.0, 1.0)
proj = starts.view(1, 1, -1, 2) + t[..., None] * seg.view(1, 1, -1, 2)
dist_sq = ((pos[:, :, None, :] - proj) ** 2).sum(dim=-1)
min_dist_sq, idx = dist_sq.min(dim=-1)
cum = torch.cat([torch.zeros(1, device=pos.device), seg_len.cumsum(dim=0)[:-1]], dim=0)
along = cum.view(1, 1, -1) + t * seg_len.view(1, 1, -1)
route_s = along.gather(dim=-1, index=idx[..., None]).squeeze(-1)
return min_dist_sq, route_s, seg_len.sum()
def route_points_at_s(route_points: torch.Tensor, s: torch.Tensor) -> torch.Tensor:
starts = route_points[:-1]
ends = route_points[1:]
seg = ends - starts
seg_len = torch.linalg.norm(seg, dim=-1).clamp_min(1.0e-6)
cum_end = seg_len.cumsum(dim=0)
cum_start = cum_end - seg_len
flat_s = s.reshape(-1).clamp(0.0, float(cum_end[-1].detach().cpu()))
idx = torch.searchsorted(cum_end, flat_s, right=False).clamp(max=seg_len.numel() - 1)
local = ((flat_s - cum_start[idx]) / seg_len[idx]).clamp(0.0, 1.0)
pts = starts[idx] + local[:, None] * seg[idx]
return pts.reshape(*s.shape, 2)
def learned_plan(
model,
image_history: deque,
action_history: deque,
goals: np.ndarray,
goal_idx: int,
active_action_dim: int,
args,
prev_action: np.ndarray,
previous_mean: np.ndarray | None,
context_mode: str,
donor_context: torch.Tensor | None,
) -> tuple[np.ndarray, np.ndarray | None, np.ndarray]:
device = next(model.parameters()).device
images = torch.as_tensor(np.asarray(image_history, dtype=np.uint8), device=device).unsqueeze(0)
actions = torch.as_tensor(np.asarray(action_history, dtype=np.float32), device=device).unsqueeze(0)
with torch.no_grad(), autocast_context(device, args.precision):
z, c = model.encode(images, actions)
if c.numel() and context_mode == "zero":
c = torch.zeros_like(c)
if c.numel() and context_mode == "shuffled" and donor_context is not None:
c = donor_context.to(device=device, dtype=torch.float32)
z = z.detach()
c = c.detach()
goal = goals[goal_idx]
goal_t = torch.as_tensor(goal, dtype=torch.float32, device=device).view(1, 2)
with torch.no_grad(), autocast_context(device, args.precision):
current_pos = decode_absolute(model.decoder(z)).float().detach()[..., :2]
route_points = route_points_tensor(current_pos[0], goals, goal_idx)
mean = warm_start_mean(
previous_mean,
args.cem_horizon,
model.config.action_dim,
active_action_dim,
device,
)
std = torch.full_like(mean, args.cem_action_std)
prev = torch.zeros((model.config.action_dim,), dtype=torch.float32, device=device)
prev[:active_action_dim] = torch.as_tensor(prev_action, dtype=torch.float32, device=device)
best_candidates = None
with torch.no_grad():
action, best_candidates, mean = cem_plan(
model,
z,
c,
mean,
std,
goal_t,
route_points,
current_pos,
prev,
active_action_dim,
args,
)
return action, best_candidates, mean
def planning_cost(
pred: torch.Tensor,
samples: torch.Tensor,
goal_t: torch.Tensor,
route_points: torch.Tensor,
current_pos: torch.Tensor,
prev: torch.Tensor,
active_action_dim: int,
args,
) -> torch.Tensor:
pos = pred[..., :2]
goal_delta = goal_t - current_pos
goal_dir = goal_delta / torch.linalg.norm(goal_delta, dim=-1, keepdim=True).clamp_min(1.0e-6)
direct_progress = ((pos - current_pos[:, None]) * goal_dir[:, None]).sum(dim=-1).amax(dim=-1)
alpha = torch.linspace(1.0 / pos.shape[1], 1.0, pos.shape[1], device=pos.device, dtype=pos.dtype)
direct_route = current_pos[:, None] + alpha.view(1, -1, 1) * goal_delta[:, None]
direct_route_error = ((pos - direct_route) ** 2).sum(dim=-1).mean(dim=-1)
route_dist_sq, route_s, route_len = route_projection(pos, route_points)
route_error = route_dist_sq.mean(dim=-1)
scheduled_s = alpha * torch.minimum(
route_len,
torch.as_tensor(args.cem_route_horizon_distance, dtype=pos.dtype, device=pos.device),
)
scheduled = route_points_at_s(route_points, scheduled_s).view(1, pos.shape[1], 2)
lookahead_error = ((pos - scheduled) ** 2).sum(dim=-1).mean(dim=-1)
route_progress = (route_s.amax(dim=-1) / route_len.clamp_min(1.0e-6)).clamp(0.0, 1.0)
goal_from_pos = goal_t[:, None] - pos
goal_from_pos = goal_from_pos / torch.linalg.norm(goal_from_pos, dim=-1, keepdim=True).clamp_min(1.0e-6)
heading = pred[..., 2:4]
heading = heading / torch.linalg.norm(heading, dim=-1, keepdim=True).clamp_min(1.0e-6)
heading_error = (1.0 - (heading * goal_from_pos).sum(dim=-1)).mean(dim=-1)
terminal = ((pos[:, -1] - goal_t) ** 2).sum(dim=-1)
path = ((pos - goal_t[:, None]) ** 2).sum(dim=-1).mean(dim=-1)
via = ((pos - goal_t[:, None]) ** 2).sum(dim=-1).amin(dim=-1)
energy = (samples[..., :active_action_dim] ** 2).mean(dim=(1, 2))
smooth_prev = torch.cat([prev.view(1, 1, -1).repeat(samples.shape[0], 1, 1), samples[:, :-1]], dim=1)
smooth = ((samples - smooth_prev) ** 2).mean(dim=(1, 2))
margin = args.cem_boundary_margin
boundary = (
torch.relu(margin - pos[..., 0])
+ torch.relu(pos[..., 0] - (10.0 - margin))
+ torch.relu(margin - pos[..., 1])
+ torch.relu(pos[..., 1] - (10.0 - margin))
).mean(dim=-1)
return (
args.cem_w_goal * terminal
+ args.cem_w_path * path
+ args.cem_w_route * (route_error + 0.25 * direct_route_error)
+ args.cem_w_lookahead * lookahead_error
+ args.cem_w_via * via
+ args.cem_w_heading_goal * heading_error
+ args.cem_w_action * energy
+ args.cem_w_smooth * smooth
+ args.cem_w_boundary * boundary
- args.cem_w_progress * (route_progress + 0.1 * direct_progress)
)
def cem_plan(
model,
z: torch.Tensor,
c: torch.Tensor,
mean: torch.Tensor,
std: torch.Tensor,
goal_t: torch.Tensor,
route_points: torch.Tensor,
current_pos: torch.Tensor,
prev: torch.Tensor,
active_action_dim: int,
args,
) -> tuple[np.ndarray, np.ndarray | None, np.ndarray]:
best_candidates = None
for _ in range(args.cem_iterations):
samples = sample_action_sequences(mean, std, args.cem_population, args.cem_knots)
samples[0] = mean
samples = samples.clamp(-1.0, 1.0)
if active_action_dim < model.config.action_dim:
samples[:, :, active_action_dim:] = 0.0
with autocast_context(mean.device, args.precision):
pred = rollout_latent(model, z, c, samples)
cost = planning_cost(pred, samples, goal_t, route_points, current_pos, prev, active_action_dim, args)
elite_idx = torch.topk(cost, k=args.cem_elites, largest=False).indices
elites = samples[elite_idx]
mean = elites.mean(dim=0)
std = elites.std(dim=0).clamp_min(0.05)
if args.make_gifs:
pos = pred[..., :2]
best_candidates = pos[elite_idx[:12]].detach().cpu().numpy()
action = mean[0, :active_action_dim].detach().cpu().numpy()
return (
np.clip(action, -1.0, 1.0).astype(np.float32),
best_candidates,
mean.detach().cpu().numpy(),
)
@torch.no_grad()
def donor_context_for_flowmo(model, env: SurfaceBoatEnv, args, seed: int) -> torch.Tensor | None:
if not hasattr(model, "to_c"):
return None
rng = np.random.default_rng(seed + 99_999)
donor = SurfaceBoatEnv(
boat=env.config.boat,
flow_type=env.config.flow_type,
boundary="terminate",
episode_steps=model.config.context_len + 8,
seed=seed + 99,
)
donor.reset(flow_type=env.config.flow_type, random_velocity=False)
image_history = deque(maxlen=args.history_len)
action_history = deque(maxlen=args.history_len)
action = np.zeros((model.config.action_dim,), dtype=np.float32)
for _ in range(args.history_len):
image_history.append(clean_observation(donor, args.image_size, args.visual_scale))
action_history.append(action.copy())
raw = rng.uniform(-0.5, 0.5, size=donor.action_dim).astype(np.float32)
donor.step(raw)
action = pad_action(raw, model.config.action_dim)
device = next(model.parameters()).device
images = torch.as_tensor(np.asarray(image_history, dtype=np.uint8), device=device).unsqueeze(0)
actions = torch.as_tensor(np.asarray(action_history, dtype=np.float32), device=device).unsqueeze(0)
with autocast_context(device, args.precision):
return model.encode(images, actions)[1].detach()
def traditional_action(method: str, image_history: deque, env: SurfaceBoatEnv, goal: np.ndarray) -> np.ndarray:
evaluate_module = importlib.import_module(f"experiments.{method}.src.evaluate")
image = np.transpose(image_history[-1], (1, 2, 0))
history = [np.transpose(x, (1, 2, 0)) for x in image_history]
cfg = {
"image": image,
"history": history,
"true_flow": env.last_flow_velocity.copy(),
"goal": goal.astype(float).tolist(),
"action_dim": env.action_dim,
"boat": env.config.boat,
}
return evaluate_module.evaluate(cfg)[: env.action_dim].astype(np.float32)
def evaluate_one_method(method: str, args) -> dict:
torch.manual_seed(args.seed)
learned = method in LEARNED_METHODS
model = None
if learned:
_cfg, model = build_method(method)
state = torch.load(Path("experiments") / method / "checkpoint" / args.checkpoint_name, map_location="cpu")
model.load_state_dict(state)
model.to(torch.device(args.device))
if torch.device(args.device).type == "cuda":
model.to(memory_format=torch.channels_last)
model.eval()
for param in model.parameters():
param.requires_grad_(False)
results = []
gif_dir = Path(args.out) / "gifs"
gif_dir.mkdir(parents=True, exist_ok=True)
context_modes = args.context_modes if method == "flowmo" else ["inferred"]
for context_mode in context_modes:
for ep in range(args.episodes):
episode_seed = int(args.seed + ep)
rng = np.random.default_rng(episode_seed)
env = SurfaceBoatEnv(
boat=args.boat,
flow_type=args.flow_type,
boundary="terminate",
episode_steps=args.max_steps,
seed=episode_seed,
)
reset_task(env, args.task, args.flow_type, rng)
goals = task_goals(args.task, rng)
goal_idx = 0
image_history = deque(maxlen=args.history_len)
action_history = deque(maxlen=args.history_len)
zero = np.zeros((model.config.action_dim if learned else 3,), dtype=np.float32)
first = clean_observation(env, args.image_size, args.visual_scale)
for _ in range(args.history_len):
image_history.append(first.copy())
action_history.append(zero.copy())
donor_context = donor_context_for_flowmo(model, env, args, episode_seed) if learned and context_mode == "shuffled" else None
trajectory = [env.full_state()[:6].copy()]
frames = []
prev_action = np.zeros((env.action_dim,), dtype=np.float32)
energy = 0.0
reached_times: list[int] = []
min_goal_dists = np.full((len(goals),), np.inf, dtype=np.float32)
planned = None
learned_plan_mean = None
for t in range(args.max_steps):
goal = goals[goal_idx]
if learned:
action, planned, learned_plan_mean = learned_plan(
model,
image_history,
action_history,
goals,
goal_idx,
env.action_dim,
args,
prev_action,
learned_plan_mean,
context_mode,
donor_context,
)
else:
action = traditional_action(method, image_history, env, goal)
planned = None
prev_action = action.copy()
_obs, _reward, done, _info = env.step(action)
energy += float(np.sum(action * action))
trajectory.append(env.full_state()[:6].copy())
image_history.append(clean_observation(env, args.image_size, args.visual_scale))
action_history.append(pad_action(action, len(action_history[-1])))
dists = np.linalg.norm(goals - env.state[:2], axis=1)
min_goal_dists = np.minimum(min_goal_dists, dists)
if ep < args.make_gifs and t % args.gif_stride == 0:
frames.append(
render_frame(
env.full_state()[:6],
env.spec,
env.flow,
env.workspace,
trajectory=np.asarray(trajectory),
goal=goal,
planned=planned,
t=env.time,
)
)
if float(dists[goal_idx]) < args.success_radius:
reached_times.append(t + 1)
if args.task == "station_keeping":
if t >= max(40, args.max_steps // 3):
break
else:
goal_idx += 1
learned_plan_mean = None
if goal_idx >= len(goals):
break
if done:
break
path = np.asarray(trajectory)[:, :2]
final_goal = goals[min(goal_idx, len(goals) - 1)]
record = {
"method": method,
"context_mode": context_mode,
"episode": ep,
"success": bool(goal_idx >= len(goals) or (args.task == "station_keeping" and np.linalg.norm(env.state[:2] - goals[0]) < args.success_radius)),
"final_distance": float(np.linalg.norm(env.state[:2] - final_goal)),
"mean_min_goal_distance": float(min_goal_dists.mean()),
"path_length": float(np.linalg.norm(np.diff(path, axis=0), axis=-1).sum()) if len(path) > 1 else 0.0,
"energy": energy,
"steps": len(trajectory) - 1,
"reached_times": reached_times,
}
results.append(record)
if ep < args.make_gifs and frames:
name = f"image_planning_{method}_{context_mode}_{args.boat}_{args.task}_{args.flow_type}_ep{ep:03d}.gif"
save_gif(frames, gif_dir / name, duration_ms=args.gif_duration_ms)
return summarize(method, args, results)
def summarize(method: str, args, results: list[dict]) -> dict:
groups = sorted({r["context_mode"] for r in results})
by_context = {}
def success_mean(items: list[dict], key: str) -> float | None:
successful = [r[key] for r in items if r["success"]]
return float(np.mean(successful)) if successful else None
for context in groups:
items = [r for r in results if r["context_mode"] == context]
by_context[context] = {
"episodes": len(items),
"successes": len([r for r in items if r["success"]]),
"success_rate": float(np.mean([r["success"] for r in items])),
"final_distance_mean": float(np.mean([r["final_distance"] for r in items])),
"mean_min_goal_distance": float(np.mean([r["mean_min_goal_distance"] for r in items])),
"path_length_success_mean": success_mean(items, "path_length"),
"energy_success_mean": success_mean(items, "energy"),
"steps_success_mean": success_mean(items, "steps"),
}
return {
"method": method,
"task": args.task,
"boat": args.boat,
"flow_type": args.flow_type,
"by_context": by_context,
"results": results,
}
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--methods", nargs="+", default=LEARNED_METHODS + TRADITIONAL_METHODS)
parser.add_argument("--task", choices=["reach_target", "station_keeping", "waypoint_square", "waypoint_zigzag"], default="reach_target")
parser.add_argument("--boat", choices=["twin", "triangle"], default="twin")
parser.add_argument("--flow-type", choices=["noflow", "uniform", "vortex_center", "double_gyre", "source_sink", "source_sink_pair", "gradient", "shear", "turbulent_patch", "random_fourier"], default="uniform")
parser.add_argument("--episodes", type=int, default=50)
parser.add_argument("--max-steps", type=int, default=420)
parser.add_argument("--history-len", type=int, default=32)
parser.add_argument("--image-size", type=int, default=160)
parser.add_argument("--visual-scale", type=float, default=2.5)
parser.add_argument("--checkpoint-name", default="paper.pt")
parser.add_argument("--context-modes", nargs="+", default=["inferred", "zero", "shuffled"])
parser.add_argument("--cem-horizon", type=int, default=45)
parser.add_argument("--cem-population", type=int, default=512)
parser.add_argument("--cem-elites", type=int, default=64)
parser.add_argument("--cem-iterations", type=int, default=4)
parser.add_argument("--cem-action-std", type=float, default=0.5)
parser.add_argument("--cem-knots", type=int, default=10)
parser.add_argument("--cem-w-goal", type=float, default=6.0)
parser.add_argument("--cem-w-path", type=float, default=0.2)
parser.add_argument("--cem-w-route", type=float, default=6.0)
parser.add_argument("--cem-w-lookahead", type=float, default=2.0)
parser.add_argument("--cem-w-via", type=float, default=2.0)
parser.add_argument("--cem-route-horizon-distance", type=float, default=3.0)
parser.add_argument("--cem-w-heading-goal", type=float, default=0.0)
parser.add_argument("--cem-w-action", type=float, default=0.08)
parser.add_argument("--cem-w-smooth", type=float, default=0.08)
parser.add_argument("--cem-w-boundary", type=float, default=250.0)
parser.add_argument("--cem-boundary-margin", type=float, default=0.75)
parser.add_argument("--cem-w-progress", type=float, default=2.0)
parser.add_argument("--success-radius", type=float, default=0.65)
parser.add_argument("--make-gifs", type=int, default=3)
parser.add_argument("--gif-stride", type=int, default=1)
parser.add_argument("--gif-duration-ms", type=int, default=55)
parser.add_argument("--seed", type=int, default=33)
parser.add_argument("--device", default="cuda")
parser.add_argument("--precision", choices=["fp32", "bf16", "fp16"], default="fp32")
parser.add_argument("--out", default="experiments/reports/paper_planning")
args = parser.parse_args()
out_dir = Path(args.out)
out_dir.mkdir(parents=True, exist_ok=True)
payload = [evaluate_one_method(method, args) for method in args.methods]
out_path = out_dir / f"{args.task}_{args.boat}_{args.flow_type}.json"
out_path.write_text(json.dumps(payload, indent=2))
print(json.dumps(payload, indent=2))
if __name__ == "__main__":
main()