| import torch | |
| from minidreamer.models.world_model import WorldModel | |
| def test_world_model_sequence_shapes_and_loss(): | |
| torch.manual_seed(0) | |
| model = WorldModel( | |
| action_dim=7, | |
| embedding_dim=128, | |
| deter_dim=128, | |
| stoch_dim=16, | |
| hidden_dim=128, | |
| use_decoder=True, | |
| ) | |
| obs = torch.rand(4, 33, 3, 64, 64) | |
| actions = torch.randint(0, 7, (4, 32)) | |
| outputs = model.observe_sequence(obs, actions, sample=False) | |
| assert outputs.reward_pred.shape == (4, 32) | |
| assert outputs.done_logits.shape == (4, 32) | |
| assert outputs.prior_mean.shape == (4, 32, 16) | |
| assert outputs.reconstructions is not None | |
| assert outputs.reconstructions.shape == (4, 32, 3, 64, 64) | |
| batch = { | |
| "obs": obs, | |
| "actions": actions, | |
| "rewards": torch.zeros(4, 32), | |
| "done": torch.zeros(4, 32), | |
| "mask": torch.ones(4, 32), | |
| } | |
| config = { | |
| "training": { | |
| "beta_reward": 1.0, | |
| "beta_done": 1.0, | |
| "beta_kl": 1.0, | |
| "beta_recon": 1.0, | |
| "free_nats": 1.0, | |
| } | |
| } | |
| losses = model.compute_losses(batch, config) | |
| assert torch.isfinite(losses["loss"]) | |