Spaces:
Runtime error
Runtime error
amitke commited on
Commit ·
95aa7b7
1
Parent(s): e490837
- inference.py +1 -1
- testing.py +1 -1
- train.py +1 -1
inference.py
CHANGED
|
@@ -26,7 +26,7 @@ def _load_artifacts(symbol: str):
|
|
| 26 |
scaler = pickle.load(f)
|
| 27 |
|
| 28 |
model = StockLSTM(input_dim=1, hidden_dim=64, num_layers=2, dropout=0.2)
|
| 29 |
-
model.load_state_dict(torch.load(p["model"], map_location="
|
| 30 |
model.eval()
|
| 31 |
return model, scaler, meta
|
| 32 |
|
|
|
|
| 26 |
scaler = pickle.load(f)
|
| 27 |
|
| 28 |
model = StockLSTM(input_dim=1, hidden_dim=64, num_layers=2, dropout=0.2)
|
| 29 |
+
model.load_state_dict(torch.load(p["model"], map_location="cpu"))
|
| 30 |
model.eval()
|
| 31 |
return model, scaler, meta
|
| 32 |
|
testing.py
CHANGED
|
@@ -13,7 +13,7 @@ ARTIFACTS_DIR = "artifacts"
|
|
| 13 |
def evaluate(symbol: str):
|
| 14 |
base = os.path.join(ARTIFACTS_DIR, symbol.upper())
|
| 15 |
model = StockLSTM(input_dim=1, hidden_dim=64, num_layers=2, dropout=0.2)
|
| 16 |
-
model.load_state_dict(torch.load(os.path.join(base, "model.pt"), map_location="
|
| 17 |
model.eval()
|
| 18 |
with open(os.path.join(base, "scaler.pkl"), "rb") as f:
|
| 19 |
scaler = pickle.load(f)
|
|
|
|
| 13 |
def evaluate(symbol: str):
|
| 14 |
base = os.path.join(ARTIFACTS_DIR, symbol.upper())
|
| 15 |
model = StockLSTM(input_dim=1, hidden_dim=64, num_layers=2, dropout=0.2)
|
| 16 |
+
model.load_state_dict(torch.load(os.path.join(base, "model.pt"), map_location="cpu"))
|
| 17 |
model.eval()
|
| 18 |
with open(os.path.join(base, "scaler.pkl"), "rb") as f:
|
| 19 |
scaler = pickle.load(f)
|
train.py
CHANGED
|
@@ -44,7 +44,7 @@ def to_tensor_loader(X, y, batch_size=32):
|
|
| 44 |
|
| 45 |
def train(symbol: str, seq_len: int = 60, epochs: int = 5, batch_size: int = 32,
|
| 46 |
start: str = None, end: str = None, lr: float = 1e-3):
|
| 47 |
-
device = torch.device("
|
| 48 |
|
| 49 |
# --- data ---
|
| 50 |
df = fetch_data(symbol, start, end)
|
|
|
|
| 44 |
|
| 45 |
def train(symbol: str, seq_len: int = 60, epochs: int = 5, batch_size: int = 32,
|
| 46 |
start: str = None, end: str = None, lr: float = 1e-3):
|
| 47 |
+
device = torch.device("cpu")
|
| 48 |
|
| 49 |
# --- data ---
|
| 50 |
df = fetch_data(symbol, start, end)
|