FlowMo-WM / experiments /evaluate_image_world_models.py
cccat6's picture
Update FlowMo-WM code and static flow protocol
ccf9f1b verified
"""Evaluate trained image-input world models on long open-loop rollouts."""
from __future__ import annotations
import argparse
import importlib
import json
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import DataLoader
from experiments.shared.src.data.image_dataset import ImageTrajectoryDataset
from experiments.shared.src.methods import PAPER_LEARNED_METHODS
from experiments.shared.src.vision.clean_renderer import render_clean_boat_history_tensor
from experiments.train_image_world_models import configure_training_runtime
from experiments.train_image_world_models import autocast_context
from experiments.train_image_world_models import decode_predictions
from experiments.train_image_world_models import required_model_history
from experiments.train_image_world_models import selected_history_indices
METHODS = PAPER_LEARNED_METHODS
def loader_kwargs(num_workers: int) -> dict:
if num_workers <= 0:
return {}
return {
"multiprocessing_context": "spawn",
"persistent_workers": True,
"prefetch_factor": 4,
}
def prepare_batch(batch, args, device: torch.device):
observation_hist, actions, future_actions, targets, origin, prev_origin, flow_type_id, boat_id = batch
history_indices = getattr(args, "history_indices", None)
if history_indices is None:
model_history_len = int(getattr(args, "model_history_len", observation_hist.shape[1]))
observation_hist = observation_hist[:, -model_history_len:]
actions = actions[:, -model_history_len:]
else:
observation_hist = observation_hist[:, history_indices]
actions = actions[:, history_indices]
actions = actions.to(device, non_blocking=True)
future_actions = future_actions.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
origin = origin.to(device, non_blocking=True)
if args.render_mode == "device":
states = observation_hist.to(device, non_blocking=True)
boat_id_device = boat_id.to(device, non_blocking=True)
images = render_clean_boat_history_tensor(
states,
boat_id_device,
image_size=args.image_size,
visual_scale=args.visual_scale,
)
else:
images = observation_hist.to(device, non_blocking=True)
return images, actions, future_actions, targets, origin, prev_origin, flow_type_id, boat_id
def build_method(method: str):
config_module = importlib.import_module(f"experiments.{method}.src.config")
model_module = importlib.import_module(f"experiments.{method}.src.model")
cfg = config_module.default_config()
return cfg, model_module.build_model(cfg)
def load_flow_names(source_npz: str) -> dict[int, str]:
src = np.load(source_npz, allow_pickle=False)
metadata = json.loads(str(src["metadata"]))
return {int(v): str(k) for k, v in metadata["flows"].items()}
def load_group_names(source_npz: str, key: str) -> dict[int, str]:
src = np.load(source_npz, allow_pickle=False)
metadata = json.loads(str(src["metadata"]))
return {int(v): str(k) for k, v in metadata[key].items()}
@torch.no_grad()
def rollout_with_context(model, images: torch.Tensor, actions: torch.Tensor, future_actions: torch.Tensor, mode: str) -> torch.Tensor:
z, c = model.encode(images, actions)
if mode == "zero":
c = torch.zeros_like(c)
elif mode == "shuffled":
c = c.roll(shifts=1, dims=0)
if hasattr(model, "rollout_with_context"):
return model.rollout_with_context(z, c, future_actions)
preds = []
cur = z
for t in range(future_actions.shape[1]):
cur = model.step(cur, future_actions[:, t], c)
preds.append(model.decoder(cur))
return torch.stack(preds, dim=1)
@torch.no_grad()
def evaluate_model(
model,
loader,
device: torch.device,
horizon: int,
target_mode: str,
flow_names: dict[int, str],
traj_names: dict[int, str],
boat_names: dict[int, str],
context_mode: str,
args,
) -> dict:
model.eval()
steps = [s for s in [1, 3, 6, 8, 10, 20, 30, 40, 60] if s <= horizon]
pos_sum = np.zeros(horizon, dtype=np.float64)
heading_sum = np.zeros(horizon, dtype=np.float64)
flow_pos: dict[int, np.ndarray] = {}
flow_heading: dict[int, np.ndarray] = {}
flow_count: dict[int, int] = {}
traj_pos: dict[int, np.ndarray] = {}
traj_heading: dict[int, np.ndarray] = {}
traj_count: dict[int, int] = {}
boat_pos: dict[int, np.ndarray] = {}
boat_heading: dict[int, np.ndarray] = {}
boat_count: dict[int, int] = {}
count = 0
cursor = 0
for batch in loader:
images, actions, future_actions, targets, origin, _prev_origin, flow_type_id, _boat_id = prepare_batch(batch, args, device)
with autocast_context(device, args.precision):
if context_mode == "inferred":
encoded = model.rollout(images, actions, future_actions)
else:
encoded = rollout_with_context(model, images, actions, future_actions, context_mode)
pred = decode_predictions(encoded.float(), origin, target_mode)
pos = torch.linalg.norm(pred[..., :2] - targets[..., :2], dim=-1)
pred_angle = torch.atan2(pred[..., 3], pred[..., 2])
target_angle = torch.atan2(targets[..., 3], targets[..., 2])
heading = torch.atan2(torch.sin(pred_angle - target_angle), torch.cos(pred_angle - target_angle)).abs()
pos_np = pos.cpu().numpy()
heading_np = heading.cpu().numpy()
pos_sum += pos_np.sum(axis=0)
heading_sum += heading_np.sum(axis=0)
count += int(pos_np.shape[0])
flow_np = flow_type_id.numpy()
batch_indices = loader.dataset.indices[cursor : cursor + int(pos_np.shape[0])]
cursor += int(pos_np.shape[0])
traj_np = np.array([loader.dataset.traj_type_ids[ep] for ep, _t in batch_indices], dtype=np.int64)
boat_np = np.array([loader.dataset.boat_ids[ep] for ep, _t in batch_indices], dtype=np.int64)
for flow_id in np.unique(flow_np):
mask = flow_np == flow_id
fid = int(flow_id)
flow_pos.setdefault(fid, np.zeros(horizon, dtype=np.float64))
flow_heading.setdefault(fid, np.zeros(horizon, dtype=np.float64))
flow_count[fid] = flow_count.get(fid, 0) + int(mask.sum())
flow_pos[fid] += pos_np[mask].sum(axis=0)
flow_heading[fid] += heading_np[mask].sum(axis=0)
for traj_id in np.unique(traj_np):
mask = traj_np == traj_id
tid = int(traj_id)
traj_pos.setdefault(tid, np.zeros(horizon, dtype=np.float64))
traj_heading.setdefault(tid, np.zeros(horizon, dtype=np.float64))
traj_count[tid] = traj_count.get(tid, 0) + int(mask.sum())
traj_pos[tid] += pos_np[mask].sum(axis=0)
traj_heading[tid] += heading_np[mask].sum(axis=0)
for boat_id in np.unique(boat_np):
mask = boat_np == boat_id
bid = int(boat_id)
boat_pos.setdefault(bid, np.zeros(horizon, dtype=np.float64))
boat_heading.setdefault(bid, np.zeros(horizon, dtype=np.float64))
boat_count[bid] = boat_count.get(bid, 0) + int(mask.sum())
boat_pos[bid] += pos_np[mask].sum(axis=0)
boat_heading[bid] += heading_np[mask].sum(axis=0)
result = summarize(pos_sum / count, heading_sum / count, steps)
by_flow = {}
for fid, n in sorted(flow_count.items()):
by_flow[flow_names.get(fid, str(fid))] = summarize(flow_pos[fid] / n, flow_heading[fid] / n, steps)
result["by_flow"] = by_flow
result["by_trajectory"] = {
traj_names.get(tid, str(tid)): summarize(traj_pos[tid] / n, traj_heading[tid] / n, steps)
for tid, n in sorted(traj_count.items())
}
result["by_boat"] = {
boat_names.get(bid, str(bid)): summarize(boat_pos[bid] / n, boat_heading[bid] / n, steps)
for bid, n in sorted(boat_count.items())
}
return result
def summarize(pos_mean: np.ndarray, heading_mean: np.ndarray, steps: list[int]) -> dict[str, float]:
result: dict[str, float] = {}
for step in steps:
result[f"pos{step}"] = float(pos_mean[step - 1])
result[f"heading{step}"] = float(heading_mean[step - 1])
return result
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--methods", nargs="+", default=METHODS)
parser.add_argument("--test-source", default="data/paper/test.npz")
parser.add_argument("--test-episodes", type=int, default=256)
parser.add_argument("--history-len", type=int, default=32)
parser.add_argument("--horizon", type=int, default=60)
parser.add_argument("--test-windows", type=int, default=4096)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--seed", type=int, default=20)
parser.add_argument("--device", default="cuda")
parser.add_argument("--target-mode", choices=["absolute_normalized", "relative_motion"], default="absolute_normalized")
parser.add_argument("--checkpoint-name", default="image_local.pt")
parser.add_argument("--out", default="experiments/reports/image_long_rollout_eval.json")
parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument("--image-size", type=int, default=160)
parser.add_argument("--visual-scale", type=float, default=2.5)
parser.add_argument("--render-mode", choices=["device", "dataset"], default="device")
parser.add_argument("--precision", choices=["fp32", "bf16", "fp16"], default="fp32")
args = parser.parse_args()
device = torch.device(args.device)
configure_training_runtime(device)
flow_names = load_flow_names(args.test_source)
traj_names = load_group_names(args.test_source, "trajectories")
boat_names = load_group_names(args.test_source, "boats")
ds = ImageTrajectoryDataset(
args.test_source,
history_len=args.history_len,
horizon=args.horizon,
episodes=args.test_episodes,
max_windows=args.test_windows,
seed=args.seed,
image_size=args.image_size,
visual_scale=args.visual_scale,
return_aux=True,
render_images=args.render_mode == "dataset",
)
loader = DataLoader(
ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=device.type == "cuda",
**loader_kwargs(args.num_workers),
)
payload = []
for method in args.methods:
_cfg, model = build_method(method)
state = torch.load(Path("experiments") / method / "checkpoint" / args.checkpoint_name, map_location="cpu")
model.load_state_dict(state)
model.to(device)
if device.type == "cuda":
model.to(memory_format=torch.channels_last)
args.model_history_len = required_model_history(model, args.history_len)
args.history_indices = selected_history_indices(model, args.history_len)
item = {
"method": method,
"inferred": evaluate_model(model, loader, device, args.horizon, args.target_mode, flow_names, traj_names, boat_names, "inferred", args),
}
if method == "flowmo":
item["context_zero"] = evaluate_model(model, loader, device, args.horizon, args.target_mode, flow_names, traj_names, boat_names, "zero", args)
item["context_shuffled"] = evaluate_model(model, loader, device, args.horizon, args.target_mode, flow_names, traj_names, boat_names, "shuffled", args)
payload.append(item)
out = Path(args.out)
out.parent.mkdir(parents=True, exist_ok=True)
out.write_text(json.dumps(payload, indent=2))
print(json.dumps(payload, indent=2))
if __name__ == "__main__":
main()