JavadBayazi commited on
Commit
92f4bb2
·
1 Parent(s): da5122f

Add modular model architecture with dropdown selector

Browse files

- Create models.py for centralized model configuration
- Support both Chronos-2 and Chronos-T5 pipeline classes
- Add model selector dropdown with 6 model options
- Separate model loading logic from app logic
- Easy to extend with new models in the future
- Fix compatibility with different Chronos variants

Files changed (3) hide show
  1. __pycache__/models.cpython-313.pyc +0 -0
  2. app.py +18 -11
  3. models.py +118 -0
__pycache__/models.cpython-313.pyc ADDED
Binary file (3.73 kB). View file
 
app.py CHANGED
@@ -1,23 +1,17 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import torch
4
- from chronos import Chronos2Pipeline
5
  import matplotlib.pyplot as plt
6
  import numpy as np
7
  from gridstatus import Ercot
8
  from datetime import datetime, timedelta
 
9
 
10
- # Load the Chronos Pipeline model
11
  @st.cache_resource
12
- def load_pipeline():
13
- pipeline = Chronos2Pipeline.from_pretrained(
14
- "amazon/chronos-2",
15
- device_map="cpu", # Change to CPU
16
- dtype=torch.float32, # Use float32 for CPU
17
- )
18
- return pipeline
19
-
20
- pipeline = load_pipeline()
21
 
22
  # Function to fetch ERCOT electricity price data
23
  @st.cache_data(ttl=3600) # Cache for 1 hour
@@ -49,6 +43,19 @@ def fetch_ercot_data(days_back=180):
49
  st.title("Electricity Market Price Forecasting with Chronos-2")
50
  st.write("This demo uses **Chronos-2** to forecast electricity prices from ERCOT (Texas) market data.")
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  # Fetch default ERCOT data
53
  with st.spinner("Fetching latest ERCOT electricity prices..."):
54
  ercot_data = fetch_ercot_data()
 
1
  import streamlit as st
2
  import pandas as pd
3
  import torch
 
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
  from gridstatus import Ercot
7
  from datetime import datetime, timedelta
8
+ from models import ModelConfig, load_model_pipeline
9
 
10
+ # Load the forecasting model pipeline
11
  @st.cache_resource
12
+ def load_pipeline(model_name):
13
+ """Load and cache the model pipeline"""
14
+ return load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32)
 
 
 
 
 
 
15
 
16
  # Function to fetch ERCOT electricity price data
17
  @st.cache_data(ttl=3600) # Cache for 1 hour
 
43
  st.title("Electricity Market Price Forecasting with Chronos-2")
44
  st.write("This demo uses **Chronos-2** to forecast electricity prices from ERCOT (Texas) market data.")
45
 
46
+ # Model selection
47
+ available_model_names = ModelConfig.get_model_names()
48
+
49
+ selected_model_name = st.selectbox(
50
+ "Select Forecasting Model:",
51
+ options=available_model_names,
52
+ index=0 # Default to first model (Chronos-2)
53
+ )
54
+
55
+ # Load the selected model
56
+ with st.spinner(f"Loading {selected_model_name}..."):
57
+ pipeline = load_pipeline(selected_model_name)
58
+
59
  # Fetch default ERCOT data
60
  with st.spinner("Fetching latest ERCOT electricity prices..."):
61
  ercot_data = fetch_ercot_data()
models.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model configuration and loading for time series forecasting.
3
+ Supports multiple Chronos model variants with different architectures.
4
+ """
5
+
6
+ import torch
7
+ from chronos import Chronos2Pipeline, ChronosPipeline
8
+
9
+
10
+ class ModelConfig:
11
+ """Configuration for available forecasting models"""
12
+
13
+ CHRONOS_2_MODELS = {
14
+ "Chronos-2 (Latest, 120M params)": {
15
+ "model_id": "amazon/chronos-2",
16
+ "pipeline_class": Chronos2Pipeline,
17
+ "description": "Latest Chronos-2 model with 120M parameters"
18
+ }
19
+ }
20
+
21
+ CHRONOS_T5_MODELS = {
22
+ "Chronos-T5 Tiny (8M params)": {
23
+ "model_id": "amazon/chronos-t5-tiny",
24
+ "pipeline_class": ChronosPipeline,
25
+ "description": "Smallest Chronos-T5 model, fastest inference"
26
+ },
27
+ "Chronos-T5 Mini (20M params)": {
28
+ "model_id": "amazon/chronos-t5-mini",
29
+ "pipeline_class": ChronosPipeline,
30
+ "description": "Mini Chronos-T5 model"
31
+ },
32
+ "Chronos-T5 Small (46M params)": {
33
+ "model_id": "amazon/chronos-t5-small",
34
+ "pipeline_class": ChronosPipeline,
35
+ "description": "Small Chronos-T5 model"
36
+ },
37
+ "Chronos-T5 Base (200M params)": {
38
+ "model_id": "amazon/chronos-t5-base",
39
+ "pipeline_class": ChronosPipeline,
40
+ "description": "Base Chronos-T5 model"
41
+ },
42
+ "Chronos-T5 Large (710M params)": {
43
+ "model_id": "amazon/chronos-t5-large",
44
+ "pipeline_class": ChronosPipeline,
45
+ "description": "Largest Chronos-T5 model, best accuracy"
46
+ }
47
+ }
48
+
49
+ @classmethod
50
+ def get_all_models(cls):
51
+ """Get all available models"""
52
+ all_models = {}
53
+ all_models.update(cls.CHRONOS_2_MODELS)
54
+ all_models.update(cls.CHRONOS_T5_MODELS)
55
+ return all_models
56
+
57
+ @classmethod
58
+ def get_model_names(cls):
59
+ """Get list of model names for dropdown"""
60
+ return list(cls.get_all_models().keys())
61
+
62
+ @classmethod
63
+ def get_model_config(cls, model_name):
64
+ """Get configuration for a specific model"""
65
+ return cls.get_all_models().get(model_name)
66
+
67
+
68
+ def load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32):
69
+ """
70
+ Load a forecasting model pipeline.
71
+
72
+ Args:
73
+ model_name: Display name of the model
74
+ device_map: Device to load model on (default: "cpu")
75
+ dtype: Data type for model weights (default: torch.float32)
76
+
77
+ Returns:
78
+ Loaded pipeline instance
79
+ """
80
+ config = ModelConfig.get_model_config(model_name)
81
+
82
+ if config is None:
83
+ raise ValueError(f"Unknown model: {model_name}")
84
+
85
+ pipeline_class = config["pipeline_class"]
86
+ model_id = config["model_id"]
87
+
88
+ # Load the appropriate pipeline
89
+ pipeline = pipeline_class.from_pretrained(
90
+ model_id,
91
+ device_map=device_map,
92
+ dtype=dtype,
93
+ )
94
+
95
+ return pipeline
96
+
97
+
98
+ def get_model_info(model_name):
99
+ """
100
+ Get information about a model.
101
+
102
+ Args:
103
+ model_name: Display name of the model
104
+
105
+ Returns:
106
+ Dictionary with model information
107
+ """
108
+ config = ModelConfig.get_model_config(model_name)
109
+
110
+ if config is None:
111
+ return None
112
+
113
+ return {
114
+ "name": model_name,
115
+ "model_id": config["model_id"],
116
+ "description": config["description"],
117
+ "pipeline": config["pipeline_class"].__name__
118
+ }