Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from statsmodels.tsa.arima.model import ARIMA | |
| from tensorflow import keras # Use full TensorFlow Keras for custom_objects | |
| import joblib | |
| from datetime import datetime, timedelta | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # Load your saved models/data (upload these files to Colab if needed via !wget or upload) | |
| arima_model = ARIMA(pd.read_pickle('final_df.pkl')['Target'].tail(750), order=(5,1,0)).fit() # Re-fit or load saved | |
| # Fixed: Load LSTM with custom_objects to resolve 'mse' deserialization issue | |
| lstm_model = keras.models.load_model('lstm_model.h5', | |
| custom_objects={'mse': keras.losses.MeanSquaredError()}) | |
| target_scaler = joblib.load('target_scaler.pkl') | |
| final_df = pd.read_pickle('final_df.pkl') | |
| def forecast_aapl(start_date_str, model_type='LSTM', days=30): | |
| try: | |
| start_date = pd.to_datetime(start_date_str) | |
| # Simple forecast logic (adapt from your notebook) | |
| if model_type == 'ARIMA': | |
| forecast = arima_model.forecast(steps=days) | |
| else: # LSTM | |
| # Use last window for input (simplified) | |
| recent_data = final_df.tail(30).values # n_lags=30 | |
| X_input = recent_data.reshape(1, 30, recent_data.shape[1]) | |
| preds_scaled = lstm_model.predict(X_input)[0][0] | |
| forecast = np.full(days, target_scaler.inverse_transform([[preds_scaled]])[0][0]) # Placeholder; extend for multi-step | |
| future_dates = pd.date_range(start=start_date, periods=days, freq='D') | |
| results_df = pd.DataFrame({'Date': future_dates, 'Forecast': forecast}) | |
| # Plot | |
| fig, ax = plt.subplots(figsize=(10, 5)) | |
| ax.plot(results_df['Date'], results_df['Forecast'], marker='o', label=f'{model_type} Forecast') | |
| ax.set_title(f'AAPL {days}-Day Forecast from {start_date_str}') | |
| ax.set_xlabel('Date') | |
| ax.set_ylabel('Price ($)') | |
| ax.legend() | |
| ax.grid(True) | |
| plt.xticks(rotation=45) | |
| plt.tight_layout() | |
| return fig, results_df.to_string(index=False) | |
| except Exception as e: | |
| return None, f"Error: {str(e)}" | |
| # Gradio interface | |
| iface = gr.Interface( | |
| fn=forecast_aapl, | |
| inputs=[ | |
| gr.Textbox(label="Start Date (YYYY-MM-DD)", value="2020-03-01"), | |
| gr.Dropdown(choices=['LSTM', 'ARIMA'], label="Model", value='LSTM'), | |
| gr.Slider(1, 90, value=30, label="Forecast Days") | |
| ], | |
| outputs=[gr.Plot(label="Forecast Plot"), gr.Textbox(label="Forecast Table")], | |
| title="AAPL Stock Price Forecaster", | |
| description="Enter a start date to get future AAPL price forecasts using ARIMA or LSTM." | |
| ) | |
| # Launch locally first (test in Colab) | |
| iface.launch(share=True, debug=True) # This gives a public link for testing |