Spaces:
Running
Running
| import streamlit as st | |
| import pandas as pd | |
| import torch | |
| from chronos import Chronos2Pipeline | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from gridstatus import Ercot | |
| from datetime import datetime, timedelta | |
| # Load the Chronos Pipeline model | |
| def load_pipeline(): | |
| pipeline = Chronos2Pipeline.from_pretrained( | |
| "amazon/chronos-2", | |
| device_map="cpu", # Change to CPU | |
| dtype=torch.float32, # Use float32 for CPU | |
| ) | |
| return pipeline | |
| pipeline = load_pipeline() | |
| # Function to fetch ERCOT electricity price data | |
| # Cache for 1 hour | |
| def fetch_ercot_data(days_back=60): | |
| """Fetch ERCOT day-ahead market prices for the last N days""" | |
| try: | |
| ercot = Ercot() | |
| end_date = datetime.now() | |
| start_date = end_date - timedelta(days=days_back) | |
| # Get day-ahead hourly market settlement point prices | |
| df = ercot.get_spp( | |
| date=start_date, | |
| end=end_date, | |
| market="DAY_AHEAD_HOURLY", | |
| ) | |
| # Get average price per day across all locations | |
| df['Date'] = pd.to_datetime(df['Interval Start']).dt.date | |
| daily_prices = df.groupby('Date')['SPP'].mean() | |
| # Convert to comma-separated string | |
| price_list = daily_prices.round(2).tolist() | |
| return ", ".join(map(str, price_list)) | |
| except Exception as e: | |
| st.warning(f"Could not fetch live ERCOT data: {e}. Using sample data instead.") | |
| return None | |
| # Streamlit app interface | |
| st.title("Electricity Market Price Forecasting with Chronos-2") | |
| st.write("This demo uses **Chronos-2** to forecast electricity prices from ERCOT (Texas) market data.") | |
| # Fetch default ERCOT data | |
| with st.spinner("Fetching latest ERCOT electricity prices..."): | |
| ercot_data = fetch_ercot_data() | |
| # Fallback to sample data if fetching fails | |
| default_data = ercot_data if ercot_data else """ | |
| 25.50, 24.80, 26.30, 23.90, 25.10, 27.20, 28.50, 26.70, 24.30, 23.80, 25.40, 26.10, 27.80, 29.20, 28.40, | |
| 26.90, 25.30, 24.70, 26.50, 28.10, 29.60, 31.20, 30.50, 28.80, 27.10, 25.90, 27.30, 28.70, 30.20, 32.10, | |
| 31.40, 29.70, 28.20, 26.80, 28.40, 29.80, 31.50, 33.20, 32.60, 30.90, 29.30, 27.80, 29.40, 30.90, 32.70, | |
| 34.50, 33.80, 32.10, 30.50, 28.90, 30.50, 32.10, 33.90, 35.80, 35.10, 33.30, 31.60, 30.10, 31.70, 33.40, | |
| 35.20, 37.10, 36.40, 34.60, 32.90, 31.30, 32.90, 34.60, 36.50, 38.40, 37.70, 35.80, 34.10, 32.50, 34.20, | |
| 35.90, 37.80, 39.80, 39.10, 37.10, 35.40, 33.70, 35.40, 37.20, 39.20, 41.20, 40.50, 38.50, 36.70, 35.00, | |
| 36.70, 38.50, 40.60, 42.60, 41.90, 39.90, 38.00, 36.30, 38.00, 39.90, 42.00, 44.10, 43.40, 41.30, 39.40 | |
| """ | |
| # Data source selection | |
| data_source = st.radio( | |
| "Select Data Source:", | |
| ["Live ERCOT Data (Last 180 Days)", "Custom Data"], | |
| index=0 | |
| ) | |
| # Input field for user-provided data | |
| if data_source == "Custom Data": | |
| user_input = st.text_area( | |
| "Enter time series data (comma-separated values):", | |
| "" | |
| ) | |
| else: | |
| user_input = st.text_area( | |
| "ERCOT Day-Ahead Hourly Market Prices ($/MWh) - Daily Average:", | |
| default_data.strip(), | |
| height=150 | |
| ) | |
| st.info("💡 Live data from ERCOT's Day-Ahead Hourly Market - averaged across all settlement points per day") | |
| # Convert user input into a list of numbers | |
| def process_input(input_str): | |
| return [float(x.strip()) for x in input_str.split(",")] | |
| try: | |
| time_series_data = process_input(user_input) | |
| except ValueError: | |
| st.error("Please make sure all values are numbers, separated by commas.") | |
| time_series_data = [] # Set empty data on error to prevent further processing | |
| # Select the number of days for forecasting | |
| prediction_length = st.slider("Select Forecast Horizon (Days)", min_value=1, max_value=64, value=14) | |
| # If data is valid, perform the forecast | |
| if time_series_data: | |
| # Create a DataFrame for Chronos-2 | |
| context_df = pd.DataFrame({ | |
| 'timestamp': pd.date_range(start='2024-01-01', periods=len(time_series_data), freq='D'), | |
| 'target': time_series_data, | |
| 'id': 'ercot_prices' | |
| }) | |
| # Make the forecast using Chronos-2 API | |
| pred_df = pipeline.predict_df( | |
| context_df, | |
| prediction_length=prediction_length, | |
| quantile_levels=[0.1, 0.5, 0.9], | |
| id_column="id", | |
| timestamp_column="timestamp", | |
| target="target", | |
| ) | |
| # Prepare forecast data for plotting | |
| forecast_index = range(len(time_series_data), len(time_series_data) + prediction_length) | |
| median = pred_df["predictions"].values | |
| low = pred_df["0.1"].values | |
| high = pred_df["0.9"].values | |
| # Plot the historical and forecasted data | |
| plt.figure(figsize=(10, 5)) | |
| plt.plot(time_series_data, color="royalblue", label="Historical Prices") | |
| plt.plot(forecast_index, median, color="tomato", label="Median Forecast") | |
| plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% Prediction Interval") | |
| plt.xlabel("Days") | |
| plt.ylabel("Price ($/MWh)") | |
| plt.title("ERCOT Electricity Price Forecast") | |
| plt.legend() | |
| plt.grid(alpha=0.3) | |
| # Show the plot in the Streamlit app | |
| st.pyplot(plt) | |
| # Display forecast statistics | |
| st.write("### Forecast Summary") | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.metric("Median Forecast", f"${median.mean():.2f}/MWh") | |
| with col2: | |
| st.metric("Low (10th percentile)", f"${low.mean():.2f}/MWh") | |
| with col3: | |
| st.metric("High (90th percentile)", f"${high.mean():.2f}/MWh") | |
| # Note for comments, feedback, or questions | |
| st.write("### Notes") | |
| st.write("For comments, feedback, or any questions, please reach out to me on [LinkedIn](https://www.linkedin.com/in/javadbayazi/).") |