File size: 788 Bytes
b0d8e39
514c4c0
 
 
b0d8e39
 
514c4c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import os
import torch
import torch.nn as nn

os.environ["CUDA_VISIBLE_DEVICES"] = ""

class StockLSTM(nn.Module):
    def __init__(self, input_dim=1, hidden_dim=64, num_layers=2, dropout=0.2):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0.0,
            batch_first=True,
        )
        self.head = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        # x: [B, T, 1]
        out, (h_n, c_n) = self.lstm(x)  # out: [B, T, H]
        last = out[:, -1, :]            # [B, H]
        y = self.head(last)             # [B, 1]
        return y