Spaces:
Build error
Build error
| import yfinance as yf | |
| import pandas as pd | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| import gradio as gr | |
| import io | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| from datetime import datetime | |
| import plotly.express as px | |
| import warnings | |
| import timesfm | |
| from prophet import Prophet | |
| class StockDataFetcher: | |
| """Handles fetching and preprocessing stock data""" | |
| def fetch_stock_data(ticker, start_date, end_date): | |
| """Fetch and preprocess stock data""" | |
| stock_data = yf.download(ticker, start=start_date, end=end_date) | |
| # Handle MultiIndex columns if present | |
| if isinstance(stock_data.columns, pd.MultiIndex): | |
| stock_data.columns = stock_data.columns.droplevel(level=1) | |
| # Standardize column names | |
| stock_data.columns = ['Close', 'High', 'Low', 'Open', 'Volume'] | |
| return stock_data | |
| # Function for TimesFM forecasting | |
| def timesfm_forecast(ticker, start_date, end_date): | |
| try: | |
| # Fetch historical data using the StockDataFetcher class | |
| stock_data = StockDataFetcher.fetch_stock_data(ticker, start_date, end_date) | |
| # Reset index to have 'Date' as a column | |
| stock_data.reset_index(inplace=True) | |
| # Select relevant columns and rename them | |
| df = stock_data[['Date', 'Close']].rename(columns={'Date': 'ds', 'Close': 'y'}) | |
| # Ensure the dates are in datetime format | |
| df['ds'] = pd.to_datetime(df['ds']) | |
| # Add a unique identifier for the time series | |
| df['unique_id'] = ticker | |
| # Initialize the TimesFM model | |
| tfm = timesfm.TimesFm( | |
| hparams=timesfm.TimesFmHparams( | |
| backend="pytorch", | |
| per_core_batch_size=32, | |
| horizon_len=30, # Predicting the next 30 days | |
| input_patch_len=32, | |
| output_patch_len=128, | |
| num_layers=50, | |
| model_dims=1280, | |
| use_positional_embedding=False, | |
| ), | |
| checkpoint=timesfm.TimesFmCheckpoint( | |
| huggingface_repo_id="google/timesfm-2.0-500m-pytorch" | |
| ), | |
| ) | |
| # Forecast using the prepared DataFrame | |
| forecast_df = tfm.forecast_on_df( | |
| inputs=df, | |
| freq="D", # Daily frequency | |
| value_name="y", | |
| num_jobs=-1, | |
| ) | |
| # Ensure forecast_df has the correct columns | |
| forecast_df.rename(columns={"timesfm": "forecast"}, inplace=True) | |
| # Create an interactive plot with Plotly | |
| fig = go.Figure() | |
| # Add Actual Prices Line | |
| fig.add_trace(go.Scatter(x=df["ds"], y=df["y"], | |
| mode="lines", name="Actual Prices", | |
| line=dict(color="#00FFFF", width=2))) # Brighter cyan | |
| # Add Forecasted Prices Line | |
| fig.add_trace(go.Scatter(x=forecast_df["ds"], y=forecast_df["forecast"], | |
| mode="lines", name="Forecasted Prices", | |
| line=dict(color="#FF00FF", width=2, dash="dash"))) # Brighter magenta | |
| # Layout Customization | |
| fig.update_layout( | |
| title=f"{ticker} Stock Price Forecast (TimesFM)", | |
| xaxis_title="Date", | |
| yaxis_title="Price", | |
| template="plotly_dark", # Dark Theme | |
| hovermode="x unified", # Show all values on hover | |
| legend=dict(bgcolor="rgba(0,0,0,0.8)", bordercolor="white", borderwidth=1), | |
| plot_bgcolor="#111111", # Slightly lighter than black for contrast | |
| paper_bgcolor="#111111", | |
| font=dict(color="white", size=12), | |
| margin=dict(l=40, r=40, t=50, b=40), | |
| ) | |
| # Add grid lines for better readability | |
| fig.update_xaxes(showgrid=True, gridcolor="rgba(255,255,255,0.1)") | |
| fig.update_yaxes(showgrid=True, gridcolor="rgba(255,255,255,0.1)") | |
| return fig # Return the Plotly figure for Gradio | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Function for Prophet forecasting | |
| def prophet_forecast(ticker, start_date, end_date): | |
| try: | |
| # Download stock market data using the StockDataFetcher class | |
| df = StockDataFetcher.fetch_stock_data(ticker, start_date, end_date) | |
| # Reset the index to get 'Date' back as a column | |
| df_plot = df.reset_index() | |
| # Prepare the data for Prophet | |
| df1 = df_plot[['Date', 'Close']].rename(columns={'Date': 'ds', 'Close': 'y'}) | |
| # Fit the model | |
| m = Prophet() | |
| m.fit(df1) | |
| # Create future dataframe and make predictions | |
| future = m.make_future_dataframe(periods=30, freq='D') | |
| forecast = m.predict(future) | |
| # Plotting stock closing prices with trend | |
| fig1 = go.Figure() | |
| # Add actual closing prices | |
| fig1.add_trace(go.Scatter( | |
| x=df1['ds'], | |
| y=df1['y'], | |
| mode='lines', | |
| name='Actual Price', | |
| line=dict(color='#36D7B7', width=2) | |
| )) | |
| # Add trend component | |
| fig1.add_trace(go.Scatter( | |
| x=forecast['ds'], | |
| y=forecast['trend'], | |
| mode='lines', | |
| name='Trend', | |
| line=dict(color='#FF6B6B', width=2) | |
| )) | |
| fig1.update_layout( | |
| title=f'{ticker} Price and Trend', | |
| plot_bgcolor='#111111', | |
| paper_bgcolor='#111111', | |
| font=dict(color='white', size=12), | |
| margin=dict(l=40, r=40, t=50, b=40), | |
| xaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.1)"), | |
| yaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.1)"), | |
| legend=dict(bgcolor="rgba(0,0,0,0.8)", bordercolor="white", borderwidth=1) | |
| ) | |
| # Plotting forecast with confidence interval | |
| forecast_40 = forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']].tail(40) | |
| fig2 = go.Figure() | |
| # Add forecast line | |
| fig2.add_trace(go.Scatter( | |
| x=forecast_40['ds'], | |
| y=forecast_40['yhat'], | |
| mode='lines', | |
| name='Forecast', | |
| line=dict(color='#FF6B6B', width=2) | |
| )) | |
| # Add confidence interval | |
| fig2.add_trace(go.Scatter( | |
| x=forecast_40["ds"].tolist() + forecast_40["ds"].tolist()[::-1], | |
| y=forecast_40["yhat_upper"].tolist() + forecast_40["yhat_lower"].tolist()[::-1], | |
| fill="toself", | |
| fillcolor="rgba(78, 205, 196, 0.2)", | |
| line=dict(color="rgba(255,255,255,0)"), | |
| name="Confidence Interval" | |
| )) | |
| fig2.update_layout( | |
| title=f'{ticker} 30 Days Forecast (Prophet)', | |
| plot_bgcolor='#111111', | |
| paper_bgcolor='#111111', | |
| font=dict(color='white', size=12), | |
| margin=dict(l=40, r=40, t=50, b=40), | |
| xaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.1)"), | |
| yaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.1)"), | |
| legend=dict(bgcolor="rgba(0,0,0,0.8)", bordercolor="white", borderwidth=1) | |
| ) | |
| # Create components figure | |
| components_fig = go.Figure() | |
| # Add components if they exist in the forecast | |
| if 'yearly' in forecast.columns: | |
| yearly_pattern = forecast.iloc[-365:] if len(forecast) > 365 else forecast | |
| components_fig.add_trace(go.Scatter( | |
| x=yearly_pattern['ds'], | |
| y=yearly_pattern['yearly'], | |
| mode='lines', | |
| name='Yearly Pattern', | |
| line=dict(color='#4ECDC4', width=2) | |
| )) | |
| components_fig.update_layout( | |
| title=f'{ticker} Forecast Components', | |
| xaxis_title='Date', | |
| yaxis_title='Value', | |
| plot_bgcolor='#111111', | |
| paper_bgcolor='#111111', | |
| font=dict(color='white', size=12), | |
| legend=dict(bgcolor="rgba(0,0,0,0.8)", bordercolor="white", borderwidth=1), | |
| margin=dict(l=40, r=40, t=50, b=40), | |
| xaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.1)"), | |
| yaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.1)") | |
| ) | |
| # For backwards compatibility, still create the matplotlib figure | |
| try: | |
| plt.style.use('dark_background') | |
| fig, ax = plt.subplots(figsize=(10, 8), facecolor='#111111') | |
| plt.rcParams.update({ | |
| 'text.color': 'white', | |
| 'axes.labelcolor': 'white', | |
| 'axes.edgecolor': 'white', | |
| 'xtick.color': 'white', | |
| 'ytick.color': 'white', | |
| 'grid.color': 'gray', | |
| 'figure.facecolor': '#111111', | |
| 'axes.facecolor': '#111111', | |
| 'savefig.facecolor': '#111111', | |
| }) | |
| m.plot_components(forecast, ax=ax) | |
| for ax in plt.gcf().get_axes(): | |
| ax.set_facecolor('#111111') | |
| for spine in ax.spines.values(): | |
| spine.set_color('white') | |
| ax.tick_params(colors='white') | |
| ax.title.set_color('white') | |
| for line in ax.get_lines(): | |
| if line.get_color() == 'b': | |
| line.set_color('#C678DD') | |
| else: | |
| line.set_color('#FF6B6B') | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', facecolor='#111111') | |
| buf.seek(0) | |
| plt.close(fig) | |
| img = Image.open(buf) | |
| return fig1, fig2, components_fig | |
| except Exception as e: | |
| print(f"Error with Matplotlib components: {e}") | |
| return fig1, fig2, components_fig | |
| except Exception as e: | |
| return f"Error: {str(e)}", f"Error: {str(e)}", None | |
| # Functions for technical analysis | |
| def smooth_moving_average(series: pd.Series, window: int) -> pd.Series: | |
| if len(series) < window or window <= 0: | |
| return pd.Series(series.mean(), index=series.index) | |
| result = pd.Series(index=series.index, dtype=float) | |
| result.iloc[:window] = series.iloc[:window].mean() | |
| for i in range(window, len(series)): | |
| result.iloc[i] = (result.iloc[i-1] * (window - 1) + series.iloc[i]) / window | |
| return result.ffill().bfill().fillna(series.mean()) | |
| def calculate_rsi(close: pd.Series, window: int = 14) -> pd.Series: | |
| if len(close) <= window: | |
| return pd.Series(50.0, index=close.index) | |
| delta = close.diff() | |
| gain = delta.where(delta > 0, 0.0) | |
| loss = -delta.where(delta < 0, 0.0) | |
| avg_gain = smooth_moving_average(gain, window) | |
| avg_loss = smooth_moving_average(loss, window) | |
| rs = np.where(avg_loss != 0, avg_gain / avg_loss, np.inf) | |
| rsi = 100.0 - (100.0 / (1.0 + rs)) | |
| return pd.Series(rsi, index=close.index).replace([np.inf, -np.inf], np.nan).ffill().bfill().fillna(50.0) | |
| def calculate_stochastic(high: pd.Series, low: pd.Series, close: pd.Series, k_window=14, d_window=3): | |
| if len(close) < k_window: | |
| return pd.Series(50.0, index=close.index), pd.Series(50.0, index=close.index) | |
| lowest = low.rolling(k_window, min_periods=1).min() | |
| highest = high.rolling(k_window, min_periods=1).max() | |
| k_pct = ((close - lowest) / (highest - lowest + 1e-10)) * 100 | |
| k_pct = k_pct.clip(0, 100) | |
| d_pct = k_pct.rolling(d_window, min_periods=1).mean() | |
| return k_pct.ffill().bfill().fillna(50.0), d_pct.ffill().bfill().fillna(50.0) | |
| def calculate_cci(high: pd.Series, low: pd.Series, close: pd.Series, window=20): | |
| if len(close) < window: | |
| return pd.Series(0.0, index=close.index) | |
| typical_price = (high + low + close) / 3.0 | |
| sma = typical_price.rolling(window, min_periods=1).mean() | |
| mean_deviation = (typical_price - sma).abs().rolling(window, min_periods=1).mean() | |
| cci = (typical_price - sma) / (0.015 * mean_deviation + 1e-10) | |
| return cci.ffill().bfill().fillna(0.0) | |
| # --- New Robust Helper Functions --- | |
| def calculate_sma_robust(series: pd.Series, window: int) -> pd.Series: | |
| if len(series) < window or window <= 0: | |
| return pd.Series(series.mean(), index=series.index) | |
| return series.rolling(window=window, min_periods=window).mean().ffill().bfill().fillna(series.mean()) | |
| def calculate_ema_robust(series: pd.Series, span: int) -> pd.Series: | |
| if len(series) < span or span <= 0: | |
| return pd.Series(series.mean(), index=series.index) | |
| return series.ewm(span=span, adjust=False, min_periods=span).mean().ffill().bfill().fillna(series.mean()) | |
| def calculate_macd_robust(close: pd.Series): | |
| ema12 = calculate_ema_robust(close, 12) | |
| ema26 = calculate_ema_robust(close, 26) | |
| macd_line = ema12 - ema26 | |
| signal_line = calculate_ema_robust(macd_line, 9) | |
| return macd_line, signal_line | |
| def calculate_bollinger_bands_robust(close: pd.Series, window=20, num_std=2.0): | |
| if len(close) < window: | |
| mid = pd.Series(close.mean(), index=close.index) | |
| return mid, mid, mid | |
| sma = calculate_sma_robust(close, window) | |
| std = close.rolling(window=window, min_periods=window).std().fillna(1e-10) | |
| upper = sma + num_std * std | |
| lower = sma - num_std * std | |
| return sma.ffill().bfill(), upper.ffill().bfill(), lower.ffill().bfill() | |
| # --- The Core Integration: generate_trading_signals --- | |
| def generate_trading_signals(df: pd.DataFrame) -> pd.DataFrame: | |
| """ | |
| Generates trading signals using strict thresholds to minimize false positives. | |
| Output columns match the expected names for the plotting functions. | |
| """ | |
| df = df.copy() | |
| close = df['Close'] | |
| has_hl = all(col in df.columns for col in ['High', 'Low']) | |
| has_vol = 'Volume' in df.columns | |
| high = df['High'] if has_hl else close | |
| low = df['Low'] if has_hl else close | |
| volume = df['Volume'] if has_vol else pd.Series(1.0, index=close.index) | |
| # Calculate indicators using robust methods | |
| rsi = calculate_rsi(close, window=14) | |
| stoch_k, stoch_d = calculate_stochastic(high, low, close, k_window=14, d_window=3) | |
| cci = calculate_cci(high, low, close, window=20) | |
| sma30 = calculate_sma_robust(close, 30) | |
| sma100 = calculate_sma_robust(close, 100) | |
| macd_line, macd_signal_line = calculate_macd_robust(close) | |
| _, bb_upper, bb_lower = calculate_bollinger_bands_robust(close, window=20, num_std=2.5) | |
| # CMF Calculation | |
| if has_hl and has_vol: | |
| mfv = ((close - low) - (high - close)) / (high - low + 1e-10) * volume | |
| cmf = mfv.rolling(window=20, min_periods=20).sum() / (volume.rolling(window=20, min_periods=20).sum() + 1e-10) | |
| cmf = cmf.ffill().bfill().fillna(0.0) | |
| else: | |
| cmf = pd.Series(0.0, index=close.index) | |
| # --- STRICT SIGNAL LOGIC (Output matches old function's schema) --- | |
| # Initialize all signal columns to 0 | |
| for col in ['MACD_Signal', 'RSI_Signal', 'BB_Signal', 'Stochastic_Signal', 'CMF_Signal', 'CCI_Signal']: | |
| df[col] = 0 | |
| # 1. MACD Signal | |
| macd_bull = ( | |
| (macd_line > macd_signal_line) & | |
| (macd_line.shift(1) <= macd_signal_line.shift(1)) & | |
| (macd_line > 0.5) & | |
| ((macd_line - macd_signal_line) > 0.8) | |
| ) | |
| macd_bear = ( | |
| (macd_line < macd_signal_line) & | |
| (macd_line.shift(1) >= macd_signal_line.shift(1)) & | |
| (macd_line < -0.5) & | |
| ((macd_signal_line - macd_line) > 0.8) | |
| ) | |
| df.loc[macd_bull, 'MACD_Signal'] = 1 | |
| df.loc[macd_bear, 'MACD_Signal'] = -1 | |
| # 2. RSI Signal | |
| df.loc[rsi < 15, 'RSI_Signal'] = 1 | |
| df.loc[rsi > 85, 'RSI_Signal'] = -1 | |
| # 3. Bollinger Bands Signal | |
| df.loc[close <= bb_lower, 'BB_Signal'] = 1 | |
| df.loc[close >= bb_upper, 'BB_Signal'] = -1 | |
| # 4. Stochastic Signal | |
| stoch_buy = (stoch_k < 5) & (stoch_d < 5) | |
| stoch_sell = (stoch_k > 95) & (stoch_d > 95) | |
| df.loc[stoch_buy, 'Stochastic_Signal'] = 1 | |
| df.loc[stoch_sell, 'Stochastic_Signal'] = -1 | |
| # 5. CMF Signal | |
| df.loc[cmf < -0.5, 'CMF_Signal'] = 1 | |
| df.loc[cmf > 0.5, 'CMF_Signal'] = -1 | |
| # 6. CCI Signal | |
| df.loc[cci < -250, 'CCI_Signal'] = 1 | |
| df.loc[cci > 250, 'CCI_Signal'] = -1 | |
| # Create the Combined_Signal by summing the individual signals | |
| df['Combined_Signal'] = df[['MACD_Signal', 'RSI_Signal', 'BB_Signal', | |
| 'Stochastic_Signal', 'CMF_Signal', 'CCI_Signal']].sum(axis=1) | |
| return df | |
| def plot_combined_signals(df, ticker): | |
| """ | |
| Creates a focused plot of JUST the combined signal strength. | |
| Bars are colored green for positive (buy) signals and red for negative (sell) signals. | |
| """ | |
| # Create a new figure | |
| fig = go.Figure() | |
| # Define colors based on the signal value (positive/negative) | |
| colors = ['#2ECC71' if val >= 0 else '#E74C3C' for val in df['Combined_Signal']] | |
| # Add the bar chart for the combined signal | |
| fig.add_trace(go.Bar( | |
| x=df.index, | |
| y=df['Combined_Signal'], | |
| name='Signal Strength', | |
| marker_color=colors, | |
| # Add hover text for clarity | |
| hovertemplate='<b>Date</b>: %{x}<br><b>Signal</b>: %{y}<extra></extra>' | |
| )) | |
| # Update the layout for a clean, focused look | |
| fig.update_layout( | |
| title=f'{ticker}', | |
| template='plotly_dark', | |
| xaxis_title='Date', | |
| yaxis_title='Signal Strength Score', | |
| yaxis=dict(zeroline=True, zerolinewidth=2, zerolinecolor='gray'), | |
| showlegend=False # Not needed for a single trace | |
| ) | |
| return fig | |
| def plot_individual_signals(df, ticker, x_range=None): | |
| fig = go.Figure() | |
| # Closing price | |
| fig.add_trace(go.Scatter( | |
| x=df.index, y=df['Close'], | |
| mode='lines', | |
| name='Closing Price', | |
| line=dict(color='#36A2EB', width=2) | |
| )) | |
| signal_colors = { | |
| 'MACD_Signal': {'buy': '#39FF14', 'sell': '#FF073A'}, | |
| 'RSI_Signal': {'buy': '#39FF14', 'sell': '#FF073A'}, | |
| 'BB_Signal': {'buy': '#39FF14', 'sell': '#FF073A'}, | |
| 'Stochastic_Signal': {'buy': '#39FF14', 'sell': '#FF073A'}, | |
| 'CMF_Signal': {'buy': '#39FF14', 'sell': '#FF073A'}, | |
| 'CCI_Signal': {'buy': '#39FF14', 'sell': '#FF073A'} | |
| } | |
| signal_names = ['MACD_Signal', 'RSI_Signal', 'BB_Signal', | |
| 'Stochastic_Signal', 'CMF_Signal', 'CCI_Signal'] | |
| for signal in signal_names: | |
| buy_signals = df[df[signal] == 1] | |
| sell_signals = df[df[signal] == -1] | |
| fig.add_trace(go.Scatter( | |
| x=buy_signals.index, y=buy_signals['Close'], | |
| mode='markers', | |
| marker=dict(symbol='triangle-up', size=12, color=signal_colors[signal]['buy']), | |
| name=f'{signal} Buy' | |
| )) | |
| fig.add_trace(go.Scatter( | |
| x=sell_signals.index, y=sell_signals['Close'], | |
| mode='markers', | |
| marker=dict(symbol='triangle-down', size=12, color=signal_colors[signal]['sell']), | |
| name=f'{signal} Sell' | |
| )) | |
| fig.update_layout( | |
| title=f'{ticker}', | |
| xaxis=dict( | |
| title='Date', | |
| showgrid=True, | |
| gridcolor="rgba(255,255,255,0.1)", | |
| range=x_range # ←←← Shared x-axis range | |
| ), | |
| yaxis=dict( | |
| title='Price', | |
| side='left', | |
| showgrid=True, | |
| gridcolor="rgba(255,255,255,0.1)" | |
| ), | |
| plot_bgcolor='#111111', | |
| paper_bgcolor='#111111', | |
| font=dict(color='white', size=12), | |
| legend=dict( | |
| orientation='h', # Horizontal legend | |
| yanchor='bottom', | |
| y=1.02, # Just above the plot | |
| xanchor='right', | |
| x=1, | |
| bgcolor="rgba(0,0,0,0.8)", | |
| bordercolor="white", | |
| borderwidth=1 | |
| ), | |
| margin=dict(l=40, r=40, t=80, b=40) # Extra top margin for legend | |
| ) | |
| return fig | |
| def technical_analysis(ticker, start_date, end_date): | |
| try: | |
| # Download stock data using the StockDataFetcher class | |
| df = StockDataFetcher.fetch_stock_data(ticker, start_date, end_date) | |
| # Generate signals | |
| df = generate_trading_signals(df) | |
| # Last 120 days for plotting | |
| df_last_120 = df.tail(120) | |
| # Plot combined signals | |
| fig_signals = plot_combined_signals(df_last_120, ticker) | |
| # Plot individual signals | |
| fig_individual_signals = plot_individual_signals(df_last_120, ticker) | |
| return fig_signals, fig_individual_signals | |
| except Exception as e: | |
| return f"Error: {str(e)}", f"Error: {str(e)}" | |
| # Custom CSS for better appearance | |
| custom_css = """ | |
| .gradio-container { | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| } | |
| .container { | |
| max-width: 1200px; | |
| margin: auto; | |
| } | |
| button#analyze-btn { | |
| background-color: #003366; | |
| color: white; | |
| border: none; | |
| } | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks(theme=gr.themes.Monochrome(), css=custom_css) as demo: | |
| gr.Markdown("# Advanced Stock Analysis & Forecasting App") | |
| gr.Markdown("Enter a stock ticker, start date, and end date to analyze and forecast stock prices.") | |
| with gr.Row(): | |
| ticker_input = gr.Textbox(label="Enter Stock Ticker", value="NVDA") | |
| start_date_input = gr.Textbox(label="Enter Start Date (YYYY-MM-DD)", value="2025-01-01") | |
| end_date_input = gr.Textbox(label="Enter End Date (YYYY-MM-DD)", value="2027-01-01") | |
| # Create tabs for different analysis types | |
| with gr.Tabs() as tabs: | |
| with gr.TabItem("TimesFM Forecast"): | |
| timesfm_button = gr.Button("Generate TimesFM Forecast") | |
| timesfm_plot = gr.Plot(label="TimesFM Stock Price Forecast") | |
| # Connect button to function | |
| timesfm_button.click( | |
| timesfm_forecast, | |
| inputs=[ticker_input, start_date_input, end_date_input], | |
| outputs=timesfm_plot | |
| ) | |
| with gr.TabItem("Prophet Forecast"): | |
| prophet_button = gr.Button("Generate Prophet Forecast") | |
| prophet_recent_plot = gr.Plot(label="Recent Stock Prices") | |
| prophet_forecast_plot = gr.Plot(label="Prophet 30-Day Forecast") | |
| prophet_components = gr.Plot(label="Forecast Components") # Changed from gr.Image to gr.Plot | |
| with gr.TabItem("Technical Analysis"): | |
| analysis_button = gr.Button("Generate Technical Analysis") | |
| individual_signals = gr.Plot(label="Individual Trading Signals") | |
| combined_signals = gr.Plot(label="Combined Trading Signals") | |
| # Connect button to function | |
| analysis_button.click( | |
| technical_analysis, | |
| inputs=[ticker_input, start_date_input, end_date_input], | |
| outputs=[combined_signals, individual_signals] | |
| ) | |
| # Connect button to function | |
| prophet_button.click( | |
| prophet_forecast, | |
| inputs=[ticker_input, start_date_input, end_date_input], | |
| outputs=[prophet_recent_plot, prophet_forecast_plot, prophet_components] | |
| ) | |
| # Launch the app | |
| demo.launch() |