FlowMo-WM / experiments /train_image_world_models.py
cccat6's picture
Update FlowMo-WM code and static flow protocol
ccf9f1b verified
"""Train and evaluate all image-input world models locally."""
from __future__ import annotations
import argparse
import gc
import random
import importlib
import json
from contextlib import nullcontext
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from experiments.shared.src.data.image_dataset import ImageTrajectoryDataset
from experiments.shared.src.methods import PAPER_LEARNED_METHODS
from experiments.shared.src.utils.parameter_count import save_parameter_count
from experiments.shared.src.vision.clean_renderer import render_clean_boat_history_tensor
METHODS = PAPER_LEARNED_METHODS
POSITION_SCALE = 5.0
def build_method(method: str):
config_module = importlib.import_module(f"experiments.{method}.src.config")
model_module = importlib.import_module(f"experiments.{method}.src.model")
cfg = config_module.default_config()
return cfg, model_module.build_model(cfg)
def loader_kwargs(num_workers: int) -> dict:
if num_workers <= 0:
return {}
return {
"multiprocessing_context": "spawn",
"persistent_workers": True,
"prefetch_factor": 4,
}
def dataloader_pin_memory(device: torch.device) -> bool:
return device.type == "cuda"
def configure_training_runtime(device: torch.device) -> None:
if hasattr(torch, "set_float32_matmul_precision"):
torch.set_float32_matmul_precision("high")
if device.type == "cuda" and hasattr(torch.backends, "cudnn"):
torch.backends.cudnn.benchmark = True
def autocast_context(device: torch.device, precision: str):
if device.type != "cuda" or precision == "fp32":
return nullcontext()
dtype = torch.bfloat16 if precision == "bf16" else torch.float16
return torch.autocast(device_type="cuda", dtype=dtype)
def method_seed(base_seed: int, method: str) -> int:
return int(base_seed) + sum((idx + 1) * ord(char) for idx, char in enumerate(method))
def set_training_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def required_model_history(model, default_history_len: int) -> int:
config = getattr(model, "config", None)
history_len = int(getattr(config, "history_len", default_history_len))
context_len = int(getattr(config, "context_len", history_len))
return min(int(default_history_len), max(history_len, context_len))
def selected_history_indices(model, default_history_len: int) -> list[int]:
if hasattr(model, "selected_history_indices"):
return [int(i) for i in model.selected_history_indices(default_history_len)]
needed = required_model_history(model, default_history_len)
return list(range(default_history_len - needed, default_history_len))
def prepare_batch(batch, args, device: torch.device):
observation_hist, actions, future_actions, targets, origin, prev_origin, flow_type_id, boat_id = batch
history_indices = getattr(args, "history_indices", None)
if history_indices is None:
model_history_len = int(getattr(args, "model_history_len", observation_hist.shape[1]))
observation_hist = observation_hist[:, -model_history_len:]
actions = actions[:, -model_history_len:]
else:
observation_hist = observation_hist[:, history_indices]
actions = actions[:, history_indices]
actions = actions.to(device, non_blocking=True)
future_actions = future_actions.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
origin = origin.to(device, non_blocking=True)
if args.render_mode == "device":
states = observation_hist.to(device, non_blocking=True)
boat_id = boat_id.to(device, non_blocking=True)
images = render_clean_boat_history_tensor(
states,
boat_id,
image_size=args.image_size,
visual_scale=args.visual_scale,
)
else:
images = observation_hist.to(device, non_blocking=True)
return images, actions, future_actions, targets, origin, prev_origin, flow_type_id, boat_id
def rollout_from_encoded(model, z: torch.Tensor, c: torch.Tensor, future_actions: torch.Tensor) -> torch.Tensor:
preds = []
cur = z
for t in range(future_actions.shape[1]):
cur = model.step(cur, future_actions[:, t], c)
preds.append(model.decoder(cur))
return torch.stack(preds, dim=1)
def save_eval_checkpoint(model, checkpoint_dir: Path, checkpoint_name: str, step: int | None = None) -> str:
checkpoint_dir.mkdir(parents=True, exist_ok=True)
path = checkpoint_dir / checkpoint_name
if step is not None:
stem = Path(checkpoint_name).stem
suffix = Path(checkpoint_name).suffix or ".pt"
path = checkpoint_dir / f"{stem}_step_{step:06d}{suffix}"
torch.save(model.state_dict(), path)
return path.name
def train_method(method: str, args) -> dict[str, float | int | list[float]]:
set_training_seed(method_seed(args.seed, method))
cfg, model = build_method(method)
device = torch.device(args.device)
configure_training_runtime(device)
args.model_history_len = required_model_history(model, args.history_len)
args.history_indices = selected_history_indices(model, args.history_len)
model.to(device)
if device.type == "cuda":
model.to(memory_format=torch.channels_last)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
scaler = torch.amp.GradScaler("cuda", enabled=(device.type == "cuda" and args.precision == "fp16"))
out_dir = Path("experiments") / method
checkpoint_dir = out_dir / "checkpoint"
result_dir = out_dir / "result"
checkpoint_dir.mkdir(parents=True, exist_ok=True)
result_dir.mkdir(parents=True, exist_ok=True)
trace_path = result_dir / f"{Path(args.checkpoint_name).stem}_training_trace.jsonl"
trace_path.write_text("")
train_ds = ImageTrajectoryDataset(
args.train_source,
history_len=args.history_len,
horizon=args.horizon,
episodes=args.train_episodes,
image_size=args.image_size,
visual_scale=args.visual_scale,
max_windows=args.train_windows,
seed=args.seed,
return_aux=True,
render_images=args.render_mode == "dataset",
)
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=dataloader_pin_memory(device),
drop_last=True,
**loader_kwargs(args.num_workers),
)
logged_losses = []
model.train()
step = 0
running_loss = torch.zeros((), device=device)
running_count = 0
saved_checkpoints: list[str] = []
while step < args.steps:
for batch in train_loader:
step += 1
images, actions, future_actions, targets, origin, _prev_origin, _flow_type_id, _boat_id = prepare_batch(batch, args, device)
train_targets = encode_targets(targets, origin, args.target_mode)
with autocast_context(device, args.precision):
z, c = model.encode(images, actions)
pred = rollout_from_encoded(model, z, c, future_actions)
loss = weighted_pose_loss(pred.float(), train_targets.float(), args.heading_weight)
if args.motion_weight > 0.0:
pred_abs = decode_predictions(pred.float(), origin, args.target_mode)
loss = loss + args.motion_weight * motion_delta_loss(pred_abs, targets)
if args.current_pose_weight > 0.0:
current_target = encode_absolute_pose(origin)
loss = loss + args.current_pose_weight * weighted_pose_loss(
model.decoder(z).float(),
current_target,
args.heading_weight,
)
optimizer.zero_grad(set_to_none=True)
if scaler.is_enabled():
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
optimizer.step()
running_loss = running_loss + loss.detach()
running_count += 1
if step % args.log_every == 0:
mean_loss = float((running_loss / max(running_count, 1)).item())
logged_losses.append(mean_loss)
print(f"{method} step {step:05d} loss={mean_loss:.5f}", flush=True)
with trace_path.open("a") as f:
f.write(json.dumps({"method": method, "step": int(step), "loss": mean_loss}) + "\n")
running_loss = torch.zeros((), device=device)
running_count = 0
if args.checkpoint_interval > 0 and step % args.checkpoint_interval == 0:
saved_checkpoints.append(save_eval_checkpoint(model, checkpoint_dir, args.checkpoint_name, step))
if step >= args.steps:
break
if running_count:
mean_loss = float((running_loss / running_count).item())
logged_losses.append(mean_loss)
with trace_path.open("a") as f:
f.write(json.dumps({"method": method, "step": int(step), "loss": mean_loss}) + "\n")
final_checkpoint = save_eval_checkpoint(model, checkpoint_dir, args.checkpoint_name)
counts = save_parameter_count(model, result_dir / "parameter_count.json")
del train_loader, train_ds
gc.collect()
test_ds = ImageTrajectoryDataset(
args.test_source,
history_len=args.history_len,
horizon=args.horizon,
episodes=args.test_episodes,
max_windows=args.test_windows,
seed=args.seed + 1,
image_size=args.image_size,
visual_scale=args.visual_scale,
return_aux=True,
render_images=args.render_mode == "dataset",
)
test_loader = DataLoader(
test_ds,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=dataloader_pin_memory(device),
**loader_kwargs(args.num_workers),
)
metrics = evaluate_model(model, test_loader, device, args.horizon, args.target_mode, args)
result = {
"method": method,
"steps": int(args.steps),
"batch_size": int(args.batch_size),
"train_samples": int(args.steps * args.batch_size),
"final_train_loss": float(logged_losses[-1]),
"total_parameters": int(counts["total"]),
"target_mode": args.target_mode,
"position_scale": POSITION_SCALE,
"heading_weight": float(args.heading_weight),
"current_pose_weight": float(args.current_pose_weight),
"motion_weight": float(args.motion_weight),
"precision": args.precision,
"checkpoint_name": args.checkpoint_name,
"final_checkpoint": final_checkpoint,
"intermediate_checkpoints": saved_checkpoints,
"checkpoint_interval": int(args.checkpoint_interval),
"prediction": metrics,
}
result_name = f"{Path(args.checkpoint_name).stem}_training.json"
(result_dir / result_name).write_text(json.dumps(result, indent=2))
del test_loader, test_ds, model
gc.collect()
if device.type == "cuda":
torch.cuda.empty_cache()
return result
@torch.no_grad()
def evaluate_model(model, loader, device: torch.device, horizon: int, target_mode: str, args) -> dict[str, float]:
model.eval()
pos_sums = np.zeros(horizon, dtype=np.float64)
heading_sums = np.zeros(horizon, dtype=np.float64)
count = 0
for batch in loader:
images, actions, future_actions, targets, origin, _prev_origin, _flow_type_id, _boat_id = prepare_batch(batch, args, device)
with autocast_context(device, args.precision):
encoded = model.rollout(images, actions, future_actions)
pred = decode_predictions(encoded.float(), origin, target_mode)
pos = torch.linalg.norm(pred[..., :2] - targets[..., :2], dim=-1)
pred_angle = torch.atan2(pred[..., 3], pred[..., 2])
target_angle = torch.atan2(targets[..., 3], targets[..., 2])
heading = torch.atan2(torch.sin(pred_angle - target_angle), torch.cos(pred_angle - target_angle)).abs()
pos_sums += pos.sum(dim=0).cpu().numpy()
heading_sums += heading.sum(dim=0).cpu().numpy()
count += int(images.shape[0])
pos_mean = pos_sums / count
heading_mean = heading_sums / count
result: dict[str, float] = {}
for step in [1, 3, 6, 8, 10, 20]:
if horizon >= step:
result[f"pos{step}"] = float(pos_mean[step - 1])
result[f"heading{step}"] = float(heading_mean[step - 1])
return result
def encode_targets(targets: torch.Tensor, origin: torch.Tensor, target_mode: str) -> torch.Tensor:
if target_mode == "absolute_normalized":
return encode_absolute_pose(targets)
if target_mode == "relative_motion":
rel_xy = (targets[..., :2] - origin[:, None, :2]) / POSITION_SCALE
origin_angle = torch.atan2(origin[:, 3], origin[:, 2])
target_angle = torch.atan2(targets[..., 3], targets[..., 2])
delta = target_angle - origin_angle[:, None]
rel_heading = torch.stack([torch.cos(delta), torch.sin(delta)], dim=-1)
return torch.cat([rel_xy, rel_heading], dim=-1)
raise ValueError(f"unknown target_mode: {target_mode}")
def encode_absolute_pose(obs: torch.Tensor) -> torch.Tensor:
xy = (obs[..., :2] - POSITION_SCALE) / POSITION_SCALE
return torch.cat([xy, obs[..., 2:4]], dim=-1)
def decode_predictions(predictions: torch.Tensor, origin: torch.Tensor, target_mode: str) -> torch.Tensor:
if target_mode == "absolute_normalized":
xy = predictions[..., :2] * POSITION_SCALE + POSITION_SCALE
return torch.cat([xy, predictions[..., 2:4]], dim=-1)
if target_mode == "relative_motion":
xy = predictions[..., :2] * POSITION_SCALE + origin[:, None, :2]
origin_angle = torch.atan2(origin[:, 3], origin[:, 2])
delta = torch.atan2(predictions[..., 3], predictions[..., 2])
angle = origin_angle[:, None] + delta
heading = torch.stack([torch.cos(angle), torch.sin(angle)], dim=-1)
return torch.cat([xy, heading], dim=-1)
raise ValueError(f"unknown target_mode: {target_mode}")
def weighted_pose_loss(predictions: torch.Tensor, targets: torch.Tensor, heading_weight: float) -> torch.Tensor:
pos_loss = F.mse_loss(predictions[..., :2], targets[..., :2])
heading_loss = F.mse_loss(predictions[..., 2:4], targets[..., 2:4])
return pos_loss + heading_weight * heading_loss
def motion_delta_loss(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
pred_delta = predictions[:, 1:, :2] - predictions[:, :-1, :2]
target_delta = targets[:, 1:, :2] - targets[:, :-1, :2]
pred_angle = torch.atan2(predictions[..., 3], predictions[..., 2])
target_angle = torch.atan2(targets[..., 3], targets[..., 2])
pred_turn = torch.atan2(
torch.sin(pred_angle[:, 1:] - pred_angle[:, :-1]),
torch.cos(pred_angle[:, 1:] - pred_angle[:, :-1]),
)
target_turn = torch.atan2(
torch.sin(target_angle[:, 1:] - target_angle[:, :-1]),
torch.cos(target_angle[:, 1:] - target_angle[:, :-1]),
)
return F.mse_loss(pred_delta, target_delta) + 0.2 * F.mse_loss(pred_turn, target_turn)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--methods", nargs="+", default=METHODS)
parser.add_argument("--train-source", default="data/paper/train.npz")
parser.add_argument("--test-source", default="data/paper/test.npz")
parser.add_argument("--train-episodes", type=int, default=512)
parser.add_argument("--test-episodes", type=int, default=256)
parser.add_argument("--history-len", type=int, default=32)
parser.add_argument("--horizon", type=int, default=20)
parser.add_argument("--train-windows", type=int, default=65536)
parser.add_argument("--test-windows", type=int, default=8192)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--steps", type=int, default=16000)
parser.add_argument("--lr", type=float, default=3e-4)
parser.add_argument("--log-every", type=int, default=200)
parser.add_argument("--seed", type=int, default=19)
parser.add_argument("--device", default="cuda")
parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument("--episode-chunk-size", type=int, default=64)
parser.add_argument("--image-size", type=int, default=160)
parser.add_argument("--visual-scale", type=float, default=2.5)
parser.add_argument("--render-mode", choices=["device", "dataset"], default="device")
parser.add_argument("--precision", choices=["fp32", "bf16", "fp16"], default="fp32")
parser.add_argument("--target-mode", choices=["absolute_normalized", "relative_motion"], default="absolute_normalized")
parser.add_argument("--heading-weight", type=float, default=2.0)
parser.add_argument("--current-pose-weight", type=float, default=1.0)
parser.add_argument("--motion-weight", type=float, default=0.0)
parser.add_argument("--checkpoint-name", default="paper.pt")
parser.add_argument("--checkpoint-interval", type=int, default=2000)
args = parser.parse_args()
results = [train_method(method, args) for method in args.methods]
print(json.dumps(results, indent=2))
if __name__ == "__main__":
main()