Spaces:
Running
Running
| """ | |
| Model configuration and loading for time series forecasting. | |
| Supports multiple Chronos model variants with different architectures. | |
| """ | |
| import torch | |
| from chronos import Chronos2Pipeline, ChronosPipeline | |
| class ModelConfig: | |
| """Configuration for available forecasting models""" | |
| CHRONOS_2_MODELS = { | |
| "Chronos-2 (Latest, 120M params)": { | |
| "model_id": "amazon/chronos-2", | |
| "pipeline_class": Chronos2Pipeline, | |
| "description": "Latest Chronos-2 model with 120M parameters" | |
| } | |
| } | |
| CHRONOS_T5_MODELS = { | |
| "Chronos-T5 Tiny (8M params)": { | |
| "model_id": "amazon/chronos-t5-tiny", | |
| "pipeline_class": ChronosPipeline, | |
| "description": "Smallest Chronos-T5 model, fastest inference" | |
| }, | |
| "Chronos-T5 Mini (20M params)": { | |
| "model_id": "amazon/chronos-t5-mini", | |
| "pipeline_class": ChronosPipeline, | |
| "description": "Mini Chronos-T5 model" | |
| }, | |
| "Chronos-T5 Small (46M params)": { | |
| "model_id": "amazon/chronos-t5-small", | |
| "pipeline_class": ChronosPipeline, | |
| "description": "Small Chronos-T5 model" | |
| }, | |
| "Chronos-T5 Base (200M params)": { | |
| "model_id": "amazon/chronos-t5-base", | |
| "pipeline_class": ChronosPipeline, | |
| "description": "Base Chronos-T5 model" | |
| }, | |
| "Chronos-T5 Large (710M params)": { | |
| "model_id": "amazon/chronos-t5-large", | |
| "pipeline_class": ChronosPipeline, | |
| "description": "Largest Chronos-T5 model, best accuracy" | |
| } | |
| } | |
| def get_all_models(cls): | |
| """Get all available models""" | |
| all_models = {} | |
| all_models.update(cls.CHRONOS_2_MODELS) | |
| all_models.update(cls.CHRONOS_T5_MODELS) | |
| return all_models | |
| def get_model_names(cls): | |
| """Get list of model names for dropdown""" | |
| return list(cls.get_all_models().keys()) | |
| def get_model_config(cls, model_name): | |
| """Get configuration for a specific model""" | |
| return cls.get_all_models().get(model_name) | |
| def load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32): | |
| """ | |
| Load a forecasting model pipeline. | |
| Args: | |
| model_name: Display name of the model | |
| device_map: Device to load model on (default: "cpu") | |
| dtype: Data type for model weights (default: torch.float32) | |
| Returns: | |
| Loaded pipeline instance | |
| """ | |
| config = ModelConfig.get_model_config(model_name) | |
| if config is None: | |
| raise ValueError(f"Unknown model: {model_name}") | |
| pipeline_class = config["pipeline_class"] | |
| model_id = config["model_id"] | |
| # Load the appropriate pipeline | |
| pipeline = pipeline_class.from_pretrained( | |
| model_id, | |
| device_map=device_map, | |
| dtype=dtype, | |
| ) | |
| return pipeline | |
| def get_model_info(model_name): | |
| """ | |
| Get information about a model. | |
| Args: | |
| model_name: Display name of the model | |
| Returns: | |
| Dictionary with model information | |
| """ | |
| config = ModelConfig.get_model_config(model_name) | |
| if config is None: | |
| return None | |
| return { | |
| "name": model_name, | |
| "model_id": config["model_id"], | |
| "description": config["description"], | |
| "pipeline": config["pipeline_class"].__name__ | |
| } | |