Spaces:
Running
Running
Commit
·
6c98d0b
1
Parent(s):
73852fc
Add TiRex model support with tirex-ts integration
Browse files- models.py +85 -2
- requirements.txt +1 -0
models.py
CHANGED
|
@@ -6,6 +6,13 @@ Supports multiple Chronos model variants with different architectures.
|
|
| 6 |
import torch
|
| 7 |
from chronos import Chronos2Pipeline, ChronosPipeline
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
class ModelConfig:
|
| 11 |
"""Configuration for available forecasting models"""
|
|
@@ -46,12 +53,21 @@ class ModelConfig:
|
|
| 46 |
}
|
| 47 |
}
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
@classmethod
|
| 50 |
def get_all_models(cls):
|
| 51 |
"""Get all available models"""
|
| 52 |
all_models = {}
|
| 53 |
all_models.update(cls.CHRONOS_2_MODELS)
|
| 54 |
all_models.update(cls.CHRONOS_T5_MODELS)
|
|
|
|
| 55 |
return all_models
|
| 56 |
|
| 57 |
@classmethod
|
|
@@ -75,7 +91,7 @@ def load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32):
|
|
| 75 |
dtype: Data type for model weights (default: torch.float32)
|
| 76 |
|
| 77 |
Returns:
|
| 78 |
-
Loaded pipeline instance
|
| 79 |
"""
|
| 80 |
config = ModelConfig.get_model_config(model_name)
|
| 81 |
|
|
@@ -85,7 +101,21 @@ def load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32):
|
|
| 85 |
pipeline_class = config["pipeline_class"]
|
| 86 |
model_id = config["model_id"]
|
| 87 |
|
| 88 |
-
# Load
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
pipeline = pipeline_class.from_pretrained(
|
| 90 |
model_id,
|
| 91 |
device_map=device_map,
|
|
@@ -95,6 +125,59 @@ def load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32):
|
|
| 95 |
return pipeline
|
| 96 |
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
def get_model_info(model_name):
|
| 99 |
"""
|
| 100 |
Get information about a model.
|
|
|
|
| 6 |
import torch
|
| 7 |
from chronos import Chronos2Pipeline, ChronosPipeline
|
| 8 |
|
| 9 |
+
# Try to import TiRex forecasting library
|
| 10 |
+
try:
|
| 11 |
+
from tirex import load_model as load_tirex_model
|
| 12 |
+
TIREX_AVAILABLE = True
|
| 13 |
+
except ImportError:
|
| 14 |
+
TIREX_AVAILABLE = False
|
| 15 |
+
|
| 16 |
|
| 17 |
class ModelConfig:
|
| 18 |
"""Configuration for available forecasting models"""
|
|
|
|
| 53 |
}
|
| 54 |
}
|
| 55 |
|
| 56 |
+
TIREX_MODELS = {
|
| 57 |
+
"TiRex (35M params)": {
|
| 58 |
+
"model_id": "NX-AI/TiRex",
|
| 59 |
+
"pipeline_class": "TiRex",
|
| 60 |
+
"description": "TiRex xLSTM-based model, excellent for both short and long-term forecasting"
|
| 61 |
+
}
|
| 62 |
+
} if TIREX_AVAILABLE else {}
|
| 63 |
+
|
| 64 |
@classmethod
|
| 65 |
def get_all_models(cls):
|
| 66 |
"""Get all available models"""
|
| 67 |
all_models = {}
|
| 68 |
all_models.update(cls.CHRONOS_2_MODELS)
|
| 69 |
all_models.update(cls.CHRONOS_T5_MODELS)
|
| 70 |
+
all_models.update(cls.TIREX_MODELS)
|
| 71 |
return all_models
|
| 72 |
|
| 73 |
@classmethod
|
|
|
|
| 91 |
dtype: Data type for model weights (default: torch.float32)
|
| 92 |
|
| 93 |
Returns:
|
| 94 |
+
Loaded pipeline instance or model
|
| 95 |
"""
|
| 96 |
config = ModelConfig.get_model_config(model_name)
|
| 97 |
|
|
|
|
| 101 |
pipeline_class = config["pipeline_class"]
|
| 102 |
model_id = config["model_id"]
|
| 103 |
|
| 104 |
+
# Load TiRex model differently
|
| 105 |
+
if pipeline_class == "TiRex":
|
| 106 |
+
if not TIREX_AVAILABLE:
|
| 107 |
+
raise ImportError(
|
| 108 |
+
"TiRex library not installed. Install with: pip install tirex-ts\n"
|
| 109 |
+
"Note: TiRex requires GPU support (CUDA-enabled GPU recommended)"
|
| 110 |
+
)
|
| 111 |
+
# TiRex uses load_model from tirex library
|
| 112 |
+
# backend="torch" for CPU/GPU, device="cuda" or "cpu"
|
| 113 |
+
import torch
|
| 114 |
+
device = "cuda" if torch.cuda.is_available() and device_map == "cuda" else "cpu"
|
| 115 |
+
model = load_tirex_model(model_id, backend="torch", device=device)
|
| 116 |
+
return TiRexWrapper(model)
|
| 117 |
+
|
| 118 |
+
# Load Chronos pipelines
|
| 119 |
pipeline = pipeline_class.from_pretrained(
|
| 120 |
model_id,
|
| 121 |
device_map=device_map,
|
|
|
|
| 125 |
return pipeline
|
| 126 |
|
| 127 |
|
| 128 |
+
class TiRexWrapper:
|
| 129 |
+
"""Wrapper to make TiRex compatible with Chronos pipeline API"""
|
| 130 |
+
|
| 131 |
+
def __init__(self, model):
|
| 132 |
+
self.model = model
|
| 133 |
+
|
| 134 |
+
def predict_df(self, context_df, prediction_length, quantile_levels, **kwargs):
|
| 135 |
+
"""
|
| 136 |
+
Wrapper to make TiRex work with the same API as Chronos
|
| 137 |
+
TiRex.forecast() may return a tuple of (forecast, metadata) or just tensor
|
| 138 |
+
"""
|
| 139 |
+
import pandas as pd
|
| 140 |
+
import torch
|
| 141 |
+
|
| 142 |
+
# Convert dataframe to tensor (batch_size=1, sequence_length)
|
| 143 |
+
context = torch.tensor(context_df['target'].values, dtype=torch.float32).unsqueeze(0)
|
| 144 |
+
|
| 145 |
+
# TiRex forecast may return tuple or tensor
|
| 146 |
+
with torch.no_grad():
|
| 147 |
+
result = self.model.forecast(context=context, prediction_length=prediction_length)
|
| 148 |
+
|
| 149 |
+
# Handle tuple return (forecast, metadata)
|
| 150 |
+
if isinstance(result, tuple):
|
| 151 |
+
forecast = result[0]
|
| 152 |
+
else:
|
| 153 |
+
forecast = result
|
| 154 |
+
|
| 155 |
+
# forecast is shape (batch, prediction_length) or (batch, prediction_length, samples)
|
| 156 |
+
if forecast.dim() == 3: # (batch, pred_len, samples)
|
| 157 |
+
forecast = forecast[0] # Take first batch
|
| 158 |
+
# Calculate quantiles from samples
|
| 159 |
+
quantiles = {}
|
| 160 |
+
for q in quantile_levels:
|
| 161 |
+
quantiles[str(q)] = torch.quantile(forecast, q, dim=-1).cpu().numpy()
|
| 162 |
+
median = torch.median(forecast, dim=-1).values.cpu().numpy()
|
| 163 |
+
elif forecast.dim() == 2: # (batch, pred_len) - single prediction
|
| 164 |
+
forecast = forecast[0].cpu().numpy() # Take first batch
|
| 165 |
+
median = forecast
|
| 166 |
+
# Use same value for all quantiles since we don't have distribution
|
| 167 |
+
quantiles = {str(q): median for q in quantile_levels}
|
| 168 |
+
else: # (pred_len,)
|
| 169 |
+
median = forecast.cpu().numpy()
|
| 170 |
+
quantiles = {str(q): median for q in quantile_levels}
|
| 171 |
+
|
| 172 |
+
# Create output dataframe matching Chronos format
|
| 173 |
+
result_df = pd.DataFrame({
|
| 174 |
+
'predictions': median,
|
| 175 |
+
**quantiles
|
| 176 |
+
})
|
| 177 |
+
|
| 178 |
+
return result_df
|
| 179 |
+
|
| 180 |
+
|
| 181 |
def get_model_info(model_name):
|
| 182 |
"""
|
| 183 |
Get information about a model.
|
requirements.txt
CHANGED
|
@@ -6,3 +6,4 @@ matplotlib
|
|
| 6 |
pandas
|
| 7 |
pyarrow
|
| 8 |
gridstatus
|
|
|
|
|
|
| 6 |
pandas
|
| 7 |
pyarrow
|
| 8 |
gridstatus
|
| 9 |
+
tirex-ts
|