""" Model configuration and loading for time series forecasting. Supports multiple Chronos model variants with different architectures. """ import torch from chronos import Chronos2Pipeline, ChronosPipeline class ModelConfig: """Configuration for available forecasting models""" CHRONOS_2_MODELS = { "Chronos-2 (Latest, 120M params)": { "model_id": "amazon/chronos-2", "pipeline_class": Chronos2Pipeline, "description": "Latest Chronos-2 model with 120M parameters" } } CHRONOS_T5_MODELS = { "Chronos-T5 Tiny (8M params)": { "model_id": "amazon/chronos-t5-tiny", "pipeline_class": ChronosPipeline, "description": "Smallest Chronos-T5 model, fastest inference" }, "Chronos-T5 Mini (20M params)": { "model_id": "amazon/chronos-t5-mini", "pipeline_class": ChronosPipeline, "description": "Mini Chronos-T5 model" }, "Chronos-T5 Small (46M params)": { "model_id": "amazon/chronos-t5-small", "pipeline_class": ChronosPipeline, "description": "Small Chronos-T5 model" }, "Chronos-T5 Base (200M params)": { "model_id": "amazon/chronos-t5-base", "pipeline_class": ChronosPipeline, "description": "Base Chronos-T5 model" }, "Chronos-T5 Large (710M params)": { "model_id": "amazon/chronos-t5-large", "pipeline_class": ChronosPipeline, "description": "Largest Chronos-T5 model, best accuracy" } } @classmethod def get_all_models(cls): """Get all available models""" all_models = {} all_models.update(cls.CHRONOS_2_MODELS) all_models.update(cls.CHRONOS_T5_MODELS) return all_models @classmethod def get_model_names(cls): """Get list of model names for dropdown""" return list(cls.get_all_models().keys()) @classmethod def get_model_config(cls, model_name): """Get configuration for a specific model""" return cls.get_all_models().get(model_name) def load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32): """ Load a forecasting model pipeline. Args: model_name: Display name of the model device_map: Device to load model on (default: "cpu") dtype: Data type for model weights (default: torch.float32) Returns: Loaded pipeline instance """ config = ModelConfig.get_model_config(model_name) if config is None: raise ValueError(f"Unknown model: {model_name}") pipeline_class = config["pipeline_class"] model_id = config["model_id"] # Load the appropriate pipeline pipeline = pipeline_class.from_pretrained( model_id, device_map=device_map, dtype=dtype, ) return pipeline def get_model_info(model_name): """ Get information about a model. Args: model_name: Display name of the model Returns: Dictionary with model information """ config = ModelConfig.get_model_config(model_name) if config is None: return None return { "name": model_name, "model_id": config["model_id"], "description": config["description"], "pipeline": config["pipeline_class"].__name__ }