TSF-EM / models.py
JavadBayazi's picture
Add modular model architecture with dropdown selector
92f4bb2
raw
history blame
3.48 kB
"""
Model configuration and loading for time series forecasting.
Supports multiple Chronos model variants with different architectures.
"""
import torch
from chronos import Chronos2Pipeline, ChronosPipeline
class ModelConfig:
"""Configuration for available forecasting models"""
CHRONOS_2_MODELS = {
"Chronos-2 (Latest, 120M params)": {
"model_id": "amazon/chronos-2",
"pipeline_class": Chronos2Pipeline,
"description": "Latest Chronos-2 model with 120M parameters"
}
}
CHRONOS_T5_MODELS = {
"Chronos-T5 Tiny (8M params)": {
"model_id": "amazon/chronos-t5-tiny",
"pipeline_class": ChronosPipeline,
"description": "Smallest Chronos-T5 model, fastest inference"
},
"Chronos-T5 Mini (20M params)": {
"model_id": "amazon/chronos-t5-mini",
"pipeline_class": ChronosPipeline,
"description": "Mini Chronos-T5 model"
},
"Chronos-T5 Small (46M params)": {
"model_id": "amazon/chronos-t5-small",
"pipeline_class": ChronosPipeline,
"description": "Small Chronos-T5 model"
},
"Chronos-T5 Base (200M params)": {
"model_id": "amazon/chronos-t5-base",
"pipeline_class": ChronosPipeline,
"description": "Base Chronos-T5 model"
},
"Chronos-T5 Large (710M params)": {
"model_id": "amazon/chronos-t5-large",
"pipeline_class": ChronosPipeline,
"description": "Largest Chronos-T5 model, best accuracy"
}
}
@classmethod
def get_all_models(cls):
"""Get all available models"""
all_models = {}
all_models.update(cls.CHRONOS_2_MODELS)
all_models.update(cls.CHRONOS_T5_MODELS)
return all_models
@classmethod
def get_model_names(cls):
"""Get list of model names for dropdown"""
return list(cls.get_all_models().keys())
@classmethod
def get_model_config(cls, model_name):
"""Get configuration for a specific model"""
return cls.get_all_models().get(model_name)
def load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32):
"""
Load a forecasting model pipeline.
Args:
model_name: Display name of the model
device_map: Device to load model on (default: "cpu")
dtype: Data type for model weights (default: torch.float32)
Returns:
Loaded pipeline instance
"""
config = ModelConfig.get_model_config(model_name)
if config is None:
raise ValueError(f"Unknown model: {model_name}")
pipeline_class = config["pipeline_class"]
model_id = config["model_id"]
# Load the appropriate pipeline
pipeline = pipeline_class.from_pretrained(
model_id,
device_map=device_map,
dtype=dtype,
)
return pipeline
def get_model_info(model_name):
"""
Get information about a model.
Args:
model_name: Display name of the model
Returns:
Dictionary with model information
"""
config = ModelConfig.get_model_config(model_name)
if config is None:
return None
return {
"name": model_name,
"model_id": config["model_id"],
"description": config["description"],
"pipeline": config["pipeline_class"].__name__
}