File size: 11,701 Bytes
068b3e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce3d808
068b3e6
 
 
 
 
 
 
 
ce3d808
 
 
068b3e6
 
ce3d808
 
5f3dfad
068b3e6
 
ce3d808
 
 
 
 
 
 
 
068b3e6
 
15ebc49
 
068b3e6
15ebc49
068b3e6
 
 
 
15ebc49
0504309
15ebc49
 
 
 
 
a102391
 
 
068b3e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f3dfad
 
 
 
068b3e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f3dfad
068b3e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0504309
 
 
60823cc
068b3e6
 
 
 
 
 
5f3dfad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
068b3e6
5f3dfad
068b3e6
5f3dfad
 
068b3e6
5f3dfad
 
068b3e6
5f3dfad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
068b3e6
5f3dfad
 
 
068b3e6
5f3dfad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
068b3e6
a4a7fe2
fc6f7dc
a4a7fe2
fc6f7dc
 
 
a4a7fe2
 
15ebc49
fc6f7dc
 
a4a7fe2
fc6f7dc
 
15ebc49
dd28e45
068b3e6
a4a7fe2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
import streamlit as st
import pandas as pd
import numpy as np
import requests
import joblib
import os
import plotly.express as px
import plotly.graph_objects as go
from alpha_vantage.techindicators import TechIndicators
from alpha_vantage.timeseries import TimeSeries
from datetime import datetime
from dotenv import load_dotenv

# Load env vars (for local support)
load_dotenv()

# --- Config ---
st.set_page_config(page_title="Stock Prediction System", layout="wide", page_icon="πŸ“ˆ")
# MODEL_DIR removed (Dynamic loading now used)

# --- Secrets ---
# Try to get from st.secrets (Cloud) or os.getenv (Local)
ALPHA_VANTAGE_KEY = os.getenv("ALPHA_VANTAGE_API_KEY")
WEBHOOK_URL = os.getenv("WEBHOOK_URL")

# --- Helper Functions ---
@st.cache_resource
def load_models_local(symbol):
    """Loads models directly from disk for the specific symbol."""
    model_path = f"models/{symbol}"
    models = {}
    try:
        models['regression'] = joblib.load(f"{model_path}/regression_model.pkl")
        models['classification'] = joblib.load(f"{model_path}/classification_model.pkl")
        models['clustering'] = joblib.load(f"{model_path}/clustering_model.pkl")
        return models
    except Exception as e:
        # Fallback to AAPL if specific model missing (for robustness)
        if symbol != "AAPL":
             try:
                 # st.warning(f"Models for {symbol} not found. Using AAPL logic transfer.")
                 return load_models_local("AAPL")
             except:
                 pass
        st.error(f"Failed to load models for {symbol}: {e}")
        return None

from src.orchestration.notifications import notify_discord

def send_discord_notification(symbol, price, change_percent, prediction_dir):
    """Sends a formatted message to Discord using the centralized module."""
    
    emoji = "πŸš€" if change_percent > 0 else "πŸ”»"
    pred_emoji = "🟒" if "UP" in prediction_dir else "πŸ”΄"
    
    # Format the message string
    message = (f"**Stock Update** πŸ•’\n"
               f"**{symbol}**: ${price:.2f} {emoji} ({change_percent:.2f}%)\n"
               f"**AI Prediction:** {prediction_dir} {pred_emoji}")
    
    # Use the robust notification function
    # It handles checking WEBHOOK_URL and printing errors
    # Use the robust notification function
    # It handles checking WEBHOOK_URL and printing errors
    return notify_discord(message)

@st.cache_data(ttl=3600) # CACHE FOR 1 HOUR
def fetch_live_data(symbol):
    """Fetches raw price data and calculates indicators locally (bypassing API limits)."""
    if not ALPHA_VANTAGE_KEY:
        st.warning("⚠️ ALPHA_VANTAGE_API_KEY not found. Using Mock Data.")
        return get_mock_data(symbol)

    try:
        # Fetch only Daily Price (Free Endpoint)
        ts = TimeSeries(key=ALPHA_VANTAGE_KEY, output_format='pandas')
        data, _ = ts.get_daily(symbol=symbol, outputsize='compact') # 100 data points is enough for indicators
        
        # Ensure sorted chronologically
        data = data.sort_index()
        
        # Rename columns standard for calculation
        data.columns = ['open', 'high', 'low', 'close', 'volume']
        
        # --- Local Calculation (Free & Unlimited) ---
        # SMA
        data['sma_20'] = data['close'].rolling(window=20).mean()
        data['sma_50'] = data['close'].rolling(window=50).mean()
        
        # RSI
        delta = data['close'].diff()
        gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
        loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
        rs = gain / loss
        data['rsi'] = 100 - (100 / (1 + rs))
        
        # Volatility (20-day std dev of Returns)
        data['returns'] = data['close'].pct_change()
        data['volatility'] = data['returns'].rolling(window=20).std()
        
        # MACD (12, 26, 9)
        exp1 = data['close'].ewm(span=12, adjust=False).mean()
        exp2 = data['close'].ewm(span=26, adjust=False).mean()
        macd = exp1 - exp2
        # signal = macd.ewm(span=9, adjust=False).mean() # We don't use signal for model input
        data['macd'] = macd

        # Get latest valid row
        latest = data.iloc[-1]
        prev = data.iloc[-2]
        
        change_percent = ((latest['close'] - prev['close']) / prev['close']) * 100

        return {
            "price": float(latest['close']),
            "change": change_percent,
            "sma_20": float(latest['sma_20']),
            "sma_50": float(latest['sma_50']),
            "rsi": float(latest['rsi']),
            "macd": float(latest['macd']),
            "volatility": float(latest['volatility']) if not np.isnan(latest['volatility']) else 0.0,
            "is_mock": False
        }

    except Exception as e:
        # st.warning(f"API Error: {e}. Falling back to mock.")
        # Only show warning if it's not the common "Key Error" on first load
        print(f"Fetch failed: {e}")
        st.warning(f"Could not fetch data for {symbol} (API Limit?). Showing Mock Data.")
        return get_mock_data(symbol)

def get_mock_data(symbol):
    """Generates realistic mock data if API fails or key missing."""
    base_price = {"AAPL": 150, "GOOGL": 2800, "MSFT": 300, "AMZN": 3400, "TSLA": 900, "NVDA": 400}
    price = base_price.get(symbol, 100) + np.random.uniform(-5, 5)
    return {
        "price": price,
        "change": np.random.uniform(-2, 2),
        "sma_20": price * 0.95,
        "sma_50": price * 0.90,
        "rsi": np.random.uniform(30, 70),
        "macd": np.random.uniform(-1, 1),
        "is_mock": True
    }

# --- UI Layout ---
st.title("πŸ“ˆ AI Stock Prediction System")

# Sidebar
st.sidebar.header("Control Panel")
available_stocks = ["AAPL", "GOOGL", "MSFT", "AMZN", "TSLA", "NVDA"]
symbol = st.sidebar.selectbox("Select Stock", available_stocks)

if st.sidebar.button("πŸ”„ Refresh Data"):
    st.cache_data.clear() # Clear cache to force update
    st.rerun()





# --- Main Logic ---

# 1. Fetch Data
with st.spinner(f"Fetching Live Data for {symbol}..."):
    data = fetch_live_data(symbol)

# --- Layout: Tabs ---
tab1, tab2, tab3 = st.tabs(["πŸš€ Dashboard", "🧠 Deep Dive", "πŸ“Š Raw Data"])

# === TAB 1: DASHBOARD ===
with tab1:
    # A. Header Metrics
    col_head1, col_head2, col_head3, col_head4 = st.columns(4)
    with col_head1:
        st.metric("Current Price", f"${data['price']:.2f}", f"{data['change']:.2f}%")
    with col_head2:
        st.metric("RSI (Momentum)", f"{data['rsi']:.1f}", "Overbought" if data['rsi']>70 else "Oversold" if data['rsi']<30 else "Neutral", delta_color="off")
    with col_head3:
        st.metric("Volatility", f"{data.get('volatility', 0):.4f}", help="20-Day Std Dev of Returns")
    with col_head4:
        source = "πŸ”΄ Mock" if data['is_mock'] else "🟒 Live"
        st.metric("Data Source", source)

    st.markdown("---")

    # B. AI Prediction Section
    st.subheader(f"πŸ€– AI Prediction for {symbol}")
    
    features = np.array([[data['sma_20'], data['sma_50'], data['rsi'], data['macd']]])
    models = load_models_local(symbol)
    
    if models:
        col_pred1, col_pred2 = st.columns(2)
        
        # Regression
        pred_price = models['regression'].predict(features)[0]
        
        # Classification
        pred_direction_prob = models['classification'].predict_proba(features)[0]
        direction = "UP πŸš€" if pred_direction_prob[1] > 0.5 else "DOWN πŸ”»"
        confidence = max(pred_direction_prob)
        
        with col_pred1:
            st.info(f"**Predicted Direction:** {direction}")
            st.progress(float(confidence), text=f"Confidence: {confidence*100:.1f}%")
            
        with col_pred2:
            st.success(f"**Target Price (Next Close):** ${pred_price:.2f}")

    # C. Price Chart (Candlestick)
    st.subheader("πŸ“‰ Price History")
    # Note: fetch_live_data only returns the LAST row's calculated metrics + latest meta, 
    # but for charts we need the full dataframe. 
    # To fix this without breaking the cache, we'll fetch full history purely for charting here.
    # Ideally, fetch_live_data should return the full DF, but let's do a quick fetch for charts:
    try:
        if not data['is_mock'] and ALPHA_VANTAGE_KEY:
            ts = TimeSeries(key=ALPHA_VANTAGE_KEY, output_format='pandas')
            hist_data, _ = ts.get_daily(symbol=symbol, outputsize='compact')
            hist_data = hist_data.sort_index()
            hist_data.columns = ['open', 'high', 'low', 'close', 'volume']
            
            fig = go.Figure(data=[go.Candlestick(x=hist_data.index,
                            open=hist_data['open'],
                            high=hist_data['high'],
                            low=hist_data['low'],
                            close=hist_data['close'])])
            fig.update_layout(title=f"{symbol} Daily Price", xaxis_title="Date", yaxis_title="Price", template="plotly_dark")
            st.plotly_chart(fig, use_container_width=True)
        else:
             st.warning("Charts unavailable in Mock Data mode (Add API Key to see charts).")
    except Exception as e:
        st.error(f"Could not load chart: {e}")


# === TAB 2: DEEP DIVE (Unsupervised & Technicals) ===
with tab2:
    st.header("🧠 Advanced Analysis")
    
    # Clustering / Market Regime
    if models and 'clustering' in models:
        st.subheader("🧐 Market Regime (Clustering)")
        
        clus_features = np.array([[data.get('volatility', 0), data['rsi']]])
        cluster_id = models['clustering'].predict(clus_features)[0]
        
        regime_labels = {
            0: "Regime 0 (Watch) πŸ‘οΈ",
            1: "Regime 1 (Accumulate) πŸ’°",
            2: "Regime 2 (Risk/Volatile) ⚠️"
        }
        regime_name = regime_labels.get(cluster_id, f"Cluster {cluster_id}")
        
        st.info(f"**Current State:** {regime_name}")
        st.caption("We use K-Means Clustering on Volatility & RSI to identify the market state.")

    st.markdown("---")
    
    # Technical Indicators Chart
    st.subheader("πŸ“Š Technical Indicators")
    if not data['is_mock'] and 'hist_data' in locals():
        # Calculate Indicators on history for plotting
        # Simple RSI calculation for plotting
        delta = hist_data['close'].diff()
        gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
        loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
        rs = gain / loss
        hist_data['rsi_plot'] = 100 - (100 / (1 + rs))

        fig_rsi = px.line(hist_data, x=hist_data.index, y='rsi_plot', title="Relative Strength Index (14)")
        fig_rsi.add_hline(y=70, line_dash="dash", line_color="red")
        fig_rsi.add_hline(y=30, line_dash="dash", line_color="green")
        fig_rsi.update_layout(template="plotly_dark")
        st.plotly_chart(fig_rsi, use_container_width=True)

# === TAB 3: RAW DATA ===
with tab3:
    st.subheader("Raw Data View")
    st.json(data)

# --- Sidebar Notification ---
st.sidebar.markdown("---")
if st.sidebar.button("πŸ”” Send Discord Update"):
    # Use current data if available, else defaults
    current_price = data.get('price', 0.0)
    current_change = data.get('change', 0.0)
    # If models failed, we won't have 'direction', so we use a placeholder checks
    test_direction = direction if 'direction' in locals() else "N/A"
    
    success, status_msg = send_discord_notification(symbol, current_price, current_change, test_direction)
    if success:
        st.sidebar.success("Sent!")
    else:
        st.sidebar.error(f"Failed: {status_msg}")

# 4. Footer
st.markdown("---")
st.caption("AI Stock Prediction System | Deployed on Hugging Face Spaces")