File size: 3,784 Bytes
4fa8bcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch.nn as nn
from torchtyping import TensorType
from typeguard import typechecked
import torch

class EncoderDecoder(nn.Module):
    def __init__(self, input_dim=3, hidden_size=128, output_seq_len=10, output_shape=None):
        """
        Encoder-Decoder model for sequence-to-sequence prediction.

        Args:
            input_dim (int): Dimension of input features (e.g., [mass, angle, friction])
            hidden_size (int): Size of hidden representation
            output_seq_len (int): Number of time steps (T)
            output_shape (tuple, optional): Output frame shape (H, W) for image prediction mode.
                                          If None, model operates in coordinate prediction mode.
        """
        super().__init__()
        self.hidden_size = hidden_size
        self.output_seq_len = output_seq_len # T
        self.output_shape = output_shape  # (H, W) or None for coordinate mode
        self.is_coordinate_mode = output_shape is None

        # --- Encoder ---
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size)
        )

        # --- Decoder LSTM ---
        self.decoder_lstm = nn.LSTM(
            input_size=hidden_size + 1,
            hidden_size=hidden_size,
            batch_first=True,
            dropout=0.3
        )

        # --- Output Layer ---
        if self.is_coordinate_mode:
            # For coordinate prediction: output [x, y] coordinates
            self.output_layer = nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, 2),  # Predict [x, y] in normalized coordinates (0 to 1)
                nn.Sigmoid()
            )
        else:
            # For image prediction: output frames
            H, W = self.output_shape
            self.spatial_decoder_input_size = 8  # assumes hidden_size can reshape to [C, 8, 8]
            self.channels = hidden_size // (self.spatial_decoder_input_size ** 2)
            assert self.channels * self.spatial_decoder_input_size ** 2 == hidden_size, "hidden_size must be divisible by 64"
            
            self.output_layer = nn.Sequential(
                nn.Linear(hidden_size, H * W),
                nn.Sigmoid()
            )

    @typechecked
    def forward(self, inputs: TensorType["B", "F"]) -> TensorType["B", "T", "H", "W"]:
        """
        Forward pass for both coordinate and image prediction modes.

        Args:
            inputs: [B, F] → input features (normalized mass, angle, friction)

        Returns:
            If coordinate mode: [B, T, 2] → sequence of (x,y) coordinates
            If image mode: [B, T, H, W] → sequence of predicted frames
        """
        B = inputs.shape[0]
        T = self.output_seq_len

        encoded = self.encoder(inputs)                  # [B, hidden]
        decoder_input = encoded.unsqueeze(1).repeat(1, T, 1)  # [B, T, hidden]
        timesteps = torch.linspace(0, 1, T, device=inputs.device).unsqueeze(0).repeat(B, 1).unsqueeze(-1)  # [B, T, 1]
        decoder_input = torch.cat([encoded.unsqueeze(1).repeat(1, T, 1), timesteps], dim=-1)  # [B, T, hidden + 1]

        lstm_out, _ = self.decoder_lstm(decoder_input)  # [B, T, hidden]

        if self.is_coordinate_mode:
            # Coordinate prediction mode
            output = self.output_layer(lstm_out)  # [B, T, 2]
        else:
            # Image prediction mode
            H, W = self.output_shape
            output = self.output_layer(lstm_out)  # [B, T, H*W]
            output = output.view(B, T, H, W)  # Reshape to [B, T, H, W]

        return output