Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -110,55 +110,41 @@ st.markdown("""
|
|
| 110 |
""", unsafe_allow_html=True)
|
| 111 |
|
| 112 |
# --- Python Backend Functions ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
@st.cache_resource(ttl=3600)
|
| 115 |
-
def
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
class GRUModel(nn.Module):
|
| 120 |
-
def __init__(self):
|
| 121 |
-
super(GRUModel, self).__init__()
|
| 122 |
-
self.gru = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout_prob)
|
| 123 |
-
self.fc = nn.Linear(hidden_dim, output_dim)
|
| 124 |
-
|
| 125 |
-
def forward(self, x):
|
| 126 |
-
h0 = torch.zeros(num_layers, x.size(0), hidden_dim).to(x.device)
|
| 127 |
-
out, _ = self.gru(x, h0)
|
| 128 |
-
return self.fc(out[:, -1, :])
|
| 129 |
-
|
| 130 |
-
class BiLSTMModel(nn.Module):
|
| 131 |
-
def __init__(self):
|
| 132 |
-
super(BiLSTMModel, self).__init__()
|
| 133 |
-
self.lstm = nn.LSTM(
|
| 134 |
-
input_size=1,
|
| 135 |
-
hidden_size=100,
|
| 136 |
-
num_layers=1, # <- match saved model
|
| 137 |
-
batch_first=True,
|
| 138 |
-
dropout=0.2,
|
| 139 |
-
bidirectional=True
|
| 140 |
-
)
|
| 141 |
-
self.fc = nn.Linear(200, 1) # 2 * hidden_size because of bidirectional
|
| 142 |
-
|
| 143 |
-
def forward(self, x):
|
| 144 |
-
h0 = torch.zeros(2 * 1, x.size(0), 100)
|
| 145 |
-
c0 = torch.zeros(2 * 1, x.size(0), 100)
|
| 146 |
-
out, _ = self.lstm(x, (h0, c0))
|
| 147 |
-
return self.fc(out[:, -1, :])
|
| 148 |
-
model_class = BiLSTMModel if model_type == 'Bi-Directional LSTM' else GRUModel
|
| 149 |
-
model = model_class()
|
| 150 |
-
|
| 151 |
-
checkpoint = torch.load(path, map_location=torch.device('cpu'))
|
| 152 |
-
|
| 153 |
-
# If full checkpoint was saved with keys like 'model_state_dict'
|
| 154 |
-
if 'model_state_dict' in checkpoint:
|
| 155 |
-
model.load_state_dict(checkpoint['model_state_dict'])
|
| 156 |
-
else:
|
| 157 |
-
model.load_state_dict(checkpoint) # Just raw state_dict
|
| 158 |
-
|
| 159 |
model.eval()
|
| 160 |
return model
|
| 161 |
-
|
| 162 |
@st.cache_data(ttl=900) # Cache data for 15 minutes
|
| 163 |
def get_stock_data(ticker):
|
| 164 |
"""Fetches historical stock data from Yahoo Finance for the last 4 years."""
|
|
@@ -175,14 +161,8 @@ def get_stock_data(ticker):
|
|
| 175 |
print(f"Successfully fetched {len(data)} rows for {ticker}")
|
| 176 |
return data
|
| 177 |
|
| 178 |
-
def predict_with_model(data: pd.DataFrame, n_days: int,
|
| 179 |
-
import torch
|
| 180 |
|
| 181 |
-
try:
|
| 182 |
-
model = load_pytorch_model(model_path, model_type=model_type)
|
| 183 |
-
except FileNotFoundError as e:
|
| 184 |
-
raise e
|
| 185 |
-
print("model:",model)
|
| 186 |
close_prices = data['Close'].values.reshape(-1, 1)
|
| 187 |
scaler = MinMaxScaler(feature_range=(0, 1))
|
| 188 |
scaled_prices = scaler.fit_transform(close_prices)
|
|
@@ -253,9 +233,15 @@ with st.sidebar:
|
|
| 253 |
st.session_state.last_ticker = ticker
|
| 254 |
st.session_state.last_model_type = model_type
|
| 255 |
st.session_state.last_prediction_days = prediction_days
|
| 256 |
-
st.rerun()
|
| 257 |
# --- Main Application Logic ---
|
| 258 |
if st.session_state.run_button_clicked:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
print(f"Inside main logic block. Current loading state: {st.session_state.loading}")
|
| 260 |
try:
|
| 261 |
if os.path.exists("AMZN_data.csv"):
|
|
@@ -268,7 +254,7 @@ if st.session_state.run_button_clicked:
|
|
| 268 |
st.session_state.error = f"Could not fetch data for ticker '{ticker}'. It may be an invalid symbol or network issue."
|
| 269 |
else:
|
| 270 |
model_path = "best_bilstm_model.pth" if model_type == "Bi-Directional LSTM" else "best_gru_model.pth"
|
| 271 |
-
st.session_state.predictions = predict_with_model(st.session_state.data, prediction_days,
|
| 272 |
st.session_state.error = None
|
| 273 |
|
| 274 |
except FileNotFoundError as e:
|
|
|
|
| 110 |
""", unsafe_allow_html=True)
|
| 111 |
|
| 112 |
# --- Python Backend Functions ---
|
| 113 |
+
# Outside of any function
|
| 114 |
+
import torch.nn as nn
|
| 115 |
+
import torch
|
| 116 |
+
|
| 117 |
+
class GRUModel(nn.Module):
|
| 118 |
+
def __init__(self, input_dim=1, hidden_dim=100, num_layers=2, output_dim=1, dropout_prob=0.2):
|
| 119 |
+
super(GRUModel, self).__init__()
|
| 120 |
+
self.gru = nn.GRU(input_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout_prob)
|
| 121 |
+
self.fc = nn.Linear(hidden_dim, output_dim)
|
| 122 |
+
|
| 123 |
+
def forward(self, x):
|
| 124 |
+
h0 = torch.zeros(2, x.size(0), 100).to(x.device)
|
| 125 |
+
out, _ = self.gru(x, h0)
|
| 126 |
+
return self.fc(out[:, -1, :])
|
| 127 |
+
|
| 128 |
+
class BiLSTMModel(nn.Module):
|
| 129 |
+
def __init__(self):
|
| 130 |
+
super(BiLSTMModel, self).__init__()
|
| 131 |
+
self.lstm = nn.LSTM(input_size=1, hidden_size=100, num_layers=1, batch_first=True, dropout=0.2, bidirectional=True)
|
| 132 |
+
self.fc = nn.Linear(200, 1)
|
| 133 |
+
|
| 134 |
+
def forward(self, x):
|
| 135 |
+
h0 = torch.zeros(2, x.size(0), 100)
|
| 136 |
+
c0 = torch.zeros(2, x.size(0), 100)
|
| 137 |
+
out, _ = self.lstm(x, (h0, c0))
|
| 138 |
+
return self.fc(out[:, -1, :])
|
| 139 |
|
| 140 |
@st.cache_resource(ttl=3600)
|
| 141 |
+
def load_model_from_disk(path, model_type):
|
| 142 |
+
model = BiLSTMModel() if model_type == "Bi-Directional LSTM" else GRUModel()
|
| 143 |
+
state = torch.load(path, map_location=torch.device("cpu"))
|
| 144 |
+
model.load_state_dict(state['model_state_dict'] if 'model_state_dict' in state else state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
model.eval()
|
| 146 |
return model
|
| 147 |
+
|
| 148 |
@st.cache_data(ttl=900) # Cache data for 15 minutes
|
| 149 |
def get_stock_data(ticker):
|
| 150 |
"""Fetches historical stock data from Yahoo Finance for the last 4 years."""
|
|
|
|
| 161 |
print(f"Successfully fetched {len(data)} rows for {ticker}")
|
| 162 |
return data
|
| 163 |
|
| 164 |
+
def predict_with_model(data: pd.DataFrame, n_days: int, model, model_type: str) -> pd.DataFrame:
|
|
|
|
| 165 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
close_prices = data['Close'].values.reshape(-1, 1)
|
| 167 |
scaler = MinMaxScaler(feature_range=(0, 1))
|
| 168 |
scaled_prices = scaler.fit_transform(close_prices)
|
|
|
|
| 233 |
st.session_state.last_ticker = ticker
|
| 234 |
st.session_state.last_model_type = model_type
|
| 235 |
st.session_state.last_prediction_days = prediction_days
|
|
|
|
| 236 |
# --- Main Application Logic ---
|
| 237 |
if st.session_state.run_button_clicked:
|
| 238 |
+
model_key = "bilstm_model" if model_type == "Bi-Directional LSTM" else "gru_model"
|
| 239 |
+
|
| 240 |
+
if model_key not in st.session_state:
|
| 241 |
+
model_path = "best_bilstm_model.pth" if model_type == "Bi-Directional LSTM" else "best_gru_model.pth"
|
| 242 |
+
st.session_state[model_key] = load_model_from_disk(model_path, model_type)
|
| 243 |
+
|
| 244 |
+
model = st.session_state[model_key]
|
| 245 |
print(f"Inside main logic block. Current loading state: {st.session_state.loading}")
|
| 246 |
try:
|
| 247 |
if os.path.exists("AMZN_data.csv"):
|
|
|
|
| 254 |
st.session_state.error = f"Could not fetch data for ticker '{ticker}'. It may be an invalid symbol or network issue."
|
| 255 |
else:
|
| 256 |
model_path = "best_bilstm_model.pth" if model_type == "Bi-Directional LSTM" else "best_gru_model.pth"
|
| 257 |
+
st.session_state.predictions = predict_with_model(st.session_state.data, prediction_days, model, model_type)
|
| 258 |
st.session_state.error = None
|
| 259 |
|
| 260 |
except FileNotFoundError as e:
|