umer6016 commited on
Commit
5f3dfad
Β·
verified Β·
1 Parent(s): ce3d808

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. streamlit_app.py +113 -30
streamlit_app.py CHANGED
@@ -32,7 +32,7 @@ def load_models_local(symbol):
32
  try:
33
  models['regression'] = joblib.load(f"{model_path}/regression_model.pkl")
34
  models['classification'] = joblib.load(f"{model_path}/classification_model.pkl")
35
- # specific clustering/pca models might be needed too if visualizing
36
  return models
37
  except Exception as e:
38
  # Fallback to AAPL if specific model missing (for robustness)
@@ -94,6 +94,10 @@ def fetch_live_data(symbol):
94
  rs = gain / loss
95
  data['rsi'] = 100 - (100 / (1 + rs))
96
 
 
 
 
 
97
  # MACD (12, 26, 9)
98
  exp1 = data['close'].ewm(span=12, adjust=False).mean()
99
  exp2 = data['close'].ewm(span=26, adjust=False).mean()
@@ -114,6 +118,7 @@ def fetch_live_data(symbol):
114
  "sma_50": float(latest['sma_50']),
115
  "rsi": float(latest['rsi']),
116
  "macd": float(latest['macd']),
 
117
  "is_mock": False
118
  }
119
 
@@ -160,41 +165,119 @@ if st.sidebar.button("πŸ”„ Refresh Data"):
160
  with st.spinner(f"Fetching Live Data for {symbol}..."):
161
  data = fetch_live_data(symbol)
162
 
163
- # 2. visual Header
164
- col_head1, col_head2, col_head3 = st.columns(3)
165
- with col_head1:
166
- st.metric("Current Price", f"${data['price']:.2f}", f"{data['change']:.2f}%")
167
- with col_head2:
168
- st.metric("RSI (Momentum)", f"{data['rsi']:.1f}", "Overbought" if data['rsi']>70 else "Oversold" if data['rsi']<30 else "Neutral", delta_color="off")
169
- with col_head3:
170
- source = "πŸ”΄ Mock Data (Check API Key)" if data['is_mock'] else "🟒 Live Alpha Vantage Data"
171
- st.caption(f"Data Source: {source}")
172
- st.caption(f"Last Updated: {datetime.now().strftime('%H:%M:%S')}")
173
-
174
- # 3. AI Prediction
175
- st.markdown("---")
176
- st.subheader(f"πŸ€– AI Analysis for {symbol}")
 
 
177
 
178
- features = np.array([[data['sma_20'], data['sma_50'], data['rsi'], data['macd']]])
179
- models = load_models_local(symbol)
180
 
181
- if models:
182
- col_pred1, col_pred2 = st.columns(2)
183
 
184
- # Regression
185
- pred_price = models['regression'].predict(features)[0]
186
 
187
- # Classification
188
- pred_direction_prob = models['classification'].predict_proba(features)[0]
189
- direction = "UP πŸš€" if pred_direction_prob[1] > 0.5 else "DOWN πŸ”»"
190
- confidence = max(pred_direction_prob)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
- with col_pred1:
193
- st.info(f"**Predicted Direction:** {direction}")
194
- st.progress(float(confidence), text=f"Confidence: {confidence*100:.1f}%")
195
 
196
- with col_pred2:
197
- st.success(f"**Target Price (Next Close):** ${pred_price:.2f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  # --- Sidebar Notification ---
200
  st.sidebar.markdown("---")
 
32
  try:
33
  models['regression'] = joblib.load(f"{model_path}/regression_model.pkl")
34
  models['classification'] = joblib.load(f"{model_path}/classification_model.pkl")
35
+ models['clustering'] = joblib.load(f"{model_path}/clustering_model.pkl")
36
  return models
37
  except Exception as e:
38
  # Fallback to AAPL if specific model missing (for robustness)
 
94
  rs = gain / loss
95
  data['rsi'] = 100 - (100 / (1 + rs))
96
 
97
+ # Volatility (20-day std dev of Returns)
98
+ data['returns'] = data['close'].pct_change()
99
+ data['volatility'] = data['returns'].rolling(window=20).std()
100
+
101
  # MACD (12, 26, 9)
102
  exp1 = data['close'].ewm(span=12, adjust=False).mean()
103
  exp2 = data['close'].ewm(span=26, adjust=False).mean()
 
118
  "sma_50": float(latest['sma_50']),
119
  "rsi": float(latest['rsi']),
120
  "macd": float(latest['macd']),
121
+ "volatility": float(latest['volatility']) if not np.isnan(latest['volatility']) else 0.0,
122
  "is_mock": False
123
  }
124
 
 
165
  with st.spinner(f"Fetching Live Data for {symbol}..."):
166
  data = fetch_live_data(symbol)
167
 
168
+ # --- Layout: Tabs ---
169
+ tab1, tab2, tab3 = st.tabs(["πŸš€ Dashboard", "🧠 Deep Dive", "πŸ“Š Raw Data"])
170
+
171
+ # === TAB 1: DASHBOARD ===
172
+ with tab1:
173
+ # A. Header Metrics
174
+ col_head1, col_head2, col_head3, col_head4 = st.columns(4)
175
+ with col_head1:
176
+ st.metric("Current Price", f"${data['price']:.2f}", f"{data['change']:.2f}%")
177
+ with col_head2:
178
+ st.metric("RSI (Momentum)", f"{data['rsi']:.1f}", "Overbought" if data['rsi']>70 else "Oversold" if data['rsi']<30 else "Neutral", delta_color="off")
179
+ with col_head3:
180
+ st.metric("Volatility", f"{data.get('volatility', 0):.4f}", help="20-Day Std Dev of Returns")
181
+ with col_head4:
182
+ source = "πŸ”΄ Mock" if data['is_mock'] else "🟒 Live"
183
+ st.metric("Data Source", source)
184
 
185
+ st.markdown("---")
 
186
 
187
+ # B. AI Prediction Section
188
+ st.subheader(f"πŸ€– AI Prediction for {symbol}")
189
 
190
+ features = np.array([[data['sma_20'], data['sma_50'], data['rsi'], data['macd']]])
191
+ models = load_models_local(symbol)
192
 
193
+ if models:
194
+ col_pred1, col_pred2 = st.columns(2)
195
+
196
+ # Regression
197
+ pred_price = models['regression'].predict(features)[0]
198
+
199
+ # Classification
200
+ pred_direction_prob = models['classification'].predict_proba(features)[0]
201
+ direction = "UP πŸš€" if pred_direction_prob[1] > 0.5 else "DOWN πŸ”»"
202
+ confidence = max(pred_direction_prob)
203
+
204
+ with col_pred1:
205
+ st.info(f"**Predicted Direction:** {direction}")
206
+ st.progress(float(confidence), text=f"Confidence: {confidence*100:.1f}%")
207
+
208
+ with col_pred2:
209
+ st.success(f"**Target Price (Next Close):** ${pred_price:.2f}")
210
+
211
+ # C. Price Chart (Candlestick)
212
+ st.subheader("πŸ“‰ Price History")
213
+ # Note: fetch_live_data only returns the LAST row's calculated metrics + latest meta,
214
+ # but for charts we need the full dataframe.
215
+ # To fix this without breaking the cache, we'll fetch full history purely for charting here.
216
+ # Ideally, fetch_live_data should return the full DF, but let's do a quick fetch for charts:
217
+ try:
218
+ if not data['is_mock'] and ALPHA_VANTAGE_KEY:
219
+ ts = TimeSeries(key=ALPHA_VANTAGE_KEY, output_format='pandas')
220
+ hist_data, _ = ts.get_daily(symbol=symbol, outputsize='compact')
221
+ hist_data = hist_data.sort_index()
222
+ hist_data.columns = ['open', 'high', 'low', 'close', 'volume']
223
+
224
+ fig = go.Figure(data=[go.Candlestick(x=hist_data.index,
225
+ open=hist_data['open'],
226
+ high=hist_data['high'],
227
+ low=hist_data['low'],
228
+ close=hist_data['close'])])
229
+ fig.update_layout(title=f"{symbol} Daily Price", xaxis_title="Date", yaxis_title="Price", template="plotly_dark")
230
+ st.plotly_chart(fig, use_container_width=True)
231
+ else:
232
+ st.warning("Charts unavailable in Mock Data mode (Add API Key to see charts).")
233
+ except Exception as e:
234
+ st.error(f"Could not load chart: {e}")
235
+
236
+
237
+ # === TAB 2: DEEP DIVE (Unsupervised & Technicals) ===
238
+ with tab2:
239
+ st.header("🧠 Advanced Analysis")
240
 
241
+ # Clustering / Market Regime
242
+ if models and 'clustering' in models:
243
+ st.subheader("🧐 Market Regime (Clustering)")
244
 
245
+ clus_features = np.array([[data.get('volatility', 0), data['rsi']]])
246
+ cluster_id = models['clustering'].predict(clus_features)[0]
247
+
248
+ regime_labels = {
249
+ 0: "Regime 0 (Watch) πŸ‘οΈ",
250
+ 1: "Regime 1 (Accumulate) πŸ’°",
251
+ 2: "Regime 2 (Risk/Volatile) ⚠️"
252
+ }
253
+ regime_name = regime_labels.get(cluster_id, f"Cluster {cluster_id}")
254
+
255
+ st.info(f"**Current State:** {regime_name}")
256
+ st.caption("We use K-Means Clustering on Volatility & RSI to identify the market state.")
257
+
258
+ st.markdown("---")
259
+
260
+ # Technical Indicators Chart
261
+ st.subheader("πŸ“Š Technical Indicators")
262
+ if not data['is_mock'] and 'hist_data' in locals():
263
+ # Calculate Indicators on history for plotting
264
+ # Simple RSI calculation for plotting
265
+ delta = hist_data['close'].diff()
266
+ gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
267
+ loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
268
+ rs = gain / loss
269
+ hist_data['rsi_plot'] = 100 - (100 / (1 + rs))
270
+
271
+ fig_rsi = px.line(hist_data, x=hist_data.index, y='rsi_plot', title="Relative Strength Index (14)")
272
+ fig_rsi.add_hline(y=70, line_dash="dash", line_color="red")
273
+ fig_rsi.add_hline(y=30, line_dash="dash", line_color="green")
274
+ fig_rsi.update_layout(template="plotly_dark")
275
+ st.plotly_chart(fig_rsi, use_container_width=True)
276
+
277
+ # === TAB 3: RAW DATA ===
278
+ with tab3:
279
+ st.subheader("Raw Data View")
280
+ st.json(data)
281
 
282
  # --- Sidebar Notification ---
283
  st.sidebar.markdown("---")