hrid0yyy's picture
Update app.py
41ce399 verified
import gradio as gr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from statsmodels.tsa.arima.model import ARIMA
from tensorflow import keras # Use full TensorFlow Keras for custom_objects
import joblib
from datetime import datetime, timedelta
import warnings
warnings.filterwarnings('ignore')
# Load your saved models/data (upload these files to Colab if needed via !wget or upload)
arima_model = ARIMA(pd.read_pickle('final_df.pkl')['Target'].tail(750), order=(5,1,0)).fit() # Re-fit or load saved
# Fixed: Load LSTM with custom_objects to resolve 'mse' deserialization issue
lstm_model = keras.models.load_model('lstm_model.h5',
custom_objects={'mse': keras.losses.MeanSquaredError()})
target_scaler = joblib.load('target_scaler.pkl')
final_df = pd.read_pickle('final_df.pkl')
def forecast_aapl(start_date_str, model_type='LSTM', days=30):
try:
start_date = pd.to_datetime(start_date_str)
# Simple forecast logic (adapt from your notebook)
if model_type == 'ARIMA':
forecast = arima_model.forecast(steps=days)
else: # LSTM
# Use last window for input (simplified)
recent_data = final_df.tail(30).values # n_lags=30
X_input = recent_data.reshape(1, 30, recent_data.shape[1])
preds_scaled = lstm_model.predict(X_input)[0][0]
forecast = np.full(days, target_scaler.inverse_transform([[preds_scaled]])[0][0]) # Placeholder; extend for multi-step
future_dates = pd.date_range(start=start_date, periods=days, freq='D')
results_df = pd.DataFrame({'Date': future_dates, 'Forecast': forecast})
# Plot
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(results_df['Date'], results_df['Forecast'], marker='o', label=f'{model_type} Forecast')
ax.set_title(f'AAPL {days}-Day Forecast from {start_date_str}')
ax.set_xlabel('Date')
ax.set_ylabel('Price ($)')
ax.legend()
ax.grid(True)
plt.xticks(rotation=45)
plt.tight_layout()
return fig, results_df.to_string(index=False)
except Exception as e:
return None, f"Error: {str(e)}"
# Gradio interface
iface = gr.Interface(
fn=forecast_aapl,
inputs=[
gr.Textbox(label="Start Date (YYYY-MM-DD)", value="2020-03-01"),
gr.Dropdown(choices=['LSTM', 'ARIMA'], label="Model", value='LSTM'),
gr.Slider(1, 90, value=30, label="Forecast Days")
],
outputs=[gr.Plot(label="Forecast Plot"), gr.Textbox(label="Forecast Table")],
title="AAPL Stock Price Forecaster",
description="Enter a start date to get future AAPL price forecasts using ARIMA or LSTM."
)
# Launch locally first (test in Colab)
iface.launch(share=True, debug=True) # This gives a public link for testing