import streamlit as st import pandas as pd import torch import matplotlib.pyplot as plt import numpy as np from datetime import datetime, timedelta from models import ModelConfig, load_model_pipeline from data import DataConfig, process_input, fetch_data_with_fallback # Load the forecasting model pipeline @st.cache_resource def load_pipeline(model_name): """Load and cache the model pipeline""" return load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32) # Fetch data with caching @st.cache_data(ttl=3600) # Cache for 1 hour def fetch_data(source_name, days_back=180): """Fetch data from specified source with caching""" return fetch_data_with_fallback(source_name, days_back) # Streamlit app interface st.title("⚡ Electricity Market Price Forecasting") st.write(""" This application uses **Amazon Chronos** pretrained models for zero-shot time series forecasting on electricity market data. Select a model, choose your data source, and evaluate forecasting performance with backtesting on real ERCOT prices. """) # Model selection available_model_names = ModelConfig.get_model_names() selected_model_name = st.selectbox( "Select Forecasting Model:", options=available_model_names, index=0 # Default to first model (Chronos-2) ) # Load the selected model with st.spinner(f"Loading {selected_model_name}..."): pipeline = load_pipeline(selected_model_name) # Data source selection available_sources = DataConfig.get_source_names() data_source = st.radio( "Select Data Source:", available_sources + ["Custom Data"], index=0 ) # Input field for user-provided data if data_source == "Custom Data": user_input = st.text_area( "Enter time series data (comma-separated values):", "", height=150 ) data_source_used = "Custom" error_msg = None else: # Fetch data from selected source with st.spinner(f"Fetching data from {data_source}..."): default_data, data_source_used, error_msg = fetch_data(data_source) if error_msg: st.warning(f"⚠️ {error_msg}\nUsing sample data instead.") user_input = st.text_area( f"{data_source_used} - Daily Average Prices ($/MWh):", default_data.strip(), height=150 ) if "ERCOT" in data_source_used: st.info("💡 Live data from ERCOT's Day-Ahead Market (DAM SPP) - Daily average prices across all settlement points") try: time_series_data = process_input(user_input) except ValueError: st.error("Please make sure all values are numbers, separated by commas.") time_series_data = [] # Set empty data on error to prevent further processing # Select the forecast window for backtesting max_test_days = min(64, len(time_series_data) - 10) if len(time_series_data) > 10 else 1 prediction_length = st.slider( "Forecast Horizon (Days to Backtest)", min_value=1, max_value=max_test_days, value=min(14, max_test_days), help="The model will use historical context to forecast the last N days, then compare predictions with actual values to evaluate performance." ) # If data is valid, perform the forecast if time_series_data: # Split data into context (historical) and test context_length = len(time_series_data) - prediction_length context_data = time_series_data[:context_length] test_data = time_series_data[context_length:] # Create timestamps end_date = datetime.now() start_date = end_date - timedelta(days=len(time_series_data) - 1) all_dates = pd.date_range(start=start_date, periods=len(time_series_data), freq='D') context_dates = all_dates[:context_length] test_dates = all_dates[context_length:] # Create a DataFrame with context for the model context_df = pd.DataFrame({ 'timestamp': context_dates, 'target': context_data, 'id': 'ercot_prices' }) # Make the forecast using the model with st.spinner("Generating forecast..."): pred_df = pipeline.predict_df( context_df, prediction_length=prediction_length, quantile_levels=[0.1, 0.5, 0.9], id_column="id", timestamp_column="timestamp", target="target", ) # Extract predictions median = pred_df["predictions"].values low = pred_df["0.1"].values high = pred_df["0.9"].values # Calculate error metrics mae = np.mean(np.abs(np.array(test_data) - median)) mape = np.mean(np.abs((np.array(test_data) - median) / np.array(test_data))) * 100 rmse = np.sqrt(np.mean((np.array(test_data) - median) ** 2)) # Plot the historical and forecasted data with dates plt.figure(figsize=(14, 7)) plt.plot(context_dates, context_data, color="royalblue", label="Historical Context", linewidth=2) plt.plot(test_dates, test_data, color="green", label="Actual Values", linewidth=2, marker='o', markersize=4) plt.plot(test_dates, median, color="tomato", label="Forecast", linewidth=2, linestyle='--', marker='s', markersize=4) plt.fill_between(test_dates, low, high, color="tomato", alpha=0.3, label="80% Prediction Interval") plt.axvline(x=context_dates[-1], color='gray', linestyle=':', linewidth=1, alpha=0.7) plt.text(context_dates[-1], plt.ylim()[1]*0.95, ' Forecast Window', fontsize=10, color='gray') plt.xlabel("Date") plt.ylabel("Price ($/MWh)") plt.title(f"ERCOT Electricity Price Forecast - {prediction_length} Day Test Window") plt.legend(loc='best') plt.grid(alpha=0.3) plt.xticks(rotation=45) plt.tight_layout() # Show the plot in the Streamlit app st.pyplot(plt) # Display forecast statistics and error metrics st.write("### Model Performance Metrics") col1, col2, col3, col4 = st.columns(4) with col1: st.metric("MAE", f"${mae:.2f}") with col2: st.metric("RMSE", f"${rmse:.2f}") with col3: st.metric("MAPE", f"{mape:.2f}%") with col4: st.metric("Avg Actual", f"${np.mean(test_data):.2f}/MWh") # Show detailed comparison table with st.expander("View Detailed Comparison"): comparison_df = pd.DataFrame({ 'Date': test_dates.strftime('%Y-%m-%d'), 'Actual': test_data, 'Forecast': median.round(2), 'Error': (median - np.array(test_data)).round(2), 'Error %': ((median - np.array(test_data)) / np.array(test_data) * 100).round(2) }) st.dataframe(comparison_df, use_container_width=True) # Note for comments, feedback, or questions st.write("### About") st.write(""" **Features:** - 🤖 Multiple pretrained models (7 options: Chronos-2, Chronos-T5, TiRex) - 📊 Real-time ERCOT electricity market data (180+ days) - 🎯 Backtesting with error metrics (MAE, RMSE, MAPE) - 📈 Visual comparison of forecasts vs actual values - 🔧 Modular architecture for easy extension **Models:** - **Chronos-2**: Amazon's latest (46M-120M params) - **Chronos-T5**: Original family (8M-710M params) - **TiRex**: NX-AI's xLSTM-based model (35M params) For questions or feedback, reach out on [LinkedIn](https://www.linkedin.com/in/javadbayazi/). """)