Upload folder using huggingface_hub
Browse files- streamlit_app.py +113 -30
streamlit_app.py
CHANGED
|
@@ -32,7 +32,7 @@ def load_models_local(symbol):
|
|
| 32 |
try:
|
| 33 |
models['regression'] = joblib.load(f"{model_path}/regression_model.pkl")
|
| 34 |
models['classification'] = joblib.load(f"{model_path}/classification_model.pkl")
|
| 35 |
-
|
| 36 |
return models
|
| 37 |
except Exception as e:
|
| 38 |
# Fallback to AAPL if specific model missing (for robustness)
|
|
@@ -94,6 +94,10 @@ def fetch_live_data(symbol):
|
|
| 94 |
rs = gain / loss
|
| 95 |
data['rsi'] = 100 - (100 / (1 + rs))
|
| 96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
# MACD (12, 26, 9)
|
| 98 |
exp1 = data['close'].ewm(span=12, adjust=False).mean()
|
| 99 |
exp2 = data['close'].ewm(span=26, adjust=False).mean()
|
|
@@ -114,6 +118,7 @@ def fetch_live_data(symbol):
|
|
| 114 |
"sma_50": float(latest['sma_50']),
|
| 115 |
"rsi": float(latest['rsi']),
|
| 116 |
"macd": float(latest['macd']),
|
|
|
|
| 117 |
"is_mock": False
|
| 118 |
}
|
| 119 |
|
|
@@ -160,41 +165,119 @@ if st.sidebar.button("π Refresh Data"):
|
|
| 160 |
with st.spinner(f"Fetching Live Data for {symbol}..."):
|
| 161 |
data = fetch_live_data(symbol)
|
| 162 |
|
| 163 |
-
#
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
with
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
st.
|
| 176 |
-
|
|
|
|
|
|
|
| 177 |
|
| 178 |
-
|
| 179 |
-
models = load_models_local(symbol)
|
| 180 |
|
| 181 |
-
|
| 182 |
-
|
| 183 |
|
| 184 |
-
|
| 185 |
-
|
| 186 |
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
st.
|
| 195 |
|
| 196 |
-
|
| 197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
|
| 199 |
# --- Sidebar Notification ---
|
| 200 |
st.sidebar.markdown("---")
|
|
|
|
| 32 |
try:
|
| 33 |
models['regression'] = joblib.load(f"{model_path}/regression_model.pkl")
|
| 34 |
models['classification'] = joblib.load(f"{model_path}/classification_model.pkl")
|
| 35 |
+
models['clustering'] = joblib.load(f"{model_path}/clustering_model.pkl")
|
| 36 |
return models
|
| 37 |
except Exception as e:
|
| 38 |
# Fallback to AAPL if specific model missing (for robustness)
|
|
|
|
| 94 |
rs = gain / loss
|
| 95 |
data['rsi'] = 100 - (100 / (1 + rs))
|
| 96 |
|
| 97 |
+
# Volatility (20-day std dev of Returns)
|
| 98 |
+
data['returns'] = data['close'].pct_change()
|
| 99 |
+
data['volatility'] = data['returns'].rolling(window=20).std()
|
| 100 |
+
|
| 101 |
# MACD (12, 26, 9)
|
| 102 |
exp1 = data['close'].ewm(span=12, adjust=False).mean()
|
| 103 |
exp2 = data['close'].ewm(span=26, adjust=False).mean()
|
|
|
|
| 118 |
"sma_50": float(latest['sma_50']),
|
| 119 |
"rsi": float(latest['rsi']),
|
| 120 |
"macd": float(latest['macd']),
|
| 121 |
+
"volatility": float(latest['volatility']) if not np.isnan(latest['volatility']) else 0.0,
|
| 122 |
"is_mock": False
|
| 123 |
}
|
| 124 |
|
|
|
|
| 165 |
with st.spinner(f"Fetching Live Data for {symbol}..."):
|
| 166 |
data = fetch_live_data(symbol)
|
| 167 |
|
| 168 |
+
# --- Layout: Tabs ---
|
| 169 |
+
tab1, tab2, tab3 = st.tabs(["π Dashboard", "π§ Deep Dive", "π Raw Data"])
|
| 170 |
+
|
| 171 |
+
# === TAB 1: DASHBOARD ===
|
| 172 |
+
with tab1:
|
| 173 |
+
# A. Header Metrics
|
| 174 |
+
col_head1, col_head2, col_head3, col_head4 = st.columns(4)
|
| 175 |
+
with col_head1:
|
| 176 |
+
st.metric("Current Price", f"${data['price']:.2f}", f"{data['change']:.2f}%")
|
| 177 |
+
with col_head2:
|
| 178 |
+
st.metric("RSI (Momentum)", f"{data['rsi']:.1f}", "Overbought" if data['rsi']>70 else "Oversold" if data['rsi']<30 else "Neutral", delta_color="off")
|
| 179 |
+
with col_head3:
|
| 180 |
+
st.metric("Volatility", f"{data.get('volatility', 0):.4f}", help="20-Day Std Dev of Returns")
|
| 181 |
+
with col_head4:
|
| 182 |
+
source = "π΄ Mock" if data['is_mock'] else "π’ Live"
|
| 183 |
+
st.metric("Data Source", source)
|
| 184 |
|
| 185 |
+
st.markdown("---")
|
|
|
|
| 186 |
|
| 187 |
+
# B. AI Prediction Section
|
| 188 |
+
st.subheader(f"π€ AI Prediction for {symbol}")
|
| 189 |
|
| 190 |
+
features = np.array([[data['sma_20'], data['sma_50'], data['rsi'], data['macd']]])
|
| 191 |
+
models = load_models_local(symbol)
|
| 192 |
|
| 193 |
+
if models:
|
| 194 |
+
col_pred1, col_pred2 = st.columns(2)
|
| 195 |
+
|
| 196 |
+
# Regression
|
| 197 |
+
pred_price = models['regression'].predict(features)[0]
|
| 198 |
+
|
| 199 |
+
# Classification
|
| 200 |
+
pred_direction_prob = models['classification'].predict_proba(features)[0]
|
| 201 |
+
direction = "UP π" if pred_direction_prob[1] > 0.5 else "DOWN π»"
|
| 202 |
+
confidence = max(pred_direction_prob)
|
| 203 |
+
|
| 204 |
+
with col_pred1:
|
| 205 |
+
st.info(f"**Predicted Direction:** {direction}")
|
| 206 |
+
st.progress(float(confidence), text=f"Confidence: {confidence*100:.1f}%")
|
| 207 |
+
|
| 208 |
+
with col_pred2:
|
| 209 |
+
st.success(f"**Target Price (Next Close):** ${pred_price:.2f}")
|
| 210 |
+
|
| 211 |
+
# C. Price Chart (Candlestick)
|
| 212 |
+
st.subheader("π Price History")
|
| 213 |
+
# Note: fetch_live_data only returns the LAST row's calculated metrics + latest meta,
|
| 214 |
+
# but for charts we need the full dataframe.
|
| 215 |
+
# To fix this without breaking the cache, we'll fetch full history purely for charting here.
|
| 216 |
+
# Ideally, fetch_live_data should return the full DF, but let's do a quick fetch for charts:
|
| 217 |
+
try:
|
| 218 |
+
if not data['is_mock'] and ALPHA_VANTAGE_KEY:
|
| 219 |
+
ts = TimeSeries(key=ALPHA_VANTAGE_KEY, output_format='pandas')
|
| 220 |
+
hist_data, _ = ts.get_daily(symbol=symbol, outputsize='compact')
|
| 221 |
+
hist_data = hist_data.sort_index()
|
| 222 |
+
hist_data.columns = ['open', 'high', 'low', 'close', 'volume']
|
| 223 |
+
|
| 224 |
+
fig = go.Figure(data=[go.Candlestick(x=hist_data.index,
|
| 225 |
+
open=hist_data['open'],
|
| 226 |
+
high=hist_data['high'],
|
| 227 |
+
low=hist_data['low'],
|
| 228 |
+
close=hist_data['close'])])
|
| 229 |
+
fig.update_layout(title=f"{symbol} Daily Price", xaxis_title="Date", yaxis_title="Price", template="plotly_dark")
|
| 230 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 231 |
+
else:
|
| 232 |
+
st.warning("Charts unavailable in Mock Data mode (Add API Key to see charts).")
|
| 233 |
+
except Exception as e:
|
| 234 |
+
st.error(f"Could not load chart: {e}")
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
# === TAB 2: DEEP DIVE (Unsupervised & Technicals) ===
|
| 238 |
+
with tab2:
|
| 239 |
+
st.header("π§ Advanced Analysis")
|
| 240 |
|
| 241 |
+
# Clustering / Market Regime
|
| 242 |
+
if models and 'clustering' in models:
|
| 243 |
+
st.subheader("π§ Market Regime (Clustering)")
|
| 244 |
|
| 245 |
+
clus_features = np.array([[data.get('volatility', 0), data['rsi']]])
|
| 246 |
+
cluster_id = models['clustering'].predict(clus_features)[0]
|
| 247 |
+
|
| 248 |
+
regime_labels = {
|
| 249 |
+
0: "Regime 0 (Watch) ποΈ",
|
| 250 |
+
1: "Regime 1 (Accumulate) π°",
|
| 251 |
+
2: "Regime 2 (Risk/Volatile) β οΈ"
|
| 252 |
+
}
|
| 253 |
+
regime_name = regime_labels.get(cluster_id, f"Cluster {cluster_id}")
|
| 254 |
+
|
| 255 |
+
st.info(f"**Current State:** {regime_name}")
|
| 256 |
+
st.caption("We use K-Means Clustering on Volatility & RSI to identify the market state.")
|
| 257 |
+
|
| 258 |
+
st.markdown("---")
|
| 259 |
+
|
| 260 |
+
# Technical Indicators Chart
|
| 261 |
+
st.subheader("π Technical Indicators")
|
| 262 |
+
if not data['is_mock'] and 'hist_data' in locals():
|
| 263 |
+
# Calculate Indicators on history for plotting
|
| 264 |
+
# Simple RSI calculation for plotting
|
| 265 |
+
delta = hist_data['close'].diff()
|
| 266 |
+
gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
|
| 267 |
+
loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
|
| 268 |
+
rs = gain / loss
|
| 269 |
+
hist_data['rsi_plot'] = 100 - (100 / (1 + rs))
|
| 270 |
+
|
| 271 |
+
fig_rsi = px.line(hist_data, x=hist_data.index, y='rsi_plot', title="Relative Strength Index (14)")
|
| 272 |
+
fig_rsi.add_hline(y=70, line_dash="dash", line_color="red")
|
| 273 |
+
fig_rsi.add_hline(y=30, line_dash="dash", line_color="green")
|
| 274 |
+
fig_rsi.update_layout(template="plotly_dark")
|
| 275 |
+
st.plotly_chart(fig_rsi, use_container_width=True)
|
| 276 |
+
|
| 277 |
+
# === TAB 3: RAW DATA ===
|
| 278 |
+
with tab3:
|
| 279 |
+
st.subheader("Raw Data View")
|
| 280 |
+
st.json(data)
|
| 281 |
|
| 282 |
# --- Sidebar Notification ---
|
| 283 |
st.sidebar.markdown("---")
|