Spaces:
Running
Running
File size: 3,476 Bytes
92f4bb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
"""
Model configuration and loading for time series forecasting.
Supports multiple Chronos model variants with different architectures.
"""
import torch
from chronos import Chronos2Pipeline, ChronosPipeline
class ModelConfig:
"""Configuration for available forecasting models"""
CHRONOS_2_MODELS = {
"Chronos-2 (Latest, 120M params)": {
"model_id": "amazon/chronos-2",
"pipeline_class": Chronos2Pipeline,
"description": "Latest Chronos-2 model with 120M parameters"
}
}
CHRONOS_T5_MODELS = {
"Chronos-T5 Tiny (8M params)": {
"model_id": "amazon/chronos-t5-tiny",
"pipeline_class": ChronosPipeline,
"description": "Smallest Chronos-T5 model, fastest inference"
},
"Chronos-T5 Mini (20M params)": {
"model_id": "amazon/chronos-t5-mini",
"pipeline_class": ChronosPipeline,
"description": "Mini Chronos-T5 model"
},
"Chronos-T5 Small (46M params)": {
"model_id": "amazon/chronos-t5-small",
"pipeline_class": ChronosPipeline,
"description": "Small Chronos-T5 model"
},
"Chronos-T5 Base (200M params)": {
"model_id": "amazon/chronos-t5-base",
"pipeline_class": ChronosPipeline,
"description": "Base Chronos-T5 model"
},
"Chronos-T5 Large (710M params)": {
"model_id": "amazon/chronos-t5-large",
"pipeline_class": ChronosPipeline,
"description": "Largest Chronos-T5 model, best accuracy"
}
}
@classmethod
def get_all_models(cls):
"""Get all available models"""
all_models = {}
all_models.update(cls.CHRONOS_2_MODELS)
all_models.update(cls.CHRONOS_T5_MODELS)
return all_models
@classmethod
def get_model_names(cls):
"""Get list of model names for dropdown"""
return list(cls.get_all_models().keys())
@classmethod
def get_model_config(cls, model_name):
"""Get configuration for a specific model"""
return cls.get_all_models().get(model_name)
def load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32):
"""
Load a forecasting model pipeline.
Args:
model_name: Display name of the model
device_map: Device to load model on (default: "cpu")
dtype: Data type for model weights (default: torch.float32)
Returns:
Loaded pipeline instance
"""
config = ModelConfig.get_model_config(model_name)
if config is None:
raise ValueError(f"Unknown model: {model_name}")
pipeline_class = config["pipeline_class"]
model_id = config["model_id"]
# Load the appropriate pipeline
pipeline = pipeline_class.from_pretrained(
model_id,
device_map=device_map,
dtype=dtype,
)
return pipeline
def get_model_info(model_name):
"""
Get information about a model.
Args:
model_name: Display name of the model
Returns:
Dictionary with model information
"""
config = ModelConfig.get_model_config(model_name)
if config is None:
return None
return {
"name": model_name,
"model_id": config["model_id"],
"description": config["description"],
"pipeline": config["pipeline_class"].__name__
}
|