import yfinance as yf import pandas as pd import numpy as np import plotly.graph_objects as go import gradio as gr import io from PIL import Image import matplotlib.pyplot as plt from datetime import datetime import plotly.express as px import warnings import timesfm from prophet import Prophet class StockDataFetcher: """Handles fetching and preprocessing stock data""" @staticmethod def fetch_stock_data(ticker, start_date, end_date): """Fetch and preprocess stock data""" stock_data = yf.download(ticker, start=start_date, end=end_date) # Handle MultiIndex columns if present if isinstance(stock_data.columns, pd.MultiIndex): stock_data.columns = stock_data.columns.droplevel(level=1) # Standardize column names stock_data.columns = ['Close', 'High', 'Low', 'Open', 'Volume'] return stock_data # Function for TimesFM forecasting def timesfm_forecast(ticker, start_date, end_date): try: # Fetch historical data using the StockDataFetcher class stock_data = StockDataFetcher.fetch_stock_data(ticker, start_date, end_date) # Reset index to have 'Date' as a column stock_data.reset_index(inplace=True) # Select relevant columns and rename them df = stock_data[['Date', 'Close']].rename(columns={'Date': 'ds', 'Close': 'y'}) # Ensure the dates are in datetime format df['ds'] = pd.to_datetime(df['ds']) # Add a unique identifier for the time series df['unique_id'] = ticker # Initialize the TimesFM model tfm = timesfm.TimesFm( hparams=timesfm.TimesFmHparams( backend="pytorch", per_core_batch_size=32, horizon_len=30, # Predicting the next 30 days input_patch_len=32, output_patch_len=128, num_layers=50, model_dims=1280, use_positional_embedding=False, ), checkpoint=timesfm.TimesFmCheckpoint( huggingface_repo_id="google/timesfm-2.0-500m-pytorch" ), ) # Forecast using the prepared DataFrame forecast_df = tfm.forecast_on_df( inputs=df, freq="D", # Daily frequency value_name="y", num_jobs=-1, ) # Ensure forecast_df has the correct columns forecast_df.rename(columns={"timesfm": "forecast"}, inplace=True) # Create an interactive plot with Plotly fig = go.Figure() # Add Actual Prices Line fig.add_trace(go.Scatter(x=df["ds"], y=df["y"], mode="lines", name="Actual Prices", line=dict(color="#00FFFF", width=2))) # Brighter cyan # Add Forecasted Prices Line fig.add_trace(go.Scatter(x=forecast_df["ds"], y=forecast_df["forecast"], mode="lines", name="Forecasted Prices", line=dict(color="#FF00FF", width=2, dash="dash"))) # Brighter magenta # Layout Customization fig.update_layout( title=f"{ticker} Stock Price Forecast (TimesFM)", xaxis_title="Date", yaxis_title="Price", template="plotly_dark", # Dark Theme hovermode="x unified", # Show all values on hover legend=dict(bgcolor="rgba(0,0,0,0.8)", bordercolor="white", borderwidth=1), plot_bgcolor="#111111", # Slightly lighter than black for contrast paper_bgcolor="#111111", font=dict(color="white", size=12), margin=dict(l=40, r=40, t=50, b=40), ) # Add grid lines for better readability fig.update_xaxes(showgrid=True, gridcolor="rgba(255,255,255,0.1)") fig.update_yaxes(showgrid=True, gridcolor="rgba(255,255,255,0.1)") return fig # Return the Plotly figure for Gradio except Exception as e: return f"Error: {str(e)}" # Function for Prophet forecasting def prophet_forecast(ticker, start_date, end_date): try: # Download stock market data using the StockDataFetcher class df = StockDataFetcher.fetch_stock_data(ticker, start_date, end_date) # Reset the index to get 'Date' back as a column df_plot = df.reset_index() # Prepare the data for Prophet df1 = df_plot[['Date', 'Close']].rename(columns={'Date': 'ds', 'Close': 'y'}) # Fit the model m = Prophet() m.fit(df1) # Create future dataframe and make predictions future = m.make_future_dataframe(periods=30, freq='D') forecast = m.predict(future) # Plotting stock closing prices with trend fig1 = go.Figure() # Add actual closing prices fig1.add_trace(go.Scatter( x=df1['ds'], y=df1['y'], mode='lines', name='Actual Price', line=dict(color='#36D7B7', width=2) )) # Add trend component fig1.add_trace(go.Scatter( x=forecast['ds'], y=forecast['trend'], mode='lines', name='Trend', line=dict(color='#FF6B6B', width=2) )) fig1.update_layout( title=f'{ticker} Price and Trend', plot_bgcolor='#111111', paper_bgcolor='#111111', font=dict(color='white', size=12), margin=dict(l=40, r=40, t=50, b=40), xaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.1)"), yaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.1)"), legend=dict(bgcolor="rgba(0,0,0,0.8)", bordercolor="white", borderwidth=1) ) # Plotting forecast with confidence interval forecast_40 = forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']].tail(40) fig2 = go.Figure() # Add forecast line fig2.add_trace(go.Scatter( x=forecast_40['ds'], y=forecast_40['yhat'], mode='lines', name='Forecast', line=dict(color='#FF6B6B', width=2) )) # Add confidence interval fig2.add_trace(go.Scatter( x=forecast_40["ds"].tolist() + forecast_40["ds"].tolist()[::-1], y=forecast_40["yhat_upper"].tolist() + forecast_40["yhat_lower"].tolist()[::-1], fill="toself", fillcolor="rgba(78, 205, 196, 0.2)", line=dict(color="rgba(255,255,255,0)"), name="Confidence Interval" )) fig2.update_layout( title=f'{ticker} 30 Days Forecast (Prophet)', plot_bgcolor='#111111', paper_bgcolor='#111111', font=dict(color='white', size=12), margin=dict(l=40, r=40, t=50, b=40), xaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.1)"), yaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.1)"), legend=dict(bgcolor="rgba(0,0,0,0.8)", bordercolor="white", borderwidth=1) ) # Create components figure components_fig = go.Figure() # Add components if they exist in the forecast if 'yearly' in forecast.columns: yearly_pattern = forecast.iloc[-365:] if len(forecast) > 365 else forecast components_fig.add_trace(go.Scatter( x=yearly_pattern['ds'], y=yearly_pattern['yearly'], mode='lines', name='Yearly Pattern', line=dict(color='#4ECDC4', width=2) )) components_fig.update_layout( title=f'{ticker} Forecast Components', xaxis_title='Date', yaxis_title='Value', plot_bgcolor='#111111', paper_bgcolor='#111111', font=dict(color='white', size=12), legend=dict(bgcolor="rgba(0,0,0,0.8)", bordercolor="white", borderwidth=1), margin=dict(l=40, r=40, t=50, b=40), xaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.1)"), yaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.1)") ) # For backwards compatibility, still create the matplotlib figure try: plt.style.use('dark_background') fig, ax = plt.subplots(figsize=(10, 8), facecolor='#111111') plt.rcParams.update({ 'text.color': 'white', 'axes.labelcolor': 'white', 'axes.edgecolor': 'white', 'xtick.color': 'white', 'ytick.color': 'white', 'grid.color': 'gray', 'figure.facecolor': '#111111', 'axes.facecolor': '#111111', 'savefig.facecolor': '#111111', }) m.plot_components(forecast, ax=ax) for ax in plt.gcf().get_axes(): ax.set_facecolor('#111111') for spine in ax.spines.values(): spine.set_color('white') ax.tick_params(colors='white') ax.title.set_color('white') for line in ax.get_lines(): if line.get_color() == 'b': line.set_color('#C678DD') else: line.set_color('#FF6B6B') plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format='png', facecolor='#111111') buf.seek(0) plt.close(fig) img = Image.open(buf) return fig1, fig2, components_fig except Exception as e: print(f"Error with Matplotlib components: {e}") return fig1, fig2, components_fig except Exception as e: return f"Error: {str(e)}", f"Error: {str(e)}", None # Functions for technical analysis def smooth_moving_average(series: pd.Series, window: int) -> pd.Series: if len(series) < window or window <= 0: return pd.Series(series.mean(), index=series.index) result = pd.Series(index=series.index, dtype=float) result.iloc[:window] = series.iloc[:window].mean() for i in range(window, len(series)): result.iloc[i] = (result.iloc[i-1] * (window - 1) + series.iloc[i]) / window return result.ffill().bfill().fillna(series.mean()) def calculate_rsi(close: pd.Series, window: int = 14) -> pd.Series: if len(close) <= window: return pd.Series(50.0, index=close.index) delta = close.diff() gain = delta.where(delta > 0, 0.0) loss = -delta.where(delta < 0, 0.0) avg_gain = smooth_moving_average(gain, window) avg_loss = smooth_moving_average(loss, window) rs = np.where(avg_loss != 0, avg_gain / avg_loss, np.inf) rsi = 100.0 - (100.0 / (1.0 + rs)) return pd.Series(rsi, index=close.index).replace([np.inf, -np.inf], np.nan).ffill().bfill().fillna(50.0) def calculate_stochastic(high: pd.Series, low: pd.Series, close: pd.Series, k_window=14, d_window=3): if len(close) < k_window: return pd.Series(50.0, index=close.index), pd.Series(50.0, index=close.index) lowest = low.rolling(k_window, min_periods=1).min() highest = high.rolling(k_window, min_periods=1).max() k_pct = ((close - lowest) / (highest - lowest + 1e-10)) * 100 k_pct = k_pct.clip(0, 100) d_pct = k_pct.rolling(d_window, min_periods=1).mean() return k_pct.ffill().bfill().fillna(50.0), d_pct.ffill().bfill().fillna(50.0) def calculate_cci(high: pd.Series, low: pd.Series, close: pd.Series, window=20): if len(close) < window: return pd.Series(0.0, index=close.index) typical_price = (high + low + close) / 3.0 sma = typical_price.rolling(window, min_periods=1).mean() mean_deviation = (typical_price - sma).abs().rolling(window, min_periods=1).mean() cci = (typical_price - sma) / (0.015 * mean_deviation + 1e-10) return cci.ffill().bfill().fillna(0.0) # --- New Robust Helper Functions --- def calculate_sma_robust(series: pd.Series, window: int) -> pd.Series: if len(series) < window or window <= 0: return pd.Series(series.mean(), index=series.index) return series.rolling(window=window, min_periods=window).mean().ffill().bfill().fillna(series.mean()) def calculate_ema_robust(series: pd.Series, span: int) -> pd.Series: if len(series) < span or span <= 0: return pd.Series(series.mean(), index=series.index) return series.ewm(span=span, adjust=False, min_periods=span).mean().ffill().bfill().fillna(series.mean()) def calculate_macd_robust(close: pd.Series): ema12 = calculate_ema_robust(close, 12) ema26 = calculate_ema_robust(close, 26) macd_line = ema12 - ema26 signal_line = calculate_ema_robust(macd_line, 9) return macd_line, signal_line def calculate_bollinger_bands_robust(close: pd.Series, window=20, num_std=2.0): if len(close) < window: mid = pd.Series(close.mean(), index=close.index) return mid, mid, mid sma = calculate_sma_robust(close, window) std = close.rolling(window=window, min_periods=window).std().fillna(1e-10) upper = sma + num_std * std lower = sma - num_std * std return sma.ffill().bfill(), upper.ffill().bfill(), lower.ffill().bfill() # --- The Core Integration: generate_trading_signals --- def generate_trading_signals(df: pd.DataFrame) -> pd.DataFrame: """ Generates trading signals using strict thresholds to minimize false positives. Output columns match the expected names for the plotting functions. """ df = df.copy() close = df['Close'] has_hl = all(col in df.columns for col in ['High', 'Low']) has_vol = 'Volume' in df.columns high = df['High'] if has_hl else close low = df['Low'] if has_hl else close volume = df['Volume'] if has_vol else pd.Series(1.0, index=close.index) # Calculate indicators using robust methods rsi = calculate_rsi(close, window=14) stoch_k, stoch_d = calculate_stochastic(high, low, close, k_window=14, d_window=3) cci = calculate_cci(high, low, close, window=20) sma30 = calculate_sma_robust(close, 30) sma100 = calculate_sma_robust(close, 100) macd_line, macd_signal_line = calculate_macd_robust(close) _, bb_upper, bb_lower = calculate_bollinger_bands_robust(close, window=20, num_std=2.5) # CMF Calculation if has_hl and has_vol: mfv = ((close - low) - (high - close)) / (high - low + 1e-10) * volume cmf = mfv.rolling(window=20, min_periods=20).sum() / (volume.rolling(window=20, min_periods=20).sum() + 1e-10) cmf = cmf.ffill().bfill().fillna(0.0) else: cmf = pd.Series(0.0, index=close.index) # --- STRICT SIGNAL LOGIC (Output matches old function's schema) --- # Initialize all signal columns to 0 for col in ['MACD_Signal', 'RSI_Signal', 'BB_Signal', 'Stochastic_Signal', 'CMF_Signal', 'CCI_Signal']: df[col] = 0 # 1. MACD Signal macd_bull = ( (macd_line > macd_signal_line) & (macd_line.shift(1) <= macd_signal_line.shift(1)) & (macd_line > 0.5) & ((macd_line - macd_signal_line) > 0.8) ) macd_bear = ( (macd_line < macd_signal_line) & (macd_line.shift(1) >= macd_signal_line.shift(1)) & (macd_line < -0.5) & ((macd_signal_line - macd_line) > 0.8) ) df.loc[macd_bull, 'MACD_Signal'] = 1 df.loc[macd_bear, 'MACD_Signal'] = -1 # 2. RSI Signal df.loc[rsi < 15, 'RSI_Signal'] = 1 df.loc[rsi > 85, 'RSI_Signal'] = -1 # 3. Bollinger Bands Signal df.loc[close <= bb_lower, 'BB_Signal'] = 1 df.loc[close >= bb_upper, 'BB_Signal'] = -1 # 4. Stochastic Signal stoch_buy = (stoch_k < 5) & (stoch_d < 5) stoch_sell = (stoch_k > 95) & (stoch_d > 95) df.loc[stoch_buy, 'Stochastic_Signal'] = 1 df.loc[stoch_sell, 'Stochastic_Signal'] = -1 # 5. CMF Signal df.loc[cmf < -0.5, 'CMF_Signal'] = 1 df.loc[cmf > 0.5, 'CMF_Signal'] = -1 # 6. CCI Signal df.loc[cci < -250, 'CCI_Signal'] = 1 df.loc[cci > 250, 'CCI_Signal'] = -1 # Create the Combined_Signal by summing the individual signals df['Combined_Signal'] = df[['MACD_Signal', 'RSI_Signal', 'BB_Signal', 'Stochastic_Signal', 'CMF_Signal', 'CCI_Signal']].sum(axis=1) return df def plot_combined_signals(df, ticker): """ Creates a focused plot of JUST the combined signal strength. Bars are colored green for positive (buy) signals and red for negative (sell) signals. """ # Create a new figure fig = go.Figure() # Define colors based on the signal value (positive/negative) colors = ['#2ECC71' if val >= 0 else '#E74C3C' for val in df['Combined_Signal']] # Add the bar chart for the combined signal fig.add_trace(go.Bar( x=df.index, y=df['Combined_Signal'], name='Signal Strength', marker_color=colors, # Add hover text for clarity hovertemplate='Date: %{x}
Signal: %{y}' )) # Update the layout for a clean, focused look fig.update_layout( title=f'{ticker}', template='plotly_dark', xaxis_title='Date', yaxis_title='Signal Strength Score', yaxis=dict(zeroline=True, zerolinewidth=2, zerolinecolor='gray'), showlegend=False # Not needed for a single trace ) return fig def plot_individual_signals(df, ticker, x_range=None): fig = go.Figure() # Closing price fig.add_trace(go.Scatter( x=df.index, y=df['Close'], mode='lines', name='Closing Price', line=dict(color='#36A2EB', width=2) )) signal_colors = { 'MACD_Signal': {'buy': '#39FF14', 'sell': '#FF073A'}, 'RSI_Signal': {'buy': '#39FF14', 'sell': '#FF073A'}, 'BB_Signal': {'buy': '#39FF14', 'sell': '#FF073A'}, 'Stochastic_Signal': {'buy': '#39FF14', 'sell': '#FF073A'}, 'CMF_Signal': {'buy': '#39FF14', 'sell': '#FF073A'}, 'CCI_Signal': {'buy': '#39FF14', 'sell': '#FF073A'} } signal_names = ['MACD_Signal', 'RSI_Signal', 'BB_Signal', 'Stochastic_Signal', 'CMF_Signal', 'CCI_Signal'] for signal in signal_names: buy_signals = df[df[signal] == 1] sell_signals = df[df[signal] == -1] fig.add_trace(go.Scatter( x=buy_signals.index, y=buy_signals['Close'], mode='markers', marker=dict(symbol='triangle-up', size=12, color=signal_colors[signal]['buy']), name=f'{signal} Buy' )) fig.add_trace(go.Scatter( x=sell_signals.index, y=sell_signals['Close'], mode='markers', marker=dict(symbol='triangle-down', size=12, color=signal_colors[signal]['sell']), name=f'{signal} Sell' )) fig.update_layout( title=f'{ticker}', xaxis=dict( title='Date', showgrid=True, gridcolor="rgba(255,255,255,0.1)", range=x_range # ←←← Shared x-axis range ), yaxis=dict( title='Price', side='left', showgrid=True, gridcolor="rgba(255,255,255,0.1)" ), plot_bgcolor='#111111', paper_bgcolor='#111111', font=dict(color='white', size=12), legend=dict( orientation='h', # Horizontal legend yanchor='bottom', y=1.02, # Just above the plot xanchor='right', x=1, bgcolor="rgba(0,0,0,0.8)", bordercolor="white", borderwidth=1 ), margin=dict(l=40, r=40, t=80, b=40) # Extra top margin for legend ) return fig def technical_analysis(ticker, start_date, end_date): try: # Download stock data using the StockDataFetcher class df = StockDataFetcher.fetch_stock_data(ticker, start_date, end_date) # Generate signals df = generate_trading_signals(df) # Last 120 days for plotting df_last_120 = df.tail(120) # Plot combined signals fig_signals = plot_combined_signals(df_last_120, ticker) # Plot individual signals fig_individual_signals = plot_individual_signals(df_last_120, ticker) return fig_signals, fig_individual_signals except Exception as e: return f"Error: {str(e)}", f"Error: {str(e)}" # Custom CSS for better appearance custom_css = """ .gradio-container { font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; } .container { max-width: 1200px; margin: auto; } button#analyze-btn { background-color: #003366; color: white; border: none; } """ # Create Gradio interface with gr.Blocks(theme=gr.themes.Monochrome(), css=custom_css) as demo: gr.Markdown("# Advanced Stock Analysis & Forecasting App") gr.Markdown("Enter a stock ticker, start date, and end date to analyze and forecast stock prices.") with gr.Row(): ticker_input = gr.Textbox(label="Enter Stock Ticker", value="NVDA") start_date_input = gr.Textbox(label="Enter Start Date (YYYY-MM-DD)", value="2025-01-01") end_date_input = gr.Textbox(label="Enter End Date (YYYY-MM-DD)", value="2027-01-01") # Create tabs for different analysis types with gr.Tabs() as tabs: with gr.TabItem("TimesFM Forecast"): timesfm_button = gr.Button("Generate TimesFM Forecast") timesfm_plot = gr.Plot(label="TimesFM Stock Price Forecast") # Connect button to function timesfm_button.click( timesfm_forecast, inputs=[ticker_input, start_date_input, end_date_input], outputs=timesfm_plot ) with gr.TabItem("Prophet Forecast"): prophet_button = gr.Button("Generate Prophet Forecast") prophet_recent_plot = gr.Plot(label="Recent Stock Prices") prophet_forecast_plot = gr.Plot(label="Prophet 30-Day Forecast") prophet_components = gr.Plot(label="Forecast Components") # Changed from gr.Image to gr.Plot with gr.TabItem("Technical Analysis"): analysis_button = gr.Button("Generate Technical Analysis") individual_signals = gr.Plot(label="Individual Trading Signals") combined_signals = gr.Plot(label="Combined Trading Signals") # Connect button to function analysis_button.click( technical_analysis, inputs=[ticker_input, start_date_input, end_date_input], outputs=[combined_signals, individual_signals] ) # Connect button to function prophet_button.click( prophet_forecast, inputs=[ticker_input, start_date_input, end_date_input], outputs=[prophet_recent_plot, prophet_forecast_plot, prophet_components] ) # Launch the app demo.launch()