FlowMo-WM / tests /test_image_world_model_interfaces.py
cccat6's picture
Initial FlowMo-WM public code release
604e535 verified
from __future__ import annotations
import importlib
import numpy as np
import torch
import torch.nn.functional as F
from experiments.shared.src.methods import PAPER_LEARNED_METHODS
from experiments.shared.src.vision.clean_renderer import render_clean_boat_array
def make_batch(
batch_size: int = 2,
history_len: int = 32,
horizon: int = 4,
image_size: int = 160,
action_dim: int = 3,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
rng = np.random.default_rng(11)
images = np.zeros((batch_size, history_len, 3, image_size, image_size), dtype=np.uint8)
actions = rng.uniform(-0.3, 0.3, size=(batch_size, history_len, action_dim)).astype(np.float32)
future_actions = rng.uniform(-0.3, 0.3, size=(batch_size, horizon, action_dim)).astype(np.float32)
targets = np.zeros((batch_size, horizon, 4), dtype=np.float32)
for b in range(batch_size):
x = float(rng.uniform(2.0, 8.0))
y = float(rng.uniform(2.0, 8.0))
theta = float(rng.uniform(-np.pi, np.pi))
for t in range(history_len):
state = np.array([x + 0.01 * t, y, theta + 0.005 * t, 0.0, 0.0, 0.0], dtype=np.float32)
image = render_clean_boat_array(state, "twin", image_size=image_size, visual_scale=2.5)
images[b, t] = np.transpose(image, (2, 0, 1))
for h in range(horizon):
theta_h = theta + 0.005 * (history_len + h)
targets[b, h] = np.array(
[x + 0.01 * (history_len + h), y, np.cos(theta_h), np.sin(theta_h)],
dtype=np.float32,
)
return (
torch.from_numpy(images),
torch.from_numpy(actions),
torch.from_numpy(future_actions),
torch.from_numpy(targets),
)
def test_all_paper_image_world_models_train_step() -> None:
batch = make_batch()
images, actions, future_actions, targets = batch
for method in PAPER_LEARNED_METHODS:
cfg = importlib.import_module(f"experiments.{method}.src.config").default_config()
model = importlib.import_module(f"experiments.{method}.src.model").build_model(cfg)
pred = model.rollout(images, actions, future_actions)
assert pred.shape == targets.shape
loss = F.mse_loss(pred, targets)
loss.backward()
assert all(p.grad is not None for p in model.parameters() if p.requires_grad)
def test_flowmo_context_residual_is_action_separated() -> None:
cfg = importlib.import_module("experiments.flowmo.src.config").default_config()
model = importlib.import_module("experiments.flowmo.src.model").build_model(cfg)
z = torch.randn(4, cfg.z_dim)
c = torch.randn(4, cfg.c_dim)
a1 = torch.randn(4, cfg.action_dim).clamp(-1.0, 1.0)
a2 = torch.randn(4, cfg.action_dim).clamp(-1.0, 1.0)
z1 = model.step(z, a1, c) - model.step(z, a1, torch.zeros_like(c))
z2 = model.step(z, a2, c) - model.step(z, a2, torch.zeros_like(c))
assert torch.allclose(z1, z2, atol=1e-6)