File size: 6,785 Bytes
92f4bb2
 
 
 
 
 
 
 
6c98d0b
 
 
 
 
 
 
92f4bb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c98d0b
 
 
 
 
 
 
 
92f4bb2
 
 
 
 
 
6c98d0b
92f4bb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c98d0b
92f4bb2
 
 
 
 
 
 
 
 
6c98d0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92f4bb2
 
 
 
 
 
 
 
 
6c98d0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92f4bb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""
Model configuration and loading for time series forecasting.
Supports multiple Chronos model variants with different architectures.
"""

import torch
from chronos import Chronos2Pipeline, ChronosPipeline

# Try to import TiRex forecasting library
try:
    from tirex import load_model as load_tirex_model
    TIREX_AVAILABLE = True
except ImportError:
    TIREX_AVAILABLE = False


class ModelConfig:
    """Configuration for available forecasting models"""
    
    CHRONOS_2_MODELS = {
        "Chronos-2 (Latest, 120M params)": {
            "model_id": "amazon/chronos-2",
            "pipeline_class": Chronos2Pipeline,
            "description": "Latest Chronos-2 model with 120M parameters"
        }
    }
    
    CHRONOS_T5_MODELS = {
        "Chronos-T5 Tiny (8M params)": {
            "model_id": "amazon/chronos-t5-tiny",
            "pipeline_class": ChronosPipeline,
            "description": "Smallest Chronos-T5 model, fastest inference"
        },
        "Chronos-T5 Mini (20M params)": {
            "model_id": "amazon/chronos-t5-mini",
            "pipeline_class": ChronosPipeline,
            "description": "Mini Chronos-T5 model"
        },
        "Chronos-T5 Small (46M params)": {
            "model_id": "amazon/chronos-t5-small",
            "pipeline_class": ChronosPipeline,
            "description": "Small Chronos-T5 model"
        },
        "Chronos-T5 Base (200M params)": {
            "model_id": "amazon/chronos-t5-base",
            "pipeline_class": ChronosPipeline,
            "description": "Base Chronos-T5 model"
        },
        "Chronos-T5 Large (710M params)": {
            "model_id": "amazon/chronos-t5-large",
            "pipeline_class": ChronosPipeline,
            "description": "Largest Chronos-T5 model, best accuracy"
        }
    }
    
    TIREX_MODELS = {
        "TiRex (35M params)": {
            "model_id": "NX-AI/TiRex",
            "pipeline_class": "TiRex",
            "description": "TiRex xLSTM-based model, excellent for both short and long-term forecasting"
        }
    } if TIREX_AVAILABLE else {}
    
    @classmethod
    def get_all_models(cls):
        """Get all available models"""
        all_models = {}
        all_models.update(cls.CHRONOS_2_MODELS)
        all_models.update(cls.CHRONOS_T5_MODELS)
        all_models.update(cls.TIREX_MODELS)
        return all_models
    
    @classmethod
    def get_model_names(cls):
        """Get list of model names for dropdown"""
        return list(cls.get_all_models().keys())
    
    @classmethod
    def get_model_config(cls, model_name):
        """Get configuration for a specific model"""
        return cls.get_all_models().get(model_name)


def load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32):
    """
    Load a forecasting model pipeline.
    
    Args:
        model_name: Display name of the model
        device_map: Device to load model on (default: "cpu")
        dtype: Data type for model weights (default: torch.float32)
    
    Returns:
        Loaded pipeline instance or model
    """
    config = ModelConfig.get_model_config(model_name)
    
    if config is None:
        raise ValueError(f"Unknown model: {model_name}")
    
    pipeline_class = config["pipeline_class"]
    model_id = config["model_id"]
    
    # Load TiRex model differently
    if pipeline_class == "TiRex":
        if not TIREX_AVAILABLE:
            raise ImportError(
                "TiRex library not installed. Install with: pip install tirex-ts\n"
                "Note: TiRex requires GPU support (CUDA-enabled GPU recommended)"
            )
        # TiRex uses load_model from tirex library
        # backend="torch" for CPU/GPU, device="cuda" or "cpu"
        import torch
        device = "cuda" if torch.cuda.is_available() and device_map == "cuda" else "cpu"
        model = load_tirex_model(model_id, backend="torch", device=device)
        return TiRexWrapper(model)
    
    # Load Chronos pipelines
    pipeline = pipeline_class.from_pretrained(
        model_id,
        device_map=device_map,
        dtype=dtype,
    )
    
    return pipeline


class TiRexWrapper:
    """Wrapper to make TiRex compatible with Chronos pipeline API"""
    
    def __init__(self, model):
        self.model = model
    
    def predict_df(self, context_df, prediction_length, quantile_levels, **kwargs):
        """
        Wrapper to make TiRex work with the same API as Chronos
        TiRex.forecast() may return a tuple of (forecast, metadata) or just tensor
        """
        import pandas as pd
        import torch
        
        # Convert dataframe to tensor (batch_size=1, sequence_length)
        context = torch.tensor(context_df['target'].values, dtype=torch.float32).unsqueeze(0)
        
        # TiRex forecast may return tuple or tensor
        with torch.no_grad():
            result = self.model.forecast(context=context, prediction_length=prediction_length)
        
        # Handle tuple return (forecast, metadata)
        if isinstance(result, tuple):
            forecast = result[0]
        else:
            forecast = result
        
        # forecast is shape (batch, prediction_length) or (batch, prediction_length, samples)
        if forecast.dim() == 3:  # (batch, pred_len, samples)
            forecast = forecast[0]  # Take first batch
            # Calculate quantiles from samples
            quantiles = {}
            for q in quantile_levels:
                quantiles[str(q)] = torch.quantile(forecast, q, dim=-1).cpu().numpy()
            median = torch.median(forecast, dim=-1).values.cpu().numpy()
        elif forecast.dim() == 2:  # (batch, pred_len) - single prediction
            forecast = forecast[0].cpu().numpy()  # Take first batch
            median = forecast
            # Use same value for all quantiles since we don't have distribution
            quantiles = {str(q): median for q in quantile_levels}
        else:  # (pred_len,)
            median = forecast.cpu().numpy()
            quantiles = {str(q): median for q in quantile_levels}
        
        # Create output dataframe matching Chronos format
        result_df = pd.DataFrame({
            'predictions': median,
            **quantiles
        })
        
        return result_df


def get_model_info(model_name):
    """
    Get information about a model.
    
    Args:
        model_name: Display name of the model
    
    Returns:
        Dictionary with model information
    """
    config = ModelConfig.get_model_config(model_name)
    
    if config is None:
        return None
    
    return {
        "name": model_name,
        "model_id": config["model_id"],
        "description": config["description"],
        "pipeline": config["pipeline_class"].__name__
    }