""" Model configuration and loading for time series forecasting. Supports multiple Chronos model variants with different architectures. """ import torch from chronos import Chronos2Pipeline, ChronosPipeline # Try to import TiRex forecasting library try: from tirex import load_model as load_tirex_model TIREX_AVAILABLE = True except ImportError: TIREX_AVAILABLE = False class ModelConfig: """Configuration for available forecasting models""" CHRONOS_2_MODELS = { "Chronos-2 (Latest, 120M params)": { "model_id": "amazon/chronos-2", "pipeline_class": Chronos2Pipeline, "description": "Latest Chronos-2 model with 120M parameters" } } CHRONOS_T5_MODELS = { "Chronos-T5 Tiny (8M params)": { "model_id": "amazon/chronos-t5-tiny", "pipeline_class": ChronosPipeline, "description": "Smallest Chronos-T5 model, fastest inference" }, "Chronos-T5 Mini (20M params)": { "model_id": "amazon/chronos-t5-mini", "pipeline_class": ChronosPipeline, "description": "Mini Chronos-T5 model" }, "Chronos-T5 Small (46M params)": { "model_id": "amazon/chronos-t5-small", "pipeline_class": ChronosPipeline, "description": "Small Chronos-T5 model" }, "Chronos-T5 Base (200M params)": { "model_id": "amazon/chronos-t5-base", "pipeline_class": ChronosPipeline, "description": "Base Chronos-T5 model" }, "Chronos-T5 Large (710M params)": { "model_id": "amazon/chronos-t5-large", "pipeline_class": ChronosPipeline, "description": "Largest Chronos-T5 model, best accuracy" } } TIREX_MODELS = { "TiRex (35M params)": { "model_id": "NX-AI/TiRex", "pipeline_class": "TiRex", "description": "TiRex xLSTM-based model, excellent for both short and long-term forecasting" } } if TIREX_AVAILABLE else {} @classmethod def get_all_models(cls): """Get all available models""" all_models = {} all_models.update(cls.CHRONOS_2_MODELS) all_models.update(cls.CHRONOS_T5_MODELS) all_models.update(cls.TIREX_MODELS) return all_models @classmethod def get_model_names(cls): """Get list of model names for dropdown""" return list(cls.get_all_models().keys()) @classmethod def get_model_config(cls, model_name): """Get configuration for a specific model""" return cls.get_all_models().get(model_name) def load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32): """ Load a forecasting model pipeline. Args: model_name: Display name of the model device_map: Device to load model on (default: "cpu") dtype: Data type for model weights (default: torch.float32) Returns: Loaded pipeline instance or model """ config = ModelConfig.get_model_config(model_name) if config is None: raise ValueError(f"Unknown model: {model_name}") pipeline_class = config["pipeline_class"] model_id = config["model_id"] # Load TiRex model differently if pipeline_class == "TiRex": if not TIREX_AVAILABLE: raise ImportError( "TiRex library not installed. Install with: pip install tirex-ts\n" "Note: TiRex requires GPU support (CUDA-enabled GPU recommended)" ) # TiRex uses load_model from tirex library # backend="torch" for CPU/GPU, device="cuda" or "cpu" import torch device = "cuda" if torch.cuda.is_available() and device_map == "cuda" else "cpu" model = load_tirex_model(model_id, backend="torch", device=device) return TiRexWrapper(model) # Load Chronos pipelines pipeline = pipeline_class.from_pretrained( model_id, device_map=device_map, dtype=dtype, ) return pipeline class TiRexWrapper: """Wrapper to make TiRex compatible with Chronos pipeline API""" def __init__(self, model): self.model = model def predict_df(self, context_df, prediction_length, quantile_levels, **kwargs): """ Wrapper to make TiRex work with the same API as Chronos TiRex.forecast() may return a tuple of (forecast, metadata) or just tensor """ import pandas as pd import torch # Convert dataframe to tensor (batch_size=1, sequence_length) context = torch.tensor(context_df['target'].values, dtype=torch.float32).unsqueeze(0) # TiRex forecast may return tuple or tensor with torch.no_grad(): result = self.model.forecast(context=context, prediction_length=prediction_length) # Handle tuple return (forecast, metadata) if isinstance(result, tuple): forecast = result[0] else: forecast = result # forecast is shape (batch, prediction_length) or (batch, prediction_length, samples) if forecast.dim() == 3: # (batch, pred_len, samples) forecast = forecast[0] # Take first batch # Calculate quantiles from samples quantiles = {} for q in quantile_levels: quantiles[str(q)] = torch.quantile(forecast, q, dim=-1).cpu().numpy() median = torch.median(forecast, dim=-1).values.cpu().numpy() elif forecast.dim() == 2: # (batch, pred_len) - single prediction forecast = forecast[0].cpu().numpy() # Take first batch median = forecast # Use same value for all quantiles since we don't have distribution quantiles = {str(q): median for q in quantile_levels} else: # (pred_len,) median = forecast.cpu().numpy() quantiles = {str(q): median for q in quantile_levels} # Create output dataframe matching Chronos format result_df = pd.DataFrame({ 'predictions': median, **quantiles }) return result_df def get_model_info(model_name): """ Get information about a model. Args: model_name: Display name of the model Returns: Dictionary with model information """ config = ModelConfig.get_model_config(model_name) if config is None: return None return { "name": model_name, "model_id": config["model_id"], "description": config["description"], "pipeline": config["pipeline_class"].__name__ }