Spaces:
Build error
Build error
| import torch.nn as nn | |
| from torchtyping import TensorType | |
| from typeguard import typechecked | |
| import torch | |
| class EncoderDecoder(nn.Module): | |
| def __init__(self, input_dim=3, hidden_size=128, output_seq_len=10, output_shape=None): | |
| """ | |
| Encoder-Decoder model for sequence-to-sequence prediction. | |
| Args: | |
| input_dim (int): Dimension of input features (e.g., [mass, angle, friction]) | |
| hidden_size (int): Size of hidden representation | |
| output_seq_len (int): Number of time steps (T) | |
| output_shape (tuple, optional): Output frame shape (H, W) for image prediction mode. | |
| If None, model operates in coordinate prediction mode. | |
| """ | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| self.output_seq_len = output_seq_len # T | |
| self.output_shape = output_shape # (H, W) or None for coordinate mode | |
| self.is_coordinate_mode = output_shape is None | |
| # --- Encoder --- | |
| self.encoder = nn.Sequential( | |
| nn.Linear(input_dim, hidden_size), | |
| nn.ReLU(), | |
| nn.Linear(hidden_size, hidden_size), | |
| nn.ReLU(), | |
| nn.Linear(hidden_size, hidden_size) | |
| ) | |
| # --- Decoder LSTM --- | |
| self.decoder_lstm = nn.LSTM( | |
| input_size=hidden_size + 1, | |
| hidden_size=hidden_size, | |
| batch_first=True, | |
| dropout=0.3 | |
| ) | |
| # --- Output Layer --- | |
| if self.is_coordinate_mode: | |
| # For coordinate prediction: output [x, y] coordinates | |
| self.output_layer = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size), | |
| nn.ReLU(), | |
| nn.Linear(hidden_size, 2), # Predict [x, y] in normalized coordinates (0 to 1) | |
| nn.Sigmoid() | |
| ) | |
| else: | |
| # For image prediction: output frames | |
| H, W = self.output_shape | |
| self.spatial_decoder_input_size = 8 # assumes hidden_size can reshape to [C, 8, 8] | |
| self.channels = hidden_size // (self.spatial_decoder_input_size ** 2) | |
| assert self.channels * self.spatial_decoder_input_size ** 2 == hidden_size, "hidden_size must be divisible by 64" | |
| self.output_layer = nn.Sequential( | |
| nn.Linear(hidden_size, H * W), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, inputs: TensorType["B", "F"]) -> TensorType["B", "T", "H", "W"]: | |
| """ | |
| Forward pass for both coordinate and image prediction modes. | |
| Args: | |
| inputs: [B, F] β input features (normalized mass, angle, friction) | |
| Returns: | |
| If coordinate mode: [B, T, 2] β sequence of (x,y) coordinates | |
| If image mode: [B, T, H, W] β sequence of predicted frames | |
| """ | |
| B = inputs.shape[0] | |
| T = self.output_seq_len | |
| encoded = self.encoder(inputs) # [B, hidden] | |
| decoder_input = encoded.unsqueeze(1).repeat(1, T, 1) # [B, T, hidden] | |
| timesteps = torch.linspace(0, 1, T, device=inputs.device).unsqueeze(0).repeat(B, 1).unsqueeze(-1) # [B, T, 1] | |
| decoder_input = torch.cat([encoded.unsqueeze(1).repeat(1, T, 1), timesteps], dim=-1) # [B, T, hidden + 1] | |
| lstm_out, _ = self.decoder_lstm(decoder_input) # [B, T, hidden] | |
| if self.is_coordinate_mode: | |
| # Coordinate prediction mode | |
| output = self.output_layer(lstm_out) # [B, T, 2] | |
| else: | |
| # Image prediction mode | |
| H, W = self.output_shape | |
| output = self.output_layer(lstm_out) # [B, T, H*W] | |
| output = output.view(B, T, H, W) # Reshape to [B, T, H, W] | |
| return output | |