Spaces:
Running
Running
| """ | |
| Model configuration and loading for time series forecasting. | |
| Supports multiple Chronos model variants with different architectures. | |
| """ | |
| import torch | |
| from chronos import Chronos2Pipeline, ChronosPipeline | |
| # Try to import TiRex forecasting library | |
| try: | |
| from tirex import load_model as load_tirex_model | |
| TIREX_AVAILABLE = True | |
| except ImportError: | |
| TIREX_AVAILABLE = False | |
| class ModelConfig: | |
| """Configuration for available forecasting models""" | |
| CHRONOS_2_MODELS = { | |
| "Chronos-2 (Latest, 120M params)": { | |
| "model_id": "amazon/chronos-2", | |
| "pipeline_class": Chronos2Pipeline, | |
| "description": "Latest Chronos-2 model with 120M parameters" | |
| } | |
| } | |
| CHRONOS_T5_MODELS = { | |
| "Chronos-T5 Tiny (8M params)": { | |
| "model_id": "amazon/chronos-t5-tiny", | |
| "pipeline_class": ChronosPipeline, | |
| "description": "Smallest Chronos-T5 model, fastest inference" | |
| }, | |
| "Chronos-T5 Mini (20M params)": { | |
| "model_id": "amazon/chronos-t5-mini", | |
| "pipeline_class": ChronosPipeline, | |
| "description": "Mini Chronos-T5 model" | |
| }, | |
| "Chronos-T5 Small (46M params)": { | |
| "model_id": "amazon/chronos-t5-small", | |
| "pipeline_class": ChronosPipeline, | |
| "description": "Small Chronos-T5 model" | |
| }, | |
| "Chronos-T5 Base (200M params)": { | |
| "model_id": "amazon/chronos-t5-base", | |
| "pipeline_class": ChronosPipeline, | |
| "description": "Base Chronos-T5 model" | |
| }, | |
| "Chronos-T5 Large (710M params)": { | |
| "model_id": "amazon/chronos-t5-large", | |
| "pipeline_class": ChronosPipeline, | |
| "description": "Largest Chronos-T5 model, best accuracy" | |
| } | |
| } | |
| TIREX_MODELS = { | |
| "TiRex (35M params)": { | |
| "model_id": "NX-AI/TiRex", | |
| "pipeline_class": "TiRex", | |
| "description": "TiRex xLSTM-based model, excellent for both short and long-term forecasting" | |
| } | |
| } if TIREX_AVAILABLE else {} | |
| def get_all_models(cls): | |
| """Get all available models""" | |
| all_models = {} | |
| all_models.update(cls.CHRONOS_2_MODELS) | |
| all_models.update(cls.CHRONOS_T5_MODELS) | |
| all_models.update(cls.TIREX_MODELS) | |
| return all_models | |
| def get_model_names(cls): | |
| """Get list of model names for dropdown""" | |
| return list(cls.get_all_models().keys()) | |
| def get_model_config(cls, model_name): | |
| """Get configuration for a specific model""" | |
| return cls.get_all_models().get(model_name) | |
| def load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32): | |
| """ | |
| Load a forecasting model pipeline. | |
| Args: | |
| model_name: Display name of the model | |
| device_map: Device to load model on (default: "cpu") | |
| dtype: Data type for model weights (default: torch.float32) | |
| Returns: | |
| Loaded pipeline instance or model | |
| """ | |
| config = ModelConfig.get_model_config(model_name) | |
| if config is None: | |
| raise ValueError(f"Unknown model: {model_name}") | |
| pipeline_class = config["pipeline_class"] | |
| model_id = config["model_id"] | |
| # Load TiRex model differently | |
| if pipeline_class == "TiRex": | |
| if not TIREX_AVAILABLE: | |
| raise ImportError( | |
| "TiRex library not installed. Install with: pip install tirex-ts\n" | |
| "Note: TiRex requires GPU support (CUDA-enabled GPU recommended)" | |
| ) | |
| # TiRex uses load_model from tirex library | |
| # backend="torch" for CPU/GPU, device="cuda" or "cpu" | |
| import torch | |
| device = "cuda" if torch.cuda.is_available() and device_map == "cuda" else "cpu" | |
| model = load_tirex_model(model_id, backend="torch", device=device) | |
| return TiRexWrapper(model) | |
| # Load Chronos pipelines | |
| pipeline = pipeline_class.from_pretrained( | |
| model_id, | |
| device_map=device_map, | |
| dtype=dtype, | |
| ) | |
| return pipeline | |
| class TiRexWrapper: | |
| """Wrapper to make TiRex compatible with Chronos pipeline API""" | |
| def __init__(self, model): | |
| self.model = model | |
| def predict_df(self, context_df, prediction_length, quantile_levels, **kwargs): | |
| """ | |
| Wrapper to make TiRex work with the same API as Chronos | |
| TiRex.forecast() may return a tuple of (forecast, metadata) or just tensor | |
| """ | |
| import pandas as pd | |
| import torch | |
| # Convert dataframe to tensor (batch_size=1, sequence_length) | |
| context = torch.tensor(context_df['target'].values, dtype=torch.float32).unsqueeze(0) | |
| # TiRex forecast may return tuple or tensor | |
| with torch.no_grad(): | |
| result = self.model.forecast(context=context, prediction_length=prediction_length) | |
| # Handle tuple return (forecast, metadata) | |
| if isinstance(result, tuple): | |
| forecast = result[0] | |
| else: | |
| forecast = result | |
| # forecast is shape (batch, prediction_length) or (batch, prediction_length, samples) | |
| if forecast.dim() == 3: # (batch, pred_len, samples) | |
| forecast = forecast[0] # Take first batch | |
| # Calculate quantiles from samples | |
| quantiles = {} | |
| for q in quantile_levels: | |
| quantiles[str(q)] = torch.quantile(forecast, q, dim=-1).cpu().numpy() | |
| median = torch.median(forecast, dim=-1).values.cpu().numpy() | |
| elif forecast.dim() == 2: # (batch, pred_len) - single prediction | |
| forecast = forecast[0].cpu().numpy() # Take first batch | |
| median = forecast | |
| # Use same value for all quantiles since we don't have distribution | |
| quantiles = {str(q): median for q in quantile_levels} | |
| else: # (pred_len,) | |
| median = forecast.cpu().numpy() | |
| quantiles = {str(q): median for q in quantile_levels} | |
| # Create output dataframe matching Chronos format | |
| result_df = pd.DataFrame({ | |
| 'predictions': median, | |
| **quantiles | |
| }) | |
| return result_df | |
| def get_model_info(model_name): | |
| """ | |
| Get information about a model. | |
| Args: | |
| model_name: Display name of the model | |
| Returns: | |
| Dictionary with model information | |
| """ | |
| config = ModelConfig.get_model_config(model_name) | |
| if config is None: | |
| return None | |
| return { | |
| "name": model_name, | |
| "model_id": config["model_id"], | |
| "description": config["description"], | |
| "pipeline": config["pipeline_class"].__name__ | |
| } | |