| from __future__ import annotations |
|
|
| import importlib |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
| from experiments.shared.src.methods import PAPER_LEARNED_METHODS |
| from experiments.shared.src.vision.clean_renderer import render_clean_boat_array |
|
|
|
|
| def make_batch( |
| batch_size: int = 2, |
| history_len: int = 32, |
| horizon: int = 4, |
| image_size: int = 160, |
| action_dim: int = 3, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| rng = np.random.default_rng(11) |
| images = np.zeros((batch_size, history_len, 3, image_size, image_size), dtype=np.uint8) |
| actions = rng.uniform(-0.3, 0.3, size=(batch_size, history_len, action_dim)).astype(np.float32) |
| future_actions = rng.uniform(-0.3, 0.3, size=(batch_size, horizon, action_dim)).astype(np.float32) |
| targets = np.zeros((batch_size, horizon, 4), dtype=np.float32) |
| for b in range(batch_size): |
| x = float(rng.uniform(2.0, 8.0)) |
| y = float(rng.uniform(2.0, 8.0)) |
| theta = float(rng.uniform(-np.pi, np.pi)) |
| for t in range(history_len): |
| state = np.array([x + 0.01 * t, y, theta + 0.005 * t, 0.0, 0.0, 0.0], dtype=np.float32) |
| image = render_clean_boat_array(state, "twin", image_size=image_size, visual_scale=2.5) |
| images[b, t] = np.transpose(image, (2, 0, 1)) |
| for h in range(horizon): |
| theta_h = theta + 0.005 * (history_len + h) |
| targets[b, h] = np.array( |
| [x + 0.01 * (history_len + h), y, np.cos(theta_h), np.sin(theta_h)], |
| dtype=np.float32, |
| ) |
| return ( |
| torch.from_numpy(images), |
| torch.from_numpy(actions), |
| torch.from_numpy(future_actions), |
| torch.from_numpy(targets), |
| ) |
|
|
|
|
| def test_all_paper_image_world_models_train_step() -> None: |
| batch = make_batch() |
| images, actions, future_actions, targets = batch |
| for method in PAPER_LEARNED_METHODS: |
| cfg = importlib.import_module(f"experiments.{method}.src.config").default_config() |
| model = importlib.import_module(f"experiments.{method}.src.model").build_model(cfg) |
| pred = model.rollout(images, actions, future_actions) |
| assert pred.shape == targets.shape |
| loss = F.mse_loss(pred, targets) |
| loss.backward() |
| assert all(p.grad is not None for p in model.parameters() if p.requires_grad) |
|
|
|
|
| def test_flowmo_context_residual_is_action_separated() -> None: |
| cfg = importlib.import_module("experiments.flowmo.src.config").default_config() |
| model = importlib.import_module("experiments.flowmo.src.model").build_model(cfg) |
| z = torch.randn(4, cfg.z_dim) |
| c = torch.randn(4, cfg.c_dim) |
| a1 = torch.randn(4, cfg.action_dim).clamp(-1.0, 1.0) |
| a2 = torch.randn(4, cfg.action_dim).clamp(-1.0, 1.0) |
| z1 = model.step(z, a1, c) - model.step(z, a1, torch.zeros_like(c)) |
| z2 = model.step(z, a2, c) - model.step(z, a2, torch.zeros_like(c)) |
| assert torch.allclose(z1, z2, atol=1e-6) |
|
|