guychuk's picture
Add RSSM dynamics model
fe76377 verified
"""RSSM (Recurrent State-Space Model) for Grid-JEPA."""
from typing import Dict, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import OneHotCategorical
class RSSM(nn.Module):
def __init__(self, embed_dim=384, latent_dim=32, latent_classes=32, hidden_dim=256,
action_dim=64, num_actions=6, obs_dim=384, reward_bins=41):
super().__init__()
self.latent_dim = latent_dim
self.latent_classes = latent_classes
self.hidden_dim = hidden_dim
self.action_embed = nn.Embedding(num_actions, action_dim)
self.seq_model = nn.GRUCell(latent_dim * latent_classes + action_dim, hidden_dim)
self.encoder = nn.Sequential(
nn.Linear(obs_dim + hidden_dim, hidden_dim), nn.ELU(),
nn.Linear(hidden_dim, latent_dim * latent_classes),
)
self.dynamics = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim), nn.ELU(),
nn.Linear(hidden_dim, latent_dim * latent_classes),
)
self.reward_predictor = nn.Sequential(
nn.Linear(hidden_dim + latent_dim * latent_classes, hidden_dim), nn.ELU(),
nn.Linear(hidden_dim, reward_bins),
)
self.continue_predictor = nn.Sequential(
nn.Linear(hidden_dim + latent_dim * latent_classes, hidden_dim), nn.ELU(),
nn.Linear(hidden_dim, 1),
)
def init_state(self, B, device):
return torch.zeros(B, self.hidden_dim, device=device), torch.zeros(B, self.latent_dim * self.latent_classes, device=device)
def _sample_st(self, logits):
dist = OneHotCategorical(logits=logits)
sample = dist.sample()
return sample + logits - logits.detach()
def observe(self, x_t, a_tm1, h_tm1, z_tm1):
B = x_t.shape[0]
a_emb = self.action_embed(a_tm1)
gru_in = torch.cat([z_tm1, a_emb], dim=-1)
h_t = self.seq_model(gru_in, h_tm1)
z_pred_logits = self.dynamics(h_t).reshape(B, self.latent_dim, self.latent_classes)
enc_in = torch.cat([x_t, h_t], dim=-1)
z_logits = self.encoder(enc_in).reshape(B, self.latent_dim, self.latent_classes)
z_onehot = self._sample_st(z_logits)
z_t = z_onehot.reshape(B, -1)
return h_t, z_t, z_logits, z_pred_logits
def imagine(self, h_t, z_t, a_t):
B = h_t.shape[0]
a_emb = self.action_embed(a_t)
gru_in = torch.cat([z_t, a_emb], dim=-1)
h_tp1 = self.seq_model(gru_in, h_t)
z_pred_logits = self.dynamics(h_tp1).reshape(B, self.latent_dim, self.latent_classes)
z_onehot = self._sample_st(z_pred_logits)
z_tp1 = z_onehot.reshape(B, -1)
return h_tp1, z_tp1, z_pred_logits
def predict_reward(self, h, z):
x = torch.cat([h, z], dim=-1)
return self.reward_predictor(x)
def predict_continue(self, h, z):
x = torch.cat([h, z], dim=-1)
return self.continue_predictor(x)
def rollout(self, h_0, z_0, actions):
B, H = actions.shape
h_states, z_states, rewards, continues = [], [], [], []
h_t, z_t = h_0, z_0
for t in range(H):
h_t, z_t, _ = self.imagine(h_t, z_t, actions[:, t])
h_states.append(h_t)
z_states.append(z_t)
rewards.append(self.predict_reward(h_t, z_t))
continues.append(self.predict_continue(h_t, z_t))
return {
"h_states": torch.stack(h_states, 1),
"z_states": torch.stack(z_states, 1),
"rewards": torch.stack(rewards, 1),
"continues": torch.stack(continues, 1),
}
if __name__ == "__main__":
B, obs = 2, 128
device = "cuda" if torch.cuda.is_available() else "cpu"
rssm = RSSM(obs_dim=obs).to(device)
x = torch.randn(B, obs, device=device)
a = torch.randint(0, 6, (B,), device=device)
h, z = rssm.init_state(B, device)
h, z, z_logits, z_pred = rssm.observe(x, a, h, z)
print(f"h: {h.shape}, z: {z.shape}, z_logits: {z_logits.shape}")
h2, z2, zpl = rssm.imagine(h, z, a)
print(f"h2: {h2.shape}, z2: {z2.shape}")
print(f"r: {rssm.predict_reward(h, z).shape}, c: {rssm.predict_continue(h, z).shape}")
actions = torch.randint(0, 6, (B, 10), device=device)
out = rssm.rollout(h, z, actions)
print(f"Rollout: h={out['h_states'].shape}, z={out['z_states'].shape}, r={out['rewards'].shape}, c={out['continues'].shape}")