Shaikat01's picture
Update app.py
fe4ed55 verified
raw
history blame
13 kB
import gradio as gr
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from datetime import datetime, timedelta
import yfinance as yf
from statsmodels.tsa.arima.model import ARIMA
from prophet import Prophet
import warnings
warnings.filterwarnings('ignore')
# NO PRE-TRAINED MODELS - Train on demand with user's data
# This avoids the 50GB storage limit issue
def fetch_stock_data(ticker, days=730):
"""Fetch stock data from Yahoo Finance"""
try:
end_date = datetime.now()
start_date = end_date - timedelta(days=days)
df = yf.download(ticker, start=start_date, end=end_date, progress=False)
if df.empty:
return None, f"No data found for ticker: {ticker}"
df = df[['Close']].copy()
df.columns = ['Price']
df = df.dropna()
return df, None
except Exception as e:
return None, str(e)
def make_arima_forecast(data, days):
"""Train ARIMA and make forecast"""
try:
# Train ARIMA model on-the-fly
model = ARIMA(data['Price'], order=(1, 1, 1))
fitted = model.fit()
forecast = fitted.forecast(steps=days)
return forecast.values
except Exception as e:
print(f"ARIMA Error: {e}")
return None
def make_prophet_forecast(data, days):
"""Train Prophet and make forecast"""
try:
# Prepare data for Prophet
prophet_data = pd.DataFrame({
'ds': data.index,
'y': data['Price'].values
})
# Create and train model on-the-fly
model = Prophet(
daily_seasonality=False,
weekly_seasonality=True,
yearly_seasonality=True,
changepoint_prior_scale=0.05,
seasonality_mode='multiplicative'
)
model.fit(prophet_data)
# Make forecast
future = model.make_future_dataframe(periods=days)
forecast = model.predict(future)
return forecast['yhat'].tail(days).values
except Exception as e:
print(f"Prophet Error: {e}")
return None
def make_simple_ml_forecast(data, days):
"""Simple exponential smoothing forecast (lightweight alternative to LSTM)"""
try:
from statsmodels.tsa.holtwinters import ExponentialSmoothing
# Train exponential smoothing model
model = ExponentialSmoothing(
data['Price'],
seasonal_periods=30,
trend='add',
seasonal='add'
)
fitted = model.fit()
forecast = fitted.forecast(steps=days)
return forecast.values
except Exception as e:
print(f"ML Forecast Error: {e}")
return None
def calculate_moving_average_forecast(data, days, window=20):
"""Simple moving average forecast"""
try:
ma = data['Price'].rolling(window=window).mean().iloc[-1]
trend = (data['Price'].iloc[-1] - data['Price'].iloc[-window]) / window
forecast = [ma + trend * i for i in range(1, days + 1)]
return np.array(forecast)
except Exception as e:
print(f"MA Error: {e}")
return None
def create_forecast_plot(historical_data, forecasts, ticker, model_names):
"""Create interactive plotly chart"""
fig = go.Figure()
# Show last 90 days of historical data for clarity
recent_data = historical_data.tail(90)
# Historical data
fig.add_trace(go.Scatter(
x=recent_data.index,
y=recent_data['Price'],
mode='lines',
name='Historical Price',
line=dict(color='blue', width=2)
))
# Generate future dates
last_date = historical_data.index[-1]
future_dates = pd.date_range(start=last_date + timedelta(days=1),
periods=len(forecasts[0]))
# Plot forecasts
colors = ['red', 'purple', 'orange', 'green']
for i, (forecast, name) in enumerate(zip(forecasts, model_names)):
if forecast is not None:
fig.add_trace(go.Scatter(
x=future_dates,
y=forecast,
mode='lines+markers',
name=f'{name} Forecast',
line=dict(color=colors[i], width=2, dash='dash'),
marker=dict(size=4)
))
# Add vertical line at prediction start
fig.add_vline(
x=last_date,
line_dash="dash",
line_color="gray",
annotation_text="Forecast Start"
)
fig.update_layout(
title=f'{ticker} Stock Price Forecast',
xaxis_title='Date',
yaxis_title='Price ($)',
hovermode='x unified',
template='plotly_white',
height=600,
showlegend=True,
legend=dict(
yanchor="top",
y=0.99,
xanchor="left",
x=0.01,
bgcolor="rgba(255, 255, 255, 0.8)"
)
)
return fig
def predict_stock(ticker, forecast_days, model_choice):
"""Main prediction function"""
# Validate inputs
if not ticker:
return None, "❌ Please enter a stock ticker symbol", None
ticker = ticker.upper().strip()
# Show loading message
status_msg = f"🔄 Fetching data for {ticker}..."
# Fetch data (2 years for better training)
data, error = fetch_stock_data(ticker, days=730)
if error:
return None, f"❌ Error: {error}", None
if len(data) < 60:
return None, f"❌ Insufficient data for {ticker}. Need at least 60 days of history.", None
status_msg += f"\n✅ Found {len(data)} days of data\n🔄 Training models..."
# Make forecasts based on model choice
forecasts = []
model_names = []
if model_choice in ["All Models", "ARIMA"]:
arima_forecast = make_arima_forecast(data, forecast_days)
if arima_forecast is not None:
forecasts.append(arima_forecast)
model_names.append("ARIMA")
if model_choice in ["All Models", "Prophet"]:
prophet_forecast = make_prophet_forecast(data, forecast_days)
if prophet_forecast is not None:
forecasts.append(prophet_forecast)
model_names.append("Prophet")
if model_choice in ["All Models", "Exp. Smoothing"]:
ml_forecast = make_simple_ml_forecast(data, forecast_days)
if ml_forecast is not None:
forecasts.append(ml_forecast)
model_names.append("Exp. Smoothing")
if model_choice in ["All Models", "Moving Average"]:
ma_forecast = calculate_moving_average_forecast(data, forecast_days)
if ma_forecast is not None:
forecasts.append(ma_forecast)
model_names.append("Moving Average")
if not forecasts:
return None, "❌ Failed to generate forecasts. Please try again.", None
# Create plot
fig = create_forecast_plot(data, forecasts, ticker, model_names)
# Create forecast table
future_dates = pd.date_range(
start=data.index[-1] + timedelta(days=1),
periods=forecast_days
)
forecast_df = pd.DataFrame({'Date': future_dates.strftime('%Y-%m-%d')})
for forecast, name in zip(forecasts, model_names):
forecast_df[f'{name} ($)'] = np.round(forecast, 2)
# Calculate statistics
current_price = data['Price'].iloc[-1]
avg_forecast = np.mean([f[-1] for f in forecasts])
avg_change = ((avg_forecast - current_price) / current_price) * 100
# Summary statistics
summary = f"""
## 📊 Forecast Summary for **{ticker}**
### Current Information
- **Current Price**: ${current_price:.2f}
- **Data Points**: {len(data)} days
- **Last Updated**: {data.index[-1].strftime('%Y-%m-%d')}
### Forecast Details
- **Forecast Period**: {forecast_days} days
- **Models Used**: {', '.join(model_names)}
- **End Date**: {future_dates[-1].strftime('%Y-%m-%d')}
### Predicted Prices (Day {forecast_days})
"""
for forecast, name in zip(forecasts, model_names):
final_price = forecast[-1]
change = ((final_price - current_price) / current_price) * 100
emoji = "📈" if change > 0 else "📉"
summary += f"\n{emoji} **{name}**: ${final_price:.2f} ({change:+.2f}%)"
summary += f"""
### Average Prediction
- **Average Price**: ${avg_forecast:.2f}
- **Expected Change**: {avg_change:+.2f}%
---
⚠️ **Risk Warning**: Past performance does not guarantee future results. Use for research only.
"""
return fig, summary, forecast_df
# Create Gradio Interface
with gr.Blocks(title="Stock Price Forecasting", theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 📈 AI Stock Price Forecasting
### Predict future stock prices using multiple time-series models
This app trains models **in real-time** using the latest stock data. No pre-trained models needed!
**✨ Features:**
- Real-time data from Yahoo Finance
- Multiple forecasting algorithms
- Interactive visualizations
- No storage limits - models train on demand
---
"""
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### 🎯 Input Parameters")
ticker_input = gr.Textbox(
label="📊 Stock Ticker Symbol",
placeholder="e.g., AAPL, GOOGL, TSLA, MSFT",
value="AAPL",
info="Enter any valid stock ticker"
)
forecast_days = gr.Slider(
minimum=7,
maximum=90,
value=30,
step=1,
label="📅 Forecast Period (Days)",
info="Number of days to forecast"
)
model_choice = gr.Radio(
choices=["All Models", "ARIMA", "Prophet", "Exp. Smoothing", "Moving Average"],
value="All Models",
label="🤖 Select Model(s)",
info="Choose which forecasting model to use"
)
predict_btn = gr.Button(
"🔮 Generate Forecast",
variant="primary",
size="lg",
scale=1
)
gr.Markdown(
"""
### 💡 Quick Tips
- Use 30 days for short-term
- Use 60-90 days for trends
- "All Models" shows comparison
"""
)
with gr.Column(scale=2):
output_plot = gr.Plot(label="📈 Forecast Visualization")
with gr.Row():
with gr.Column():
output_summary = gr.Markdown(label="📋 Analysis Summary")
with gr.Row():
output_table = gr.Dataframe(
label="📊 Detailed Forecast Table",
wrap=True,
interactive=False,
height=400
)
# Examples
gr.Markdown("### 🎯 Try These Examples")
gr.Examples(
examples=[
["AAPL", 30, "All Models"],
["GOOGL", 14, "Prophet"],
["TSLA", 60, "ARIMA"],
["MSFT", 45, "Exp. Smoothing"],
["NVDA", 30, "All Models"],
],
inputs=[ticker_input, forecast_days, model_choice],
label="Popular Stocks"
)
# Connect the button to the function
predict_btn.click(
fn=predict_stock,
inputs=[ticker_input, forecast_days, model_choice],
outputs=[output_plot, output_summary, output_table]
)
gr.Markdown(
"""
---
## 📚 About the Models
| Model | Best For | Speed | Accuracy |
|-------|----------|-------|----------|
| **ARIMA** | Short-term, stationary data | ⚡⚡⚡ Fast | ⭐⭐⭐ |
| **Prophet** | Seasonality, trends | ⚡⚡ Medium | ⭐⭐⭐⭐ |
| **Exp. Smoothing** | Smooth trends | ⚡⚡⚡ Fast | ⭐⭐⭐ |
| **Moving Average** | Simple baseline | ⚡⚡⚡⚡ Very Fast | ⭐⭐ |
## ⚠️ Important Disclaimer
**This tool is for educational and research purposes only.**
- Stock predictions are inherently uncertain
- Past performance ≠ future results
- Always do your own research
- Consult financial advisors before investing
- Never invest more than you can afford to lose
## 🔒 Privacy & Data
- No data is stored permanently
- Models train fresh for each prediction
- Stock data fetched from Yahoo Finance API
- No personal information collected
---
**Made with ❤️ using Gradio & Python**
"""
)
# Launch the app
if __name__ == "__main__":
demo.launch(
share=False,
show_error=True,
server_name="0.0.0.0",
server_port=7860
)