amitke commited on
Commit
95aa7b7
·
1 Parent(s): e490837
Files changed (3) hide show
  1. inference.py +1 -1
  2. testing.py +1 -1
  3. train.py +1 -1
inference.py CHANGED
@@ -26,7 +26,7 @@ def _load_artifacts(symbol: str):
26
  scaler = pickle.load(f)
27
 
28
  model = StockLSTM(input_dim=1, hidden_dim=64, num_layers=2, dropout=0.2)
29
- model.load_state_dict(torch.load(p["model"], map_location="gpu"))
30
  model.eval()
31
  return model, scaler, meta
32
 
 
26
  scaler = pickle.load(f)
27
 
28
  model = StockLSTM(input_dim=1, hidden_dim=64, num_layers=2, dropout=0.2)
29
+ model.load_state_dict(torch.load(p["model"], map_location="cpu"))
30
  model.eval()
31
  return model, scaler, meta
32
 
testing.py CHANGED
@@ -13,7 +13,7 @@ ARTIFACTS_DIR = "artifacts"
13
  def evaluate(symbol: str):
14
  base = os.path.join(ARTIFACTS_DIR, symbol.upper())
15
  model = StockLSTM(input_dim=1, hidden_dim=64, num_layers=2, dropout=0.2)
16
- model.load_state_dict(torch.load(os.path.join(base, "model.pt"), map_location="gpu"))
17
  model.eval()
18
  with open(os.path.join(base, "scaler.pkl"), "rb") as f:
19
  scaler = pickle.load(f)
 
13
  def evaluate(symbol: str):
14
  base = os.path.join(ARTIFACTS_DIR, symbol.upper())
15
  model = StockLSTM(input_dim=1, hidden_dim=64, num_layers=2, dropout=0.2)
16
+ model.load_state_dict(torch.load(os.path.join(base, "model.pt"), map_location="cpu"))
17
  model.eval()
18
  with open(os.path.join(base, "scaler.pkl"), "rb") as f:
19
  scaler = pickle.load(f)
train.py CHANGED
@@ -44,7 +44,7 @@ def to_tensor_loader(X, y, batch_size=32):
44
 
45
  def train(symbol: str, seq_len: int = 60, epochs: int = 5, batch_size: int = 32,
46
  start: str = None, end: str = None, lr: float = 1e-3):
47
- device = torch.device("gpu")
48
 
49
  # --- data ---
50
  df = fetch_data(symbol, start, end)
 
44
 
45
  def train(symbol: str, seq_len: int = 60, epochs: int = 5, batch_size: int = 32,
46
  start: str = None, end: str = None, lr: float = 1e-3):
47
+ device = torch.device("cpu")
48
 
49
  # --- data ---
50
  df = fetch_data(symbol, start, end)