File size: 2,983 Bytes
604e535
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from __future__ import annotations

import importlib

import numpy as np
import torch
import torch.nn.functional as F

from experiments.shared.src.methods import PAPER_LEARNED_METHODS
from experiments.shared.src.vision.clean_renderer import render_clean_boat_array


def make_batch(
    batch_size: int = 2,
    history_len: int = 32,
    horizon: int = 4,
    image_size: int = 160,
    action_dim: int = 3,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    rng = np.random.default_rng(11)
    images = np.zeros((batch_size, history_len, 3, image_size, image_size), dtype=np.uint8)
    actions = rng.uniform(-0.3, 0.3, size=(batch_size, history_len, action_dim)).astype(np.float32)
    future_actions = rng.uniform(-0.3, 0.3, size=(batch_size, horizon, action_dim)).astype(np.float32)
    targets = np.zeros((batch_size, horizon, 4), dtype=np.float32)
    for b in range(batch_size):
        x = float(rng.uniform(2.0, 8.0))
        y = float(rng.uniform(2.0, 8.0))
        theta = float(rng.uniform(-np.pi, np.pi))
        for t in range(history_len):
            state = np.array([x + 0.01 * t, y, theta + 0.005 * t, 0.0, 0.0, 0.0], dtype=np.float32)
            image = render_clean_boat_array(state, "twin", image_size=image_size, visual_scale=2.5)
            images[b, t] = np.transpose(image, (2, 0, 1))
        for h in range(horizon):
            theta_h = theta + 0.005 * (history_len + h)
            targets[b, h] = np.array(
                [x + 0.01 * (history_len + h), y, np.cos(theta_h), np.sin(theta_h)],
                dtype=np.float32,
            )
    return (
        torch.from_numpy(images),
        torch.from_numpy(actions),
        torch.from_numpy(future_actions),
        torch.from_numpy(targets),
    )


def test_all_paper_image_world_models_train_step() -> None:
    batch = make_batch()
    images, actions, future_actions, targets = batch
    for method in PAPER_LEARNED_METHODS:
        cfg = importlib.import_module(f"experiments.{method}.src.config").default_config()
        model = importlib.import_module(f"experiments.{method}.src.model").build_model(cfg)
        pred = model.rollout(images, actions, future_actions)
        assert pred.shape == targets.shape
        loss = F.mse_loss(pred, targets)
        loss.backward()
        assert all(p.grad is not None for p in model.parameters() if p.requires_grad)


def test_flowmo_context_residual_is_action_separated() -> None:
    cfg = importlib.import_module("experiments.flowmo.src.config").default_config()
    model = importlib.import_module("experiments.flowmo.src.model").build_model(cfg)
    z = torch.randn(4, cfg.z_dim)
    c = torch.randn(4, cfg.c_dim)
    a1 = torch.randn(4, cfg.action_dim).clamp(-1.0, 1.0)
    a2 = torch.randn(4, cfg.action_dim).clamp(-1.0, 1.0)
    z1 = model.step(z, a1, c) - model.step(z, a1, torch.zeros_like(c))
    z2 = model.step(z, a2, c) - model.step(z, a2, torch.zeros_like(c))
    assert torch.allclose(z1, z2, atol=1e-6)