| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| def swish(x): |
| return x * torch.sigmoid(x) |
|
|
|
|
| class SinusoidalPositionalEncoding(nn.Module): |
| """ |
| Produces a sinusoidal encoding of shape (B, T, w) |
| given timesteps of shape (B, T). |
| """ |
|
|
| def __init__(self, embedding_dim): |
| super().__init__() |
| self.embedding_dim = embedding_dim |
|
|
| def forward(self, timesteps): |
| |
| |
| timesteps = timesteps.float() |
|
|
| B, T = timesteps.shape |
| device = timesteps.device |
|
|
| half_dim = self.embedding_dim // 2 |
| |
| exponent = -torch.arange(half_dim, dtype=torch.float, device=device) * ( |
| torch.log(torch.tensor(10000.0)) / half_dim |
| ) |
| |
| freqs = timesteps.unsqueeze(-1) * exponent.exp() |
|
|
| sin = torch.sin(freqs) |
| cos = torch.cos(freqs) |
| enc = torch.cat([sin, cos], dim=-1) |
|
|
| return enc |
|
|
|
|
| class ActionEncoder(nn.Module): |
| def __init__(self, action_dim, hidden_size): |
| super().__init__() |
| self.hidden_size = hidden_size |
|
|
| |
| self.W1 = nn.Linear(action_dim, hidden_size) |
| self.W2 = nn.Linear(2 * hidden_size, hidden_size) |
| self.W3 = nn.Linear(hidden_size, hidden_size) |
|
|
| self.pos_encoding = SinusoidalPositionalEncoding(hidden_size) |
|
|
| def forward(self, actions, timesteps): |
| """ |
| actions: shape (B, T, action_dim) |
| timesteps: shape (B,) -- a single scalar per batch item |
| returns: shape (B, T, hidden_size) |
| """ |
| B, T, _ = actions.shape |
|
|
| |
| |
| |
| if timesteps.dim() == 1 and timesteps.shape[0] == B: |
| |
| timesteps = timesteps.unsqueeze(1).expand(-1, T) |
| else: |
| raise ValueError( |
| "Expected `timesteps` to have shape (B,) so we can replicate across T." |
| ) |
|
|
| |
| a_emb = self.W1(actions) |
|
|
| |
| tau_emb = self.pos_encoding(timesteps).to(dtype=a_emb.dtype) |
|
|
| |
| x = torch.cat([a_emb, tau_emb], dim=-1) |
| x = swish(self.W2(x)) |
|
|
| |
| x = self.W3(x) |
|
|
| return x |
|
|