import torch import torch.nn as nn class TrajectoryPredictor(nn.Module): def __init__(self, input_size=2, hidden_size=64, output_size=3): super().__init__() self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) self.linear = nn.Linear(hidden_size, output_size) def forward(self, x): h_lstm, _ = self.lstm(x) out = self.linear(h_lstm[:, -1, :]) return out # [x, y, z] of projected trajectory