| """RSSM (Recurrent State-Space Model) for Grid-JEPA.""" |
| from typing import Dict, Tuple |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.distributions import OneHotCategorical |
|
|
| class RSSM(nn.Module): |
| def __init__(self, embed_dim=384, latent_dim=32, latent_classes=32, hidden_dim=256, |
| action_dim=64, num_actions=6, obs_dim=384, reward_bins=41): |
| super().__init__() |
| self.latent_dim = latent_dim |
| self.latent_classes = latent_classes |
| self.hidden_dim = hidden_dim |
| self.action_embed = nn.Embedding(num_actions, action_dim) |
| self.seq_model = nn.GRUCell(latent_dim * latent_classes + action_dim, hidden_dim) |
| self.encoder = nn.Sequential( |
| nn.Linear(obs_dim + hidden_dim, hidden_dim), nn.ELU(), |
| nn.Linear(hidden_dim, latent_dim * latent_classes), |
| ) |
| self.dynamics = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim), nn.ELU(), |
| nn.Linear(hidden_dim, latent_dim * latent_classes), |
| ) |
| self.reward_predictor = nn.Sequential( |
| nn.Linear(hidden_dim + latent_dim * latent_classes, hidden_dim), nn.ELU(), |
| nn.Linear(hidden_dim, reward_bins), |
| ) |
| self.continue_predictor = nn.Sequential( |
| nn.Linear(hidden_dim + latent_dim * latent_classes, hidden_dim), nn.ELU(), |
| nn.Linear(hidden_dim, 1), |
| ) |
|
|
| def init_state(self, B, device): |
| return torch.zeros(B, self.hidden_dim, device=device), torch.zeros(B, self.latent_dim * self.latent_classes, device=device) |
|
|
| def _sample_st(self, logits): |
| dist = OneHotCategorical(logits=logits) |
| sample = dist.sample() |
| return sample + logits - logits.detach() |
|
|
| def observe(self, x_t, a_tm1, h_tm1, z_tm1): |
| B = x_t.shape[0] |
| a_emb = self.action_embed(a_tm1) |
| gru_in = torch.cat([z_tm1, a_emb], dim=-1) |
| h_t = self.seq_model(gru_in, h_tm1) |
| z_pred_logits = self.dynamics(h_t).reshape(B, self.latent_dim, self.latent_classes) |
| enc_in = torch.cat([x_t, h_t], dim=-1) |
| z_logits = self.encoder(enc_in).reshape(B, self.latent_dim, self.latent_classes) |
| z_onehot = self._sample_st(z_logits) |
| z_t = z_onehot.reshape(B, -1) |
| return h_t, z_t, z_logits, z_pred_logits |
|
|
| def imagine(self, h_t, z_t, a_t): |
| B = h_t.shape[0] |
| a_emb = self.action_embed(a_t) |
| gru_in = torch.cat([z_t, a_emb], dim=-1) |
| h_tp1 = self.seq_model(gru_in, h_t) |
| z_pred_logits = self.dynamics(h_tp1).reshape(B, self.latent_dim, self.latent_classes) |
| z_onehot = self._sample_st(z_pred_logits) |
| z_tp1 = z_onehot.reshape(B, -1) |
| return h_tp1, z_tp1, z_pred_logits |
|
|
| def predict_reward(self, h, z): |
| x = torch.cat([h, z], dim=-1) |
| return self.reward_predictor(x) |
|
|
| def predict_continue(self, h, z): |
| x = torch.cat([h, z], dim=-1) |
| return self.continue_predictor(x) |
|
|
| def rollout(self, h_0, z_0, actions): |
| B, H = actions.shape |
| h_states, z_states, rewards, continues = [], [], [], [] |
| h_t, z_t = h_0, z_0 |
| for t in range(H): |
| h_t, z_t, _ = self.imagine(h_t, z_t, actions[:, t]) |
| h_states.append(h_t) |
| z_states.append(z_t) |
| rewards.append(self.predict_reward(h_t, z_t)) |
| continues.append(self.predict_continue(h_t, z_t)) |
| return { |
| "h_states": torch.stack(h_states, 1), |
| "z_states": torch.stack(z_states, 1), |
| "rewards": torch.stack(rewards, 1), |
| "continues": torch.stack(continues, 1), |
| } |
|
|
| if __name__ == "__main__": |
| B, obs = 2, 128 |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| rssm = RSSM(obs_dim=obs).to(device) |
| x = torch.randn(B, obs, device=device) |
| a = torch.randint(0, 6, (B,), device=device) |
| h, z = rssm.init_state(B, device) |
| h, z, z_logits, z_pred = rssm.observe(x, a, h, z) |
| print(f"h: {h.shape}, z: {z.shape}, z_logits: {z_logits.shape}") |
| h2, z2, zpl = rssm.imagine(h, z, a) |
| print(f"h2: {h2.shape}, z2: {z2.shape}") |
| print(f"r: {rssm.predict_reward(h, z).shape}, c: {rssm.predict_continue(h, z).shape}") |
| actions = torch.randint(0, 6, (B, 10), device=device) |
| out = rssm.rollout(h, z, actions) |
| print(f"Rollout: h={out['h_states'].shape}, z={out['z_states'].shape}, r={out['rewards'].shape}, c={out['continues'].shape}") |
|
|