Spaces:
Running
Running
| import streamlit as st | |
| import pandas as pd | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from datetime import datetime, timedelta | |
| from models import ModelConfig, load_model_pipeline | |
| from data import DataConfig, process_input, fetch_data_with_fallback | |
| # Load the forecasting model pipeline | |
| def load_pipeline(model_name): | |
| """Load and cache the model pipeline""" | |
| return load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32) | |
| # Fetch data with caching | |
| # Cache for 1 hour | |
| def fetch_data(source_name, days_back=180): | |
| """Fetch data from specified source with caching""" | |
| return fetch_data_with_fallback(source_name, days_back) | |
| # Streamlit app interface | |
| st.title("⚡ Electricity Market Price Forecasting") | |
| st.write(""" | |
| This application uses **Amazon Chronos** pretrained models for zero-shot time series forecasting on electricity market data. | |
| Select a model, choose your data source, and evaluate forecasting performance with backtesting on real ERCOT prices. | |
| """) | |
| # Model selection | |
| available_model_names = ModelConfig.get_model_names() | |
| selected_model_name = st.selectbox( | |
| "Select Forecasting Model:", | |
| options=available_model_names, | |
| index=0 # Default to first model (Chronos-2) | |
| ) | |
| # Load the selected model | |
| with st.spinner(f"Loading {selected_model_name}..."): | |
| pipeline = load_pipeline(selected_model_name) | |
| # Data source selection | |
| available_sources = DataConfig.get_source_names() | |
| data_source = st.radio( | |
| "Select Data Source:", | |
| available_sources + ["Custom Data"], | |
| index=0 | |
| ) | |
| # Input field for user-provided data | |
| if data_source == "Custom Data": | |
| user_input = st.text_area( | |
| "Enter time series data (comma-separated values):", | |
| "", | |
| height=150 | |
| ) | |
| data_source_used = "Custom" | |
| error_msg = None | |
| else: | |
| # Fetch data from selected source | |
| with st.spinner(f"Fetching data from {data_source}..."): | |
| default_data, data_source_used, error_msg = fetch_data(data_source) | |
| if error_msg: | |
| st.warning(f"⚠️ {error_msg}\nUsing sample data instead.") | |
| user_input = st.text_area( | |
| f"{data_source_used} - Daily Average Prices ($/MWh):", | |
| default_data.strip(), | |
| height=150 | |
| ) | |
| if "ERCOT" in data_source_used: | |
| st.info("💡 Live data from ERCOT's Day-Ahead Market (DAM SPP) - Daily average prices across all settlement points") | |
| try: | |
| time_series_data = process_input(user_input) | |
| except ValueError: | |
| st.error("Please make sure all values are numbers, separated by commas.") | |
| time_series_data = [] # Set empty data on error to prevent further processing | |
| # Select the forecast window for backtesting | |
| max_test_days = min(64, len(time_series_data) - 10) if len(time_series_data) > 10 else 1 | |
| prediction_length = st.slider( | |
| "Forecast Horizon (Days to Backtest)", | |
| min_value=1, | |
| max_value=max_test_days, | |
| value=min(14, max_test_days), | |
| help="The model will use historical context to forecast the last N days, then compare predictions with actual values to evaluate performance." | |
| ) | |
| # If data is valid, perform the forecast | |
| if time_series_data: | |
| # Split data into context (historical) and test | |
| context_length = len(time_series_data) - prediction_length | |
| context_data = time_series_data[:context_length] | |
| test_data = time_series_data[context_length:] | |
| # Create timestamps | |
| end_date = datetime.now() | |
| start_date = end_date - timedelta(days=len(time_series_data) - 1) | |
| all_dates = pd.date_range(start=start_date, periods=len(time_series_data), freq='D') | |
| context_dates = all_dates[:context_length] | |
| test_dates = all_dates[context_length:] | |
| # Create a DataFrame with context for the model | |
| context_df = pd.DataFrame({ | |
| 'timestamp': context_dates, | |
| 'target': context_data, | |
| 'id': 'ercot_prices' | |
| }) | |
| # Make the forecast using the model | |
| with st.spinner("Generating forecast..."): | |
| pred_df = pipeline.predict_df( | |
| context_df, | |
| prediction_length=prediction_length, | |
| quantile_levels=[0.1, 0.5, 0.9], | |
| id_column="id", | |
| timestamp_column="timestamp", | |
| target="target", | |
| ) | |
| # Extract predictions | |
| median = pred_df["predictions"].values | |
| low = pred_df["0.1"].values | |
| high = pred_df["0.9"].values | |
| # Calculate error metrics | |
| mae = np.mean(np.abs(np.array(test_data) - median)) | |
| mape = np.mean(np.abs((np.array(test_data) - median) / np.array(test_data))) * 100 | |
| rmse = np.sqrt(np.mean((np.array(test_data) - median) ** 2)) | |
| # Plot the historical and forecasted data with dates | |
| plt.figure(figsize=(14, 7)) | |
| plt.plot(context_dates, context_data, color="royalblue", label="Historical Context", linewidth=2) | |
| plt.plot(test_dates, test_data, color="green", label="Actual Values", linewidth=2, marker='o', markersize=4) | |
| plt.plot(test_dates, median, color="tomato", label="Forecast", linewidth=2, linestyle='--', marker='s', markersize=4) | |
| plt.fill_between(test_dates, low, high, color="tomato", alpha=0.3, label="80% Prediction Interval") | |
| plt.axvline(x=context_dates[-1], color='gray', linestyle=':', linewidth=1, alpha=0.7) | |
| plt.text(context_dates[-1], plt.ylim()[1]*0.95, ' Forecast Window', fontsize=10, color='gray') | |
| plt.xlabel("Date") | |
| plt.ylabel("Price ($/MWh)") | |
| plt.title(f"ERCOT Electricity Price Forecast - {prediction_length} Day Test Window") | |
| plt.legend(loc='best') | |
| plt.grid(alpha=0.3) | |
| plt.xticks(rotation=45) | |
| plt.tight_layout() | |
| # Show the plot in the Streamlit app | |
| st.pyplot(plt) | |
| # Display forecast statistics and error metrics | |
| st.write("### Model Performance Metrics") | |
| col1, col2, col3, col4 = st.columns(4) | |
| with col1: | |
| st.metric("MAE", f"${mae:.2f}") | |
| with col2: | |
| st.metric("RMSE", f"${rmse:.2f}") | |
| with col3: | |
| st.metric("MAPE", f"{mape:.2f}%") | |
| with col4: | |
| st.metric("Avg Actual", f"${np.mean(test_data):.2f}/MWh") | |
| # Show detailed comparison table | |
| with st.expander("View Detailed Comparison"): | |
| comparison_df = pd.DataFrame({ | |
| 'Date': test_dates.strftime('%Y-%m-%d'), | |
| 'Actual': test_data, | |
| 'Forecast': median.round(2), | |
| 'Error': (median - np.array(test_data)).round(2), | |
| 'Error %': ((median - np.array(test_data)) / np.array(test_data) * 100).round(2) | |
| }) | |
| st.dataframe(comparison_df, use_container_width=True) | |
| # Note for comments, feedback, or questions | |
| st.write("### About") | |
| st.write(""" | |
| **Features:** | |
| - 🤖 Multiple pretrained models (7 options: Chronos-2, Chronos-T5, TiRex) | |
| - 📊 Real-time ERCOT electricity market data (180+ days) | |
| - 🎯 Backtesting with error metrics (MAE, RMSE, MAPE) | |
| - 📈 Visual comparison of forecasts vs actual values | |
| - 🔧 Modular architecture for easy extension | |
| **Models:** | |
| - **Chronos-2**: Amazon's latest (46M-120M params) | |
| - **Chronos-T5**: Original family (8M-710M params) | |
| - **TiRex**: NX-AI's xLSTM-based model (35M params) | |
| For questions or feedback, reach out on [LinkedIn](https://www.linkedin.com/in/javadbayazi/). | |
| """) |