amitke commited on
Commit
f3bda49
·
1 Parent(s): 1e90053
Files changed (1) hide show
  1. train.py +1 -1
train.py CHANGED
@@ -75,7 +75,7 @@ def train(symbol: str, seq_len: int = 60, epochs: int = 5, batch_size: int = 32,
75
  # --- model ---
76
  model = StockLSTM(input_dim=1, hidden_dim=64, num_layers=2, dropout=0.2).to(device)
77
  criterion = nn.MSELoss()
78
- optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay = 0.1)
79
  # optimizer = optim.SGD(model.parameters(), lr=lr)
80
 
81
  # --- training ---
 
75
  # --- model ---
76
  model = StockLSTM(input_dim=1, hidden_dim=64, num_layers=2, dropout=0.2).to(device)
77
  criterion = nn.MSELoss()
78
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay = 0.05)
79
  # optimizer = optim.SGD(model.parameters(), lr=lr)
80
 
81
  # --- training ---