dizolivemint's picture
πŸ” Sync with https://github.com/Dizolivemint/CSC580-Motion-Encoder-Decoder/tree/main
4fa8bcb
import torch.nn as nn
from torchtyping import TensorType
from typeguard import typechecked
import torch
class EncoderDecoder(nn.Module):
def __init__(self, input_dim=3, hidden_size=128, output_seq_len=10, output_shape=None):
"""
Encoder-Decoder model for sequence-to-sequence prediction.
Args:
input_dim (int): Dimension of input features (e.g., [mass, angle, friction])
hidden_size (int): Size of hidden representation
output_seq_len (int): Number of time steps (T)
output_shape (tuple, optional): Output frame shape (H, W) for image prediction mode.
If None, model operates in coordinate prediction mode.
"""
super().__init__()
self.hidden_size = hidden_size
self.output_seq_len = output_seq_len # T
self.output_shape = output_shape # (H, W) or None for coordinate mode
self.is_coordinate_mode = output_shape is None
# --- Encoder ---
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size)
)
# --- Decoder LSTM ---
self.decoder_lstm = nn.LSTM(
input_size=hidden_size + 1,
hidden_size=hidden_size,
batch_first=True,
dropout=0.3
)
# --- Output Layer ---
if self.is_coordinate_mode:
# For coordinate prediction: output [x, y] coordinates
self.output_layer = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 2), # Predict [x, y] in normalized coordinates (0 to 1)
nn.Sigmoid()
)
else:
# For image prediction: output frames
H, W = self.output_shape
self.spatial_decoder_input_size = 8 # assumes hidden_size can reshape to [C, 8, 8]
self.channels = hidden_size // (self.spatial_decoder_input_size ** 2)
assert self.channels * self.spatial_decoder_input_size ** 2 == hidden_size, "hidden_size must be divisible by 64"
self.output_layer = nn.Sequential(
nn.Linear(hidden_size, H * W),
nn.Sigmoid()
)
@typechecked
def forward(self, inputs: TensorType["B", "F"]) -> TensorType["B", "T", "H", "W"]:
"""
Forward pass for both coordinate and image prediction modes.
Args:
inputs: [B, F] β†’ input features (normalized mass, angle, friction)
Returns:
If coordinate mode: [B, T, 2] β†’ sequence of (x,y) coordinates
If image mode: [B, T, H, W] β†’ sequence of predicted frames
"""
B = inputs.shape[0]
T = self.output_seq_len
encoded = self.encoder(inputs) # [B, hidden]
decoder_input = encoded.unsqueeze(1).repeat(1, T, 1) # [B, T, hidden]
timesteps = torch.linspace(0, 1, T, device=inputs.device).unsqueeze(0).repeat(B, 1).unsqueeze(-1) # [B, T, 1]
decoder_input = torch.cat([encoded.unsqueeze(1).repeat(1, T, 1), timesteps], dim=-1) # [B, T, hidden + 1]
lstm_out, _ = self.decoder_lstm(decoder_input) # [B, T, hidden]
if self.is_coordinate_mode:
# Coordinate prediction mode
output = self.output_layer(lstm_out) # [B, T, 2]
else:
# Image prediction mode
H, W = self.output_shape
output = self.output_layer(lstm_out) # [B, T, H*W]
output = output.view(B, T, H, W) # Reshape to [B, T, H, W]
return output