import streamlit as st import pandas as pd import torch import matplotlib.pyplot as plt import numpy as np from datetime import datetime, timedelta from models import ModelConfig, load_model_pipeline from data import DataConfig, process_input, fetch_data_with_fallback # Load the forecasting model pipeline @st.cache_resource def load_pipeline(model_name): """Load and cache the model pipeline""" return load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32) # Fetch data with caching @st.cache_data(ttl=3600) # Cache for 1 hour def fetch_data(source_name, days_back=180): """Fetch data from specified source with caching""" return fetch_data_with_fallback(source_name, days_back) # Streamlit app interface st.title("Electricity Market Price Forecasting with Chronos-2") st.write("This demo uses **Chronos-2** to forecast electricity prices from ERCOT (Texas) market data.") # Model selection available_model_names = ModelConfig.get_model_names() selected_model_name = st.selectbox( "Select Forecasting Model:", options=available_model_names, index=0 # Default to first model (Chronos-2) ) # Load the selected model with st.spinner(f"Loading {selected_model_name}..."): pipeline = load_pipeline(selected_model_name) # Data source selection available_sources = DataConfig.get_source_names() data_source = st.radio( "Select Data Source:", available_sources + ["Custom Data"], index=0 ) # Input field for user-provided data if data_source == "Custom Data": user_input = st.text_area( "Enter time series data (comma-separated values):", "", height=150 ) data_source_used = "Custom" error_msg = None else: # Fetch data from selected source with st.spinner(f"Fetching data from {data_source}..."): default_data, data_source_used, error_msg = fetch_data(data_source) if error_msg: st.warning(f"⚠️ {error_msg}\nUsing sample data instead.") user_input = st.text_area( f"{data_source_used} - Daily Average Prices ($/MWh):", default_data.strip(), height=150 ) st.info("💡 Live data from ERCOT's Day-Ahead Market (DAM SPP) - averaged across all settlement points per day") try: time_series_data = process_input(user_input) except ValueError: st.error("Please make sure all values are numbers, separated by commas.") time_series_data = [] # Set empty data on error to prevent further processing # Select the number of days for testing (forecasting on known data) max_test_days = min(64, len(time_series_data) - 10) if len(time_series_data) > 10 else 1 prediction_length = st.slider( "Select Test Window (Days to Forecast & Compare)", min_value=1, max_value=max_test_days, value=min(14, max_test_days), help="The last N days will be used as test data. The model will forecast these days and compare with actual values." ) # If data is valid, perform the forecast if time_series_data: # Split data into train and test train_length = len(time_series_data) - prediction_length train_data = time_series_data[:train_length] test_data = time_series_data[train_length:] # Create timestamps end_date = datetime.now() start_date = end_date - timedelta(days=len(time_series_data) - 1) all_dates = pd.date_range(start=start_date, periods=len(time_series_data), freq='D') train_dates = all_dates[:train_length] test_dates = all_dates[train_length:] # Create a DataFrame for training context_df = pd.DataFrame({ 'timestamp': train_dates, 'target': train_data, 'id': 'ercot_prices' }) # Make the forecast using the model with st.spinner("Generating forecast..."): pred_df = pipeline.predict_df( context_df, prediction_length=prediction_length, quantile_levels=[0.1, 0.5, 0.9], id_column="id", timestamp_column="timestamp", target="target", ) # Extract predictions median = pred_df["predictions"].values low = pred_df["0.1"].values high = pred_df["0.9"].values # Calculate error metrics mae = np.mean(np.abs(np.array(test_data) - median)) mape = np.mean(np.abs((np.array(test_data) - median) / np.array(test_data))) * 100 rmse = np.sqrt(np.mean((np.array(test_data) - median) ** 2)) # Plot the historical and forecasted data with dates plt.figure(figsize=(14, 7)) plt.plot(train_dates, train_data, color="royalblue", label="Training Data", linewidth=2) plt.plot(test_dates, test_data, color="green", label="Actual Test Data", linewidth=2, marker='o', markersize=4) plt.plot(test_dates, median, color="tomato", label="Forecast", linewidth=2, linestyle='--', marker='s', markersize=4) plt.fill_between(test_dates, low, high, color="tomato", alpha=0.3, label="80% Prediction Interval") plt.axvline(x=train_dates[-1], color='gray', linestyle=':', linewidth=1, alpha=0.7) plt.text(train_dates[-1], plt.ylim()[1]*0.95, ' Train/Test Split', fontsize=10, color='gray') plt.xlabel("Date") plt.ylabel("Price ($/MWh)") plt.title(f"ERCOT Electricity Price Forecast - {prediction_length} Day Test Window") plt.legend(loc='best') plt.grid(alpha=0.3) plt.xticks(rotation=45) plt.tight_layout() # Show the plot in the Streamlit app st.pyplot(plt) # Display forecast statistics and error metrics st.write("### Model Performance Metrics") col1, col2, col3, col4 = st.columns(4) with col1: st.metric("MAE", f"${mae:.2f}") with col2: st.metric("RMSE", f"${rmse:.2f}") with col3: st.metric("MAPE", f"{mape:.2f}%") with col4: st.metric("Avg Actual", f"${np.mean(test_data):.2f}/MWh") # Show detailed comparison table with st.expander("View Detailed Comparison"): comparison_df = pd.DataFrame({ 'Date': test_dates.strftime('%Y-%m-%d'), 'Actual': test_data, 'Forecast': median.round(2), 'Error': (median - np.array(test_data)).round(2), 'Error %': ((median - np.array(test_data)) / np.array(test_data) * 100).round(2) }) st.dataframe(comparison_df, use_container_width=True) # Note for comments, feedback, or questions st.write("### Notes") st.write("For comments, feedback, or any questions, please reach out to me on [LinkedIn](https://www.linkedin.com/in/javadbayazi/).")