nba_plusminus / model.py
aggtamv's picture
Deploy NBA predictor model
fae0184
# model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class LSTMModel(nn.Module):
def __init__(self, input_size):
super(LSTMModel, self).__init__()
hidden1 = 64
hidden2 = 256
self.lstm1 = nn.LSTM(input_size=input_size, hidden_size=hidden1, batch_first=True, dropout=0.2, bidirectional=True)
self.ln1 = nn.LayerNorm(hidden1 * 2)
self.lstm2 = nn.LSTM(input_size=hidden1 * 2, hidden_size=hidden2, batch_first=True, dropout=0.2, bidirectional=True)
self.ln2 = nn.LayerNorm(hidden2 * 2)
self.fc = nn.Linear(hidden2 * 2, 1)
def forward(self, x):
x, _ = self.lstm1(x)
x = self.ln1(x)
x = F.relu(x)
x, _ = self.lstm2(x)
x = self.ln2(x)
x = F.relu(x)
out = self.fc(x[:, -1, :])
return out