JavadBayazi commited on
Commit
06412fb
·
1 Parent(s): 0043c7f

Add modular data architecture and backtesting features

Browse files

- Create data.py for centralized data fetching
- Implement ERCOTDataSource and SampleDataSource classes
- Add train/test split for model evaluation
- Display actual vs forecast comparison on plot
- Add error metrics: MAE, RMSE, MAPE
- Show detailed comparison table with day-by-day errors
- Visual train/test split marker on plot
- Easy to extend with new data sources (CAISO, PJM, etc.)

Files changed (2) hide show
  1. app.py +85 -78
  2. data.py +155 -0
app.py CHANGED
@@ -3,9 +3,9 @@ import pandas as pd
3
  import torch
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
- from gridstatus import Ercot
7
  from datetime import datetime, timedelta
8
  from models import ModelConfig, load_model_pipeline
 
9
 
10
  # Load the forecasting model pipeline
11
  @st.cache_resource
@@ -13,31 +13,11 @@ def load_pipeline(model_name):
13
  """Load and cache the model pipeline"""
14
  return load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32)
15
 
16
- # Function to fetch ERCOT electricity price data
17
  @st.cache_data(ttl=3600) # Cache for 1 hour
18
- def fetch_ercot_data(days_back=180):
19
- """Fetch ERCOT day-ahead market prices for the current year"""
20
- try:
21
- ercot = Ercot()
22
- current_year = datetime.now().year
23
-
24
- # Get day-ahead market settlement point prices for the year
25
- df = ercot.get_dam_spp(year=current_year)
26
-
27
- # Get average price per day across all locations
28
- df['Date'] = pd.to_datetime(df['Interval Start']).dt.date
29
- daily_prices = df.groupby('Date')['SPP'].mean()
30
-
31
- # Get the last N days
32
- if len(daily_prices) > days_back:
33
- daily_prices = daily_prices.tail(days_back)
34
-
35
- # Convert to comma-separated string
36
- price_list = daily_prices.round(2).tolist()
37
- return ", ".join(map(str, price_list))
38
- except Exception as e:
39
- st.warning(f"Could not fetch live ERCOT data: {e}. Using sample data instead.")
40
- return None
41
 
42
  # Streamlit app interface
43
  st.title("Electricity Market Price Forecasting with Chronos-2")
@@ -56,25 +36,11 @@ selected_model_name = st.selectbox(
56
  with st.spinner(f"Loading {selected_model_name}..."):
57
  pipeline = load_pipeline(selected_model_name)
58
 
59
- # Fetch default ERCOT data
60
- with st.spinner("Fetching latest ERCOT electricity prices..."):
61
- ercot_data = fetch_ercot_data()
62
-
63
- # Fallback to sample data if fetching fails
64
- default_data = ercot_data if ercot_data else """
65
- 25.50, 24.80, 26.30, 23.90, 25.10, 27.20, 28.50, 26.70, 24.30, 23.80, 25.40, 26.10, 27.80, 29.20, 28.40,
66
- 26.90, 25.30, 24.70, 26.50, 28.10, 29.60, 31.20, 30.50, 28.80, 27.10, 25.90, 27.30, 28.70, 30.20, 32.10,
67
- 31.40, 29.70, 28.20, 26.80, 28.40, 29.80, 31.50, 33.20, 32.60, 30.90, 29.30, 27.80, 29.40, 30.90, 32.70,
68
- 34.50, 33.80, 32.10, 30.50, 28.90, 30.50, 32.10, 33.90, 35.80, 35.10, 33.30, 31.60, 30.10, 31.70, 33.40,
69
- 35.20, 37.10, 36.40, 34.60, 32.90, 31.30, 32.90, 34.60, 36.50, 38.40, 37.70, 35.80, 34.10, 32.50, 34.20,
70
- 35.90, 37.80, 39.80, 39.10, 37.10, 35.40, 33.70, 35.40, 37.20, 39.20, 41.20, 40.50, 38.50, 36.70, 35.00,
71
- 36.70, 38.50, 40.60, 42.60, 41.90, 39.90, 38.00, 36.30, 38.00, 39.90, 42.00, 44.10, 43.40, 41.30, 39.40
72
- """
73
-
74
  # Data source selection
 
75
  data_source = st.radio(
76
  "Select Data Source:",
77
- ["Live ERCOT Data (Last 180 Days)", "Custom Data"],
78
  index=0
79
  )
80
 
@@ -82,68 +48,96 @@ data_source = st.radio(
82
  if data_source == "Custom Data":
83
  user_input = st.text_area(
84
  "Enter time series data (comma-separated values):",
85
- ""
 
86
  )
 
 
87
  else:
 
 
 
 
 
 
 
88
  user_input = st.text_area(
89
- "ERCOT Day-Ahead Hourly Market Prices ($/MWh) - Daily Average:",
90
  default_data.strip(),
91
  height=150
92
  )
93
  st.info("💡 Live data from ERCOT's Day-Ahead Market (DAM SPP) - averaged across all settlement points per day")
94
 
95
- # Convert user input into a list of numbers
96
- def process_input(input_str):
97
- return [float(x.strip()) for x in input_str.split(",")]
98
-
99
  try:
100
  time_series_data = process_input(user_input)
101
  except ValueError:
102
  st.error("Please make sure all values are numbers, separated by commas.")
103
  time_series_data = [] # Set empty data on error to prevent further processing
104
 
105
- # Select the number of days for forecasting
106
- prediction_length = st.slider("Select Forecast Horizon (Days)", min_value=1, max_value=64, value=14)
 
 
 
 
 
 
 
107
 
108
  # If data is valid, perform the forecast
109
  if time_series_data:
110
- # Create timestamps starting from today going backwards
 
 
 
 
 
111
  end_date = datetime.now()
112
  start_date = end_date - timedelta(days=len(time_series_data) - 1)
113
- historical_dates = pd.date_range(start=start_date, periods=len(time_series_data), freq='D')
 
 
114
 
115
- # Create a DataFrame for Chronos-2
116
  context_df = pd.DataFrame({
117
- 'timestamp': historical_dates,
118
- 'target': time_series_data,
119
  'id': 'ercot_prices'
120
  })
121
 
122
- # Make the forecast using Chronos-2 API
123
- pred_df = pipeline.predict_df(
124
- context_df,
125
- prediction_length=prediction_length,
126
- quantile_levels=[0.1, 0.5, 0.9],
127
- id_column="id",
128
- timestamp_column="timestamp",
129
- target="target",
130
- )
131
-
132
- # Prepare forecast data for plotting with actual dates
133
- forecast_dates = pd.date_range(start=end_date + timedelta(days=1), periods=prediction_length, freq='D')
134
  median = pred_df["predictions"].values
135
  low = pred_df["0.1"].values
136
  high = pred_df["0.9"].values
137
 
 
 
 
 
 
138
  # Plot the historical and forecasted data with dates
139
- plt.figure(figsize=(12, 6))
140
- plt.plot(historical_dates, time_series_data, color="royalblue", label="Historical Prices", linewidth=2)
141
- plt.plot(forecast_dates, median, color="tomato", label="Median Forecast", linewidth=2)
142
- plt.fill_between(forecast_dates, low, high, color="tomato", alpha=0.3, label="80% Prediction Interval")
 
 
 
143
  plt.xlabel("Date")
144
  plt.ylabel("Price ($/MWh)")
145
- plt.title("ERCOT Electricity Price Forecast")
146
- plt.legend()
147
  plt.grid(alpha=0.3)
148
  plt.xticks(rotation=45)
149
  plt.tight_layout()
@@ -151,15 +145,28 @@ if time_series_data:
151
  # Show the plot in the Streamlit app
152
  st.pyplot(plt)
153
 
154
- # Display forecast statistics
155
- st.write("### Forecast Summary")
156
- col1, col2, col3 = st.columns(3)
157
  with col1:
158
- st.metric("Median Forecast", f"${median.mean():.2f}/MWh")
159
  with col2:
160
- st.metric("Low (10th percentile)", f"${low.mean():.2f}/MWh")
161
  with col3:
162
- st.metric("High (90th percentile)", f"${high.mean():.2f}/MWh")
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
  # Note for comments, feedback, or questions
165
  st.write("### Notes")
 
3
  import torch
4
  import matplotlib.pyplot as plt
5
  import numpy as np
 
6
  from datetime import datetime, timedelta
7
  from models import ModelConfig, load_model_pipeline
8
+ from data import DataConfig, process_input, fetch_data_with_fallback
9
 
10
  # Load the forecasting model pipeline
11
  @st.cache_resource
 
13
  """Load and cache the model pipeline"""
14
  return load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32)
15
 
16
+ # Fetch data with caching
17
  @st.cache_data(ttl=3600) # Cache for 1 hour
18
+ def fetch_data(source_name, days_back=180):
19
+ """Fetch data from specified source with caching"""
20
+ return fetch_data_with_fallback(source_name, days_back)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  # Streamlit app interface
23
  st.title("Electricity Market Price Forecasting with Chronos-2")
 
36
  with st.spinner(f"Loading {selected_model_name}..."):
37
  pipeline = load_pipeline(selected_model_name)
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  # Data source selection
40
+ available_sources = DataConfig.get_source_names()
41
  data_source = st.radio(
42
  "Select Data Source:",
43
+ available_sources + ["Custom Data"],
44
  index=0
45
  )
46
 
 
48
  if data_source == "Custom Data":
49
  user_input = st.text_area(
50
  "Enter time series data (comma-separated values):",
51
+ "",
52
+ height=150
53
  )
54
+ data_source_used = "Custom"
55
+ error_msg = None
56
  else:
57
+ # Fetch data from selected source
58
+ with st.spinner(f"Fetching data from {data_source}..."):
59
+ default_data, data_source_used, error_msg = fetch_data(data_source)
60
+
61
+ if error_msg:
62
+ st.warning(f"⚠️ {error_msg}\nUsing sample data instead.")
63
+
64
  user_input = st.text_area(
65
+ f"{data_source_used} - Daily Average Prices ($/MWh):",
66
  default_data.strip(),
67
  height=150
68
  )
69
  st.info("💡 Live data from ERCOT's Day-Ahead Market (DAM SPP) - averaged across all settlement points per day")
70
 
 
 
 
 
71
  try:
72
  time_series_data = process_input(user_input)
73
  except ValueError:
74
  st.error("Please make sure all values are numbers, separated by commas.")
75
  time_series_data = [] # Set empty data on error to prevent further processing
76
 
77
+ # Select the number of days for testing (forecasting on known data)
78
+ max_test_days = min(64, len(time_series_data) - 10) if len(time_series_data) > 10 else 1
79
+ prediction_length = st.slider(
80
+ "Select Test Window (Days to Forecast & Compare)",
81
+ min_value=1,
82
+ max_value=max_test_days,
83
+ value=min(14, max_test_days),
84
+ help="The last N days will be used as test data. The model will forecast these days and compare with actual values."
85
+ )
86
 
87
  # If data is valid, perform the forecast
88
  if time_series_data:
89
+ # Split data into train and test
90
+ train_length = len(time_series_data) - prediction_length
91
+ train_data = time_series_data[:train_length]
92
+ test_data = time_series_data[train_length:]
93
+
94
+ # Create timestamps
95
  end_date = datetime.now()
96
  start_date = end_date - timedelta(days=len(time_series_data) - 1)
97
+ all_dates = pd.date_range(start=start_date, periods=len(time_series_data), freq='D')
98
+ train_dates = all_dates[:train_length]
99
+ test_dates = all_dates[train_length:]
100
 
101
+ # Create a DataFrame for training
102
  context_df = pd.DataFrame({
103
+ 'timestamp': train_dates,
104
+ 'target': train_data,
105
  'id': 'ercot_prices'
106
  })
107
 
108
+ # Make the forecast using the model
109
+ with st.spinner("Generating forecast..."):
110
+ pred_df = pipeline.predict_df(
111
+ context_df,
112
+ prediction_length=prediction_length,
113
+ quantile_levels=[0.1, 0.5, 0.9],
114
+ id_column="id",
115
+ timestamp_column="timestamp",
116
+ target="target",
117
+ )
118
+
119
+ # Extract predictions
120
  median = pred_df["predictions"].values
121
  low = pred_df["0.1"].values
122
  high = pred_df["0.9"].values
123
 
124
+ # Calculate error metrics
125
+ mae = np.mean(np.abs(np.array(test_data) - median))
126
+ mape = np.mean(np.abs((np.array(test_data) - median) / np.array(test_data))) * 100
127
+ rmse = np.sqrt(np.mean((np.array(test_data) - median) ** 2))
128
+
129
  # Plot the historical and forecasted data with dates
130
+ plt.figure(figsize=(14, 7))
131
+ plt.plot(train_dates, train_data, color="royalblue", label="Training Data", linewidth=2)
132
+ plt.plot(test_dates, test_data, color="green", label="Actual Test Data", linewidth=2, marker='o', markersize=4)
133
+ plt.plot(test_dates, median, color="tomato", label="Forecast", linewidth=2, linestyle='--', marker='s', markersize=4)
134
+ plt.fill_between(test_dates, low, high, color="tomato", alpha=0.3, label="80% Prediction Interval")
135
+ plt.axvline(x=train_dates[-1], color='gray', linestyle=':', linewidth=1, alpha=0.7)
136
+ plt.text(train_dates[-1], plt.ylim()[1]*0.95, ' Train/Test Split', fontsize=10, color='gray')
137
  plt.xlabel("Date")
138
  plt.ylabel("Price ($/MWh)")
139
+ plt.title(f"ERCOT Electricity Price Forecast - {prediction_length} Day Test Window")
140
+ plt.legend(loc='best')
141
  plt.grid(alpha=0.3)
142
  plt.xticks(rotation=45)
143
  plt.tight_layout()
 
145
  # Show the plot in the Streamlit app
146
  st.pyplot(plt)
147
 
148
+ # Display forecast statistics and error metrics
149
+ st.write("### Model Performance Metrics")
150
+ col1, col2, col3, col4 = st.columns(4)
151
  with col1:
152
+ st.metric("MAE", f"${mae:.2f}")
153
  with col2:
154
+ st.metric("RMSE", f"${rmse:.2f}")
155
  with col3:
156
+ st.metric("MAPE", f"{mape:.2f}%")
157
+ with col4:
158
+ st.metric("Avg Actual", f"${np.mean(test_data):.2f}/MWh")
159
+
160
+ # Show detailed comparison table
161
+ with st.expander("View Detailed Comparison"):
162
+ comparison_df = pd.DataFrame({
163
+ 'Date': test_dates.strftime('%Y-%m-%d'),
164
+ 'Actual': test_data,
165
+ 'Forecast': median.round(2),
166
+ 'Error': (median - np.array(test_data)).round(2),
167
+ 'Error %': ((median - np.array(test_data)) / np.array(test_data) * 100).round(2)
168
+ })
169
+ st.dataframe(comparison_df, use_container_width=True)
170
 
171
  # Note for comments, feedback, or questions
172
  st.write("### Notes")
data.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data fetching and processing for electricity market price forecasting.
3
+ Handles data retrieval from various sources and preprocessing.
4
+ """
5
+
6
+ import pandas as pd
7
+ from datetime import datetime, timedelta
8
+ from gridstatus import Ercot
9
+
10
+
11
+ class DataSource:
12
+ """Base class for data sources"""
13
+
14
+ def fetch_data(self, days_back=180):
15
+ """
16
+ Fetch data from the source.
17
+
18
+ Args:
19
+ days_back: Number of days of historical data to fetch
20
+
21
+ Returns:
22
+ Comma-separated string of prices, or None on error
23
+ """
24
+ raise NotImplementedError
25
+
26
+
27
+ class ERCOTDataSource(DataSource):
28
+ """Fetch electricity price data from ERCOT"""
29
+
30
+ def __init__(self):
31
+ self.name = "ERCOT (Texas)"
32
+ self.description = "Electric Reliability Council of Texas - Day-Ahead Market"
33
+
34
+ def fetch_data(self, days_back=180):
35
+ """
36
+ Fetch ERCOT day-ahead market prices for the current year.
37
+
38
+ Args:
39
+ days_back: Number of days to fetch (default: 180)
40
+
41
+ Returns:
42
+ Comma-separated string of daily average prices
43
+ """
44
+ try:
45
+ ercot = Ercot()
46
+ current_year = datetime.now().year
47
+
48
+ # Get day-ahead market settlement point prices for the year
49
+ df = ercot.get_dam_spp(year=current_year)
50
+
51
+ # Get average price per day across all locations
52
+ df['Date'] = pd.to_datetime(df['Interval Start']).dt.date
53
+ daily_prices = df.groupby('Date')['SPP'].mean()
54
+
55
+ # Get the last N days
56
+ if len(daily_prices) > days_back:
57
+ daily_prices = daily_prices.tail(days_back)
58
+
59
+ # Convert to comma-separated string
60
+ price_list = daily_prices.round(2).tolist()
61
+ return ", ".join(map(str, price_list))
62
+
63
+ except Exception as e:
64
+ raise Exception(f"Could not fetch ERCOT data: {e}")
65
+
66
+
67
+ class SampleDataSource(DataSource):
68
+ """Fallback sample electricity price data"""
69
+
70
+ def __init__(self):
71
+ self.name = "Sample Data"
72
+ self.description = "Sample electricity price data for demonstration"
73
+
74
+ def fetch_data(self, days_back=180):
75
+ """
76
+ Return sample electricity price data.
77
+
78
+ Returns:
79
+ Comma-separated string of sample prices
80
+ """
81
+ sample_data = """
82
+ 25.50, 24.80, 26.30, 23.90, 25.10, 27.20, 28.50, 26.70, 24.30, 23.80, 25.40, 26.10, 27.80, 29.20, 28.40,
83
+ 26.90, 25.30, 24.70, 26.50, 28.10, 29.60, 31.20, 30.50, 28.80, 27.10, 25.90, 27.30, 28.70, 30.20, 32.10,
84
+ 31.40, 29.70, 28.20, 26.80, 28.40, 29.80, 31.50, 33.20, 32.60, 30.90, 29.30, 27.80, 29.40, 30.90, 32.70,
85
+ 34.50, 33.80, 32.10, 30.50, 28.90, 30.50, 32.10, 33.90, 35.80, 35.10, 33.30, 31.60, 30.10, 31.70, 33.40,
86
+ 35.20, 37.10, 36.40, 34.60, 32.90, 31.30, 32.90, 34.60, 36.50, 38.40, 37.70, 35.80, 34.10, 32.50, 34.20,
87
+ 35.90, 37.80, 39.80, 39.10, 37.10, 35.40, 33.70, 35.40, 37.20, 39.20, 41.20, 40.50, 38.50, 36.70, 35.00,
88
+ 36.70, 38.50, 40.60, 42.60, 41.90, 39.90, 38.00, 36.30, 38.00, 39.90, 42.00, 44.10, 43.40, 41.30, 39.40
89
+ """
90
+ return sample_data.strip()
91
+
92
+
93
+ class DataConfig:
94
+ """Configuration for available data sources"""
95
+
96
+ AVAILABLE_SOURCES = {
97
+ "Live ERCOT Data (Last 180 Days)": ERCOTDataSource,
98
+ "Sample Data": SampleDataSource,
99
+ }
100
+
101
+ @classmethod
102
+ def get_source_names(cls):
103
+ """Get list of available data source names"""
104
+ return list(cls.AVAILABLE_SOURCES.keys())
105
+
106
+ @classmethod
107
+ def get_source(cls, source_name):
108
+ """
109
+ Get a data source instance by name.
110
+
111
+ Args:
112
+ source_name: Name of the data source
113
+
114
+ Returns:
115
+ DataSource instance
116
+ """
117
+ source_class = cls.AVAILABLE_SOURCES.get(source_name)
118
+ if source_class is None:
119
+ raise ValueError(f"Unknown data source: {source_name}")
120
+ return source_class()
121
+
122
+
123
+ def process_input(input_str):
124
+ """
125
+ Convert comma-separated string to list of floats.
126
+
127
+ Args:
128
+ input_str: Comma-separated string of numbers
129
+
130
+ Returns:
131
+ List of float values
132
+ """
133
+ return [float(x.strip()) for x in input_str.split(",") if x.strip()]
134
+
135
+
136
+ def fetch_data_with_fallback(source_name, days_back=180):
137
+ """
138
+ Fetch data from specified source with fallback to sample data.
139
+
140
+ Args:
141
+ source_name: Name of the data source
142
+ days_back: Number of days to fetch
143
+
144
+ Returns:
145
+ Tuple of (data_string, source_used, error_message)
146
+ """
147
+ try:
148
+ source = DataConfig.get_source(source_name)
149
+ data = source.fetch_data(days_back)
150
+ return data, source.name, None
151
+ except Exception as e:
152
+ # Fallback to sample data
153
+ sample_source = SampleDataSource()
154
+ data = sample_source.fetch_data()
155
+ return data, sample_source.name, str(e)