File size: 459 Bytes
d9faf67
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch
import torch.nn as nn

class TrajectoryPredictor(nn.Module):
    def __init__(self, input_size=2, hidden_size=64, output_size=3):
        super().__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.linear = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        h_lstm, _ = self.lstm(x)
        out = self.linear(h_lstm[:, -1, :])
        return out  # [x, y, z] of projected trajectory