JavadBayazi commited on
Commit
6c98d0b
·
1 Parent(s): 73852fc

Add TiRex model support with tirex-ts integration

Browse files
Files changed (2) hide show
  1. models.py +85 -2
  2. requirements.txt +1 -0
models.py CHANGED
@@ -6,6 +6,13 @@ Supports multiple Chronos model variants with different architectures.
6
  import torch
7
  from chronos import Chronos2Pipeline, ChronosPipeline
8
 
 
 
 
 
 
 
 
9
 
10
  class ModelConfig:
11
  """Configuration for available forecasting models"""
@@ -46,12 +53,21 @@ class ModelConfig:
46
  }
47
  }
48
 
 
 
 
 
 
 
 
 
49
  @classmethod
50
  def get_all_models(cls):
51
  """Get all available models"""
52
  all_models = {}
53
  all_models.update(cls.CHRONOS_2_MODELS)
54
  all_models.update(cls.CHRONOS_T5_MODELS)
 
55
  return all_models
56
 
57
  @classmethod
@@ -75,7 +91,7 @@ def load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32):
75
  dtype: Data type for model weights (default: torch.float32)
76
 
77
  Returns:
78
- Loaded pipeline instance
79
  """
80
  config = ModelConfig.get_model_config(model_name)
81
 
@@ -85,7 +101,21 @@ def load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32):
85
  pipeline_class = config["pipeline_class"]
86
  model_id = config["model_id"]
87
 
88
- # Load the appropriate pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  pipeline = pipeline_class.from_pretrained(
90
  model_id,
91
  device_map=device_map,
@@ -95,6 +125,59 @@ def load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32):
95
  return pipeline
96
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def get_model_info(model_name):
99
  """
100
  Get information about a model.
 
6
  import torch
7
  from chronos import Chronos2Pipeline, ChronosPipeline
8
 
9
+ # Try to import TiRex forecasting library
10
+ try:
11
+ from tirex import load_model as load_tirex_model
12
+ TIREX_AVAILABLE = True
13
+ except ImportError:
14
+ TIREX_AVAILABLE = False
15
+
16
 
17
  class ModelConfig:
18
  """Configuration for available forecasting models"""
 
53
  }
54
  }
55
 
56
+ TIREX_MODELS = {
57
+ "TiRex (35M params)": {
58
+ "model_id": "NX-AI/TiRex",
59
+ "pipeline_class": "TiRex",
60
+ "description": "TiRex xLSTM-based model, excellent for both short and long-term forecasting"
61
+ }
62
+ } if TIREX_AVAILABLE else {}
63
+
64
  @classmethod
65
  def get_all_models(cls):
66
  """Get all available models"""
67
  all_models = {}
68
  all_models.update(cls.CHRONOS_2_MODELS)
69
  all_models.update(cls.CHRONOS_T5_MODELS)
70
+ all_models.update(cls.TIREX_MODELS)
71
  return all_models
72
 
73
  @classmethod
 
91
  dtype: Data type for model weights (default: torch.float32)
92
 
93
  Returns:
94
+ Loaded pipeline instance or model
95
  """
96
  config = ModelConfig.get_model_config(model_name)
97
 
 
101
  pipeline_class = config["pipeline_class"]
102
  model_id = config["model_id"]
103
 
104
+ # Load TiRex model differently
105
+ if pipeline_class == "TiRex":
106
+ if not TIREX_AVAILABLE:
107
+ raise ImportError(
108
+ "TiRex library not installed. Install with: pip install tirex-ts\n"
109
+ "Note: TiRex requires GPU support (CUDA-enabled GPU recommended)"
110
+ )
111
+ # TiRex uses load_model from tirex library
112
+ # backend="torch" for CPU/GPU, device="cuda" or "cpu"
113
+ import torch
114
+ device = "cuda" if torch.cuda.is_available() and device_map == "cuda" else "cpu"
115
+ model = load_tirex_model(model_id, backend="torch", device=device)
116
+ return TiRexWrapper(model)
117
+
118
+ # Load Chronos pipelines
119
  pipeline = pipeline_class.from_pretrained(
120
  model_id,
121
  device_map=device_map,
 
125
  return pipeline
126
 
127
 
128
+ class TiRexWrapper:
129
+ """Wrapper to make TiRex compatible with Chronos pipeline API"""
130
+
131
+ def __init__(self, model):
132
+ self.model = model
133
+
134
+ def predict_df(self, context_df, prediction_length, quantile_levels, **kwargs):
135
+ """
136
+ Wrapper to make TiRex work with the same API as Chronos
137
+ TiRex.forecast() may return a tuple of (forecast, metadata) or just tensor
138
+ """
139
+ import pandas as pd
140
+ import torch
141
+
142
+ # Convert dataframe to tensor (batch_size=1, sequence_length)
143
+ context = torch.tensor(context_df['target'].values, dtype=torch.float32).unsqueeze(0)
144
+
145
+ # TiRex forecast may return tuple or tensor
146
+ with torch.no_grad():
147
+ result = self.model.forecast(context=context, prediction_length=prediction_length)
148
+
149
+ # Handle tuple return (forecast, metadata)
150
+ if isinstance(result, tuple):
151
+ forecast = result[0]
152
+ else:
153
+ forecast = result
154
+
155
+ # forecast is shape (batch, prediction_length) or (batch, prediction_length, samples)
156
+ if forecast.dim() == 3: # (batch, pred_len, samples)
157
+ forecast = forecast[0] # Take first batch
158
+ # Calculate quantiles from samples
159
+ quantiles = {}
160
+ for q in quantile_levels:
161
+ quantiles[str(q)] = torch.quantile(forecast, q, dim=-1).cpu().numpy()
162
+ median = torch.median(forecast, dim=-1).values.cpu().numpy()
163
+ elif forecast.dim() == 2: # (batch, pred_len) - single prediction
164
+ forecast = forecast[0].cpu().numpy() # Take first batch
165
+ median = forecast
166
+ # Use same value for all quantiles since we don't have distribution
167
+ quantiles = {str(q): median for q in quantile_levels}
168
+ else: # (pred_len,)
169
+ median = forecast.cpu().numpy()
170
+ quantiles = {str(q): median for q in quantile_levels}
171
+
172
+ # Create output dataframe matching Chronos format
173
+ result_df = pd.DataFrame({
174
+ 'predictions': median,
175
+ **quantiles
176
+ })
177
+
178
+ return result_df
179
+
180
+
181
  def get_model_info(model_name):
182
  """
183
  Get information about a model.
requirements.txt CHANGED
@@ -6,3 +6,4 @@ matplotlib
6
  pandas
7
  pyarrow
8
  gridstatus
 
 
6
  pandas
7
  pyarrow
8
  gridstatus
9
+ tirex-ts