TSF-EM / app.py
JavadBayazi's picture
Update features section to include all 7 models
12ab12a
import streamlit as st
import pandas as pd
import torch
import matplotlib.pyplot as plt
import numpy as np
from datetime import datetime, timedelta
from models import ModelConfig, load_model_pipeline
from data import DataConfig, process_input, fetch_data_with_fallback
# Load the forecasting model pipeline
@st.cache_resource
def load_pipeline(model_name):
"""Load and cache the model pipeline"""
return load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32)
# Fetch data with caching
@st.cache_data(ttl=3600) # Cache for 1 hour
def fetch_data(source_name, days_back=180):
"""Fetch data from specified source with caching"""
return fetch_data_with_fallback(source_name, days_back)
# Streamlit app interface
st.title("⚡ Electricity Market Price Forecasting")
st.write("""
This application uses **Amazon Chronos** pretrained models for zero-shot time series forecasting on electricity market data.
Select a model, choose your data source, and evaluate forecasting performance with backtesting on real ERCOT prices.
""")
# Model selection
available_model_names = ModelConfig.get_model_names()
selected_model_name = st.selectbox(
"Select Forecasting Model:",
options=available_model_names,
index=0 # Default to first model (Chronos-2)
)
# Load the selected model
with st.spinner(f"Loading {selected_model_name}..."):
pipeline = load_pipeline(selected_model_name)
# Data source selection
available_sources = DataConfig.get_source_names()
data_source = st.radio(
"Select Data Source:",
available_sources + ["Custom Data"],
index=0
)
# Input field for user-provided data
if data_source == "Custom Data":
user_input = st.text_area(
"Enter time series data (comma-separated values):",
"",
height=150
)
data_source_used = "Custom"
error_msg = None
else:
# Fetch data from selected source
with st.spinner(f"Fetching data from {data_source}..."):
default_data, data_source_used, error_msg = fetch_data(data_source)
if error_msg:
st.warning(f"⚠️ {error_msg}\nUsing sample data instead.")
user_input = st.text_area(
f"{data_source_used} - Daily Average Prices ($/MWh):",
default_data.strip(),
height=150
)
if "ERCOT" in data_source_used:
st.info("💡 Live data from ERCOT's Day-Ahead Market (DAM SPP) - Daily average prices across all settlement points")
try:
time_series_data = process_input(user_input)
except ValueError:
st.error("Please make sure all values are numbers, separated by commas.")
time_series_data = [] # Set empty data on error to prevent further processing
# Select the forecast window for backtesting
max_test_days = min(64, len(time_series_data) - 10) if len(time_series_data) > 10 else 1
prediction_length = st.slider(
"Forecast Horizon (Days to Backtest)",
min_value=1,
max_value=max_test_days,
value=min(14, max_test_days),
help="The model will use historical context to forecast the last N days, then compare predictions with actual values to evaluate performance."
)
# If data is valid, perform the forecast
if time_series_data:
# Split data into context (historical) and test
context_length = len(time_series_data) - prediction_length
context_data = time_series_data[:context_length]
test_data = time_series_data[context_length:]
# Create timestamps
end_date = datetime.now()
start_date = end_date - timedelta(days=len(time_series_data) - 1)
all_dates = pd.date_range(start=start_date, periods=len(time_series_data), freq='D')
context_dates = all_dates[:context_length]
test_dates = all_dates[context_length:]
# Create a DataFrame with context for the model
context_df = pd.DataFrame({
'timestamp': context_dates,
'target': context_data,
'id': 'ercot_prices'
})
# Make the forecast using the model
with st.spinner("Generating forecast..."):
pred_df = pipeline.predict_df(
context_df,
prediction_length=prediction_length,
quantile_levels=[0.1, 0.5, 0.9],
id_column="id",
timestamp_column="timestamp",
target="target",
)
# Extract predictions
median = pred_df["predictions"].values
low = pred_df["0.1"].values
high = pred_df["0.9"].values
# Calculate error metrics
mae = np.mean(np.abs(np.array(test_data) - median))
mape = np.mean(np.abs((np.array(test_data) - median) / np.array(test_data))) * 100
rmse = np.sqrt(np.mean((np.array(test_data) - median) ** 2))
# Plot the historical and forecasted data with dates
plt.figure(figsize=(14, 7))
plt.plot(context_dates, context_data, color="royalblue", label="Historical Context", linewidth=2)
plt.plot(test_dates, test_data, color="green", label="Actual Values", linewidth=2, marker='o', markersize=4)
plt.plot(test_dates, median, color="tomato", label="Forecast", linewidth=2, linestyle='--', marker='s', markersize=4)
plt.fill_between(test_dates, low, high, color="tomato", alpha=0.3, label="80% Prediction Interval")
plt.axvline(x=context_dates[-1], color='gray', linestyle=':', linewidth=1, alpha=0.7)
plt.text(context_dates[-1], plt.ylim()[1]*0.95, ' Forecast Window', fontsize=10, color='gray')
plt.xlabel("Date")
plt.ylabel("Price ($/MWh)")
plt.title(f"ERCOT Electricity Price Forecast - {prediction_length} Day Test Window")
plt.legend(loc='best')
plt.grid(alpha=0.3)
plt.xticks(rotation=45)
plt.tight_layout()
# Show the plot in the Streamlit app
st.pyplot(plt)
# Display forecast statistics and error metrics
st.write("### Model Performance Metrics")
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("MAE", f"${mae:.2f}")
with col2:
st.metric("RMSE", f"${rmse:.2f}")
with col3:
st.metric("MAPE", f"{mape:.2f}%")
with col4:
st.metric("Avg Actual", f"${np.mean(test_data):.2f}/MWh")
# Show detailed comparison table
with st.expander("View Detailed Comparison"):
comparison_df = pd.DataFrame({
'Date': test_dates.strftime('%Y-%m-%d'),
'Actual': test_data,
'Forecast': median.round(2),
'Error': (median - np.array(test_data)).round(2),
'Error %': ((median - np.array(test_data)) / np.array(test_data) * 100).round(2)
})
st.dataframe(comparison_df, use_container_width=True)
# Note for comments, feedback, or questions
st.write("### About")
st.write("""
**Features:**
- 🤖 Multiple pretrained models (7 options: Chronos-2, Chronos-T5, TiRex)
- 📊 Real-time ERCOT electricity market data (180+ days)
- 🎯 Backtesting with error metrics (MAE, RMSE, MAPE)
- 📈 Visual comparison of forecasts vs actual values
- 🔧 Modular architecture for easy extension
**Models:**
- **Chronos-2**: Amazon's latest (46M-120M params)
- **Chronos-T5**: Original family (8M-710M params)
- **TiRex**: NX-AI's xLSTM-based model (35M params)
For questions or feedback, reach out on [LinkedIn](https://www.linkedin.com/in/javadbayazi/).
""")