Spaces:
Runtime error
Runtime error
File size: 459 Bytes
d9faf67 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
import torch
import torch.nn as nn
class TrajectoryPredictor(nn.Module):
def __init__(self, input_size=2, hidden_size=64, output_size=3):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
self.linear = nn.Linear(hidden_size, output_size)
def forward(self, x):
h_lstm, _ = self.lstm(x)
out = self.linear(h_lstm[:, -1, :])
return out # [x, y, z] of projected trajectory
|