from __future__ import annotations import importlib import numpy as np import torch import torch.nn.functional as F from experiments.shared.src.methods import PAPER_LEARNED_METHODS from experiments.shared.src.vision.clean_renderer import render_clean_boat_array def make_batch( batch_size: int = 2, history_len: int = 32, horizon: int = 4, image_size: int = 160, action_dim: int = 3, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: rng = np.random.default_rng(11) images = np.zeros((batch_size, history_len, 3, image_size, image_size), dtype=np.uint8) actions = rng.uniform(-0.3, 0.3, size=(batch_size, history_len, action_dim)).astype(np.float32) future_actions = rng.uniform(-0.3, 0.3, size=(batch_size, horizon, action_dim)).astype(np.float32) targets = np.zeros((batch_size, horizon, 4), dtype=np.float32) for b in range(batch_size): x = float(rng.uniform(2.0, 8.0)) y = float(rng.uniform(2.0, 8.0)) theta = float(rng.uniform(-np.pi, np.pi)) for t in range(history_len): state = np.array([x + 0.01 * t, y, theta + 0.005 * t, 0.0, 0.0, 0.0], dtype=np.float32) image = render_clean_boat_array(state, "twin", image_size=image_size, visual_scale=2.5) images[b, t] = np.transpose(image, (2, 0, 1)) for h in range(horizon): theta_h = theta + 0.005 * (history_len + h) targets[b, h] = np.array( [x + 0.01 * (history_len + h), y, np.cos(theta_h), np.sin(theta_h)], dtype=np.float32, ) return ( torch.from_numpy(images), torch.from_numpy(actions), torch.from_numpy(future_actions), torch.from_numpy(targets), ) def test_all_paper_image_world_models_train_step() -> None: batch = make_batch() images, actions, future_actions, targets = batch for method in PAPER_LEARNED_METHODS: cfg = importlib.import_module(f"experiments.{method}.src.config").default_config() model = importlib.import_module(f"experiments.{method}.src.model").build_model(cfg) pred = model.rollout(images, actions, future_actions) assert pred.shape == targets.shape loss = F.mse_loss(pred, targets) loss.backward() assert all(p.grad is not None for p in model.parameters() if p.requires_grad) def test_flowmo_context_residual_is_action_separated() -> None: cfg = importlib.import_module("experiments.flowmo.src.config").default_config() model = importlib.import_module("experiments.flowmo.src.model").build_model(cfg) z = torch.randn(4, cfg.z_dim) c = torch.randn(4, cfg.c_dim) a1 = torch.randn(4, cfg.action_dim).clamp(-1.0, 1.0) a2 = torch.randn(4, cfg.action_dim).clamp(-1.0, 1.0) z1 = model.step(z, a1, c) - model.step(z, a1, torch.zeros_like(c)) z2 = model.step(z, a2, c) - model.step(z, a2, torch.zeros_like(c)) assert torch.allclose(z1, z2, atol=1e-6)