Shaikat01's picture
Update app.py
7e6a7e0 verified
raw
history blame
8.56 kB
import gradio as gr
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from datetime import datetime, timedelta
import pickle
import os
import warnings
warnings.filterwarnings('ignore')
# TensorFlow/Keras imports
from tensorflow.keras.models import load_model
from sklearn.preprocessing import MinMaxScaler
# ARIMA and Prophet
from statsmodels.tsa.arima.model import ARIMA
from prophet import Prophet
# --------------------------
# Load models safely
# --------------------------
def load_models():
try:
# ARIMA
with open('arima_model.pkl', 'rb') as f:
arima_model = pickle.load(f)
# Prophet
with open('prophet_model.pkl', 'rb') as f:
prophet_model = pickle.load(f)
# LSTM + scaler
from tensorflow.keras.models import load_model
lstm_model = load_model('lstm_model.keras')
return arima_model, prophet_model, lstm_model, scaler
except Exception as e:
print(f"Error loading models: {e}")
return None, None, None, None
arima_model, prophet_model, lstm_model, scaler = load_models()
SEQ_LENGTH = 60
# --------------------------
# Fetch stock data
# --------------------------
def fetch_stock_data(ticker, days=365):
"""
Fetch stock data from local CSV fallback.
Community Spaces cannot access the internet.
"""
ticker = ticker.upper().strip()
filename = f"{ticker}.csv"
if not os.path.exists(filename):
return None, f"No data found for {ticker}. Upload {ticker}.csv in the Space root."
df = pd.read_csv(filename, index_col=0, parse_dates=True)
if 'Close' in df.columns:
df = df[['Close']].copy()
else:
df.columns = ['Price']
df.columns = ['Price']
df['Price'] = pd.to_numeric(df['Price'], errors='coerce')
df = df.dropna()
df = df.tail(days)
if df.empty:
return None, f"No valid data found in {filename} for {ticker}."
return df, None
# --------------------------
# Forecasting functions
# --------------------------
def make_arima_forecast(data, days):
try:
data['Price'] = pd.to_numeric(data['Price'], errors='coerce')
data = data.dropna()
model = ARIMA(data['Price'], order=(1,1,1))
fitted = model.fit()
forecast = fitted.forecast(steps=days)
return forecast.values
except Exception as e:
print(f"ARIMA Error: {e}")
return None
def make_prophet_forecast(data, days):
try:
prophet_data = pd.DataFrame({'ds': data.index, 'y': data['Price'].values})
model = Prophet(
daily_seasonality=True,
weekly_seasonality=True,
yearly_seasonality=True,
changepoint_prior_scale=0.05
)
model.fit(prophet_data)
future = model.make_future_dataframe(periods=days)
forecast = model.predict(future)
return forecast['yhat'].tail(days).values
except Exception as e:
print(f"Prophet Error: {e}")
return None
def make_lstm_forecast(data, days, model, scaler, seq_length=60):
try:
scaled_data = scaler.transform(data[['Price']])
last_sequence = scaled_data[-seq_length:].reshape(1, seq_length, 1)
predictions = []
current_sequence = last_sequence.copy()
for _ in range(days):
pred = model.predict(current_sequence, verbose=0)
predictions.append(pred[0,0])
current_sequence = np.append(current_sequence[:,1:,:], pred.reshape(1,1,1), axis=1)
predictions = scaler.inverse_transform(np.array(predictions).reshape(-1,1))
return predictions.flatten()
except Exception as e:
print(f"LSTM Error: {e}")
return None
# --------------------------
# Plot function
# --------------------------
def create_forecast_plot(historical_data, forecasts, ticker, model_names):
fig = go.Figure()
fig.add_trace(go.Scatter(
x=historical_data.index,
y=historical_data['Price'],
mode='lines',
name='Historical Price',
line=dict(color='blue', width=2)
))
last_date = pd.to_datetime(historical_data.index[-1])
future_dates = pd.date_range(
start=last_date + timedelta(days=1),
periods=len(forecasts[0])
)
colors = ['red', 'purple', 'orange']
for i, (forecast, name) in enumerate(zip(forecasts, model_names)):
if forecast is not None:
fig.add_trace(go.Scatter(
x=future_dates,
y=forecast,
mode='lines+markers',
name=f'{name} Forecast',
line=dict(color=colors[i], width=2, dash='dash'),
marker=dict(size=6)
))
fig.update_layout(
title=f'{ticker} Stock Price Forecast',
xaxis_title='Date',
yaxis_title='Price ($)',
hovermode='x unified',
template='plotly_white',
height=600,
showlegend=True
)
return fig
# --------------------------
# Main prediction function
# --------------------------
def predict_stock(ticker, forecast_days, model_choice):
if not ticker:
return None, "Please enter a stock ticker symbol", None
data, error = fetch_stock_data(ticker, days=730)
if error:
return None, f"Error: {error}", None
forecasts = []
model_names = []
if model_choice in ["All Models", "ARIMA"]:
arima_forecast = make_arima_forecast(data, forecast_days)
if arima_forecast is not None:
forecasts.append(arima_forecast)
model_names.append("ARIMA")
if model_choice in ["All Models", "Prophet"]:
prophet_forecast = make_prophet_forecast(data, forecast_days)
if prophet_forecast is not None:
forecasts.append(prophet_forecast)
model_names.append("Prophet")
if model_choice in ["All Models", "LSTM"] and lstm_model is not None:
lstm_forecast = make_lstm_forecast(data, forecast_days, lstm_model, scaler, SEQ_LENGTH)
if lstm_forecast is not None:
forecasts.append(lstm_forecast)
model_names.append("LSTM")
if not forecasts:
return None, "Failed to generate forecasts.", None
fig = create_forecast_plot(data, forecasts, ticker, model_names)
# Forecast table
future_dates = pd.date_range(
start=pd.to_datetime(data.index[-1]) + timedelta(days=1),
periods=forecast_days
)
forecast_df = pd.DataFrame({'Date': future_dates.strftime('%Y-%m-%d')})
for forecast, name in zip(forecasts, model_names):
forecast_df[f'{name} Prediction ($)'] = np.round(forecast, 2)
# Summary
summary = f"๐Ÿ“Š **Forecast Summary for {ticker}**\n\n" \
f"- Current Price: ${data['Price'].iloc[-1]:.2f}\n" \
f"- Forecast Period: {forecast_days} days\n" \
f"- Models Used: {', '.join(model_names)}\n\n" \
f"**Predicted Price Range (Day {forecast_days}):**"
for forecast, name in zip(forecasts, model_names):
final_price = forecast[-1]
change = ((final_price - data['Price'].iloc[-1]) / data['Price'].iloc[-1]) * 100
summary += f"\n- {name}: ${final_price:.2f} ({change:+.2f}%)"
return fig, summary, forecast_df
# --------------------------
# Gradio Interface
# --------------------------
with gr.Blocks(title="Stock Price Forecasting", theme=gr.themes.Soft()) as demo:
gr.Markdown("# ๐Ÿ“ˆ Stock Price Forecasting App\nPredict future stock prices using ARIMA, Prophet, and LSTM models.\nUpload CSV files in the Space root for offline use.")
with gr.Row():
with gr.Column(scale=1):
ticker_input = gr.Textbox(label="Stock Ticker Symbol", placeholder="e.g., AAPL", value="AAPL")
forecast_days = gr.Slider(minimum=1, maximum=90, value=30, step=1, label="Forecast Days")
model_choice = gr.Radio(choices=["All Models", "ARIMA", "Prophet", "LSTM"], value="All Models", label="Select Model(s)")
predict_btn = gr.Button("๐Ÿ”ฎ Generate Forecast", variant="primary")
with gr.Column(scale=2):
output_plot = gr.Plot(label="Forecast Visualization")
output_summary = gr.Markdown(label="Forecast Summary")
output_table = gr.Dataframe(label="Detailed Forecast", interactive=False)
predict_btn.click(fn=predict_stock, inputs=[ticker_input, forecast_days, model_choice],
outputs=[output_plot, output_summary, output_table])
# Launch
if __name__ == "__main__":
demo.launch()