"""RSSM (Recurrent State-Space Model) for Grid-JEPA.""" from typing import Dict, Tuple import torch import torch.nn as nn import torch.nn.functional as F from torch.distributions import OneHotCategorical class RSSM(nn.Module): def __init__(self, embed_dim=384, latent_dim=32, latent_classes=32, hidden_dim=256, action_dim=64, num_actions=6, obs_dim=384, reward_bins=41): super().__init__() self.latent_dim = latent_dim self.latent_classes = latent_classes self.hidden_dim = hidden_dim self.action_embed = nn.Embedding(num_actions, action_dim) self.seq_model = nn.GRUCell(latent_dim * latent_classes + action_dim, hidden_dim) self.encoder = nn.Sequential( nn.Linear(obs_dim + hidden_dim, hidden_dim), nn.ELU(), nn.Linear(hidden_dim, latent_dim * latent_classes), ) self.dynamics = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.ELU(), nn.Linear(hidden_dim, latent_dim * latent_classes), ) self.reward_predictor = nn.Sequential( nn.Linear(hidden_dim + latent_dim * latent_classes, hidden_dim), nn.ELU(), nn.Linear(hidden_dim, reward_bins), ) self.continue_predictor = nn.Sequential( nn.Linear(hidden_dim + latent_dim * latent_classes, hidden_dim), nn.ELU(), nn.Linear(hidden_dim, 1), ) def init_state(self, B, device): return torch.zeros(B, self.hidden_dim, device=device), torch.zeros(B, self.latent_dim * self.latent_classes, device=device) def _sample_st(self, logits): dist = OneHotCategorical(logits=logits) sample = dist.sample() return sample + logits - logits.detach() def observe(self, x_t, a_tm1, h_tm1, z_tm1): B = x_t.shape[0] a_emb = self.action_embed(a_tm1) gru_in = torch.cat([z_tm1, a_emb], dim=-1) h_t = self.seq_model(gru_in, h_tm1) z_pred_logits = self.dynamics(h_t).reshape(B, self.latent_dim, self.latent_classes) enc_in = torch.cat([x_t, h_t], dim=-1) z_logits = self.encoder(enc_in).reshape(B, self.latent_dim, self.latent_classes) z_onehot = self._sample_st(z_logits) z_t = z_onehot.reshape(B, -1) return h_t, z_t, z_logits, z_pred_logits def imagine(self, h_t, z_t, a_t): B = h_t.shape[0] a_emb = self.action_embed(a_t) gru_in = torch.cat([z_t, a_emb], dim=-1) h_tp1 = self.seq_model(gru_in, h_t) z_pred_logits = self.dynamics(h_tp1).reshape(B, self.latent_dim, self.latent_classes) z_onehot = self._sample_st(z_pred_logits) z_tp1 = z_onehot.reshape(B, -1) return h_tp1, z_tp1, z_pred_logits def predict_reward(self, h, z): x = torch.cat([h, z], dim=-1) return self.reward_predictor(x) def predict_continue(self, h, z): x = torch.cat([h, z], dim=-1) return self.continue_predictor(x) def rollout(self, h_0, z_0, actions): B, H = actions.shape h_states, z_states, rewards, continues = [], [], [], [] h_t, z_t = h_0, z_0 for t in range(H): h_t, z_t, _ = self.imagine(h_t, z_t, actions[:, t]) h_states.append(h_t) z_states.append(z_t) rewards.append(self.predict_reward(h_t, z_t)) continues.append(self.predict_continue(h_t, z_t)) return { "h_states": torch.stack(h_states, 1), "z_states": torch.stack(z_states, 1), "rewards": torch.stack(rewards, 1), "continues": torch.stack(continues, 1), } if __name__ == "__main__": B, obs = 2, 128 device = "cuda" if torch.cuda.is_available() else "cpu" rssm = RSSM(obs_dim=obs).to(device) x = torch.randn(B, obs, device=device) a = torch.randint(0, 6, (B,), device=device) h, z = rssm.init_state(B, device) h, z, z_logits, z_pred = rssm.observe(x, a, h, z) print(f"h: {h.shape}, z: {z.shape}, z_logits: {z_logits.shape}") h2, z2, zpl = rssm.imagine(h, z, a) print(f"h2: {h2.shape}, z2: {z2.shape}") print(f"r: {rssm.predict_reward(h, z).shape}, c: {rssm.predict_continue(h, z).shape}") actions = torch.randint(0, 6, (B, 10), device=device) out = rssm.rollout(h, z, actions) print(f"Rollout: h={out['h_states'].shape}, z={out['z_states'].shape}, r={out['rewards'].shape}, c={out['continues'].shape}")