TSF-EM / models.py
JavadBayazi's picture
Add TiRex model support with tirex-ts integration
6c98d0b
"""
Model configuration and loading for time series forecasting.
Supports multiple Chronos model variants with different architectures.
"""
import torch
from chronos import Chronos2Pipeline, ChronosPipeline
# Try to import TiRex forecasting library
try:
from tirex import load_model as load_tirex_model
TIREX_AVAILABLE = True
except ImportError:
TIREX_AVAILABLE = False
class ModelConfig:
"""Configuration for available forecasting models"""
CHRONOS_2_MODELS = {
"Chronos-2 (Latest, 120M params)": {
"model_id": "amazon/chronos-2",
"pipeline_class": Chronos2Pipeline,
"description": "Latest Chronos-2 model with 120M parameters"
}
}
CHRONOS_T5_MODELS = {
"Chronos-T5 Tiny (8M params)": {
"model_id": "amazon/chronos-t5-tiny",
"pipeline_class": ChronosPipeline,
"description": "Smallest Chronos-T5 model, fastest inference"
},
"Chronos-T5 Mini (20M params)": {
"model_id": "amazon/chronos-t5-mini",
"pipeline_class": ChronosPipeline,
"description": "Mini Chronos-T5 model"
},
"Chronos-T5 Small (46M params)": {
"model_id": "amazon/chronos-t5-small",
"pipeline_class": ChronosPipeline,
"description": "Small Chronos-T5 model"
},
"Chronos-T5 Base (200M params)": {
"model_id": "amazon/chronos-t5-base",
"pipeline_class": ChronosPipeline,
"description": "Base Chronos-T5 model"
},
"Chronos-T5 Large (710M params)": {
"model_id": "amazon/chronos-t5-large",
"pipeline_class": ChronosPipeline,
"description": "Largest Chronos-T5 model, best accuracy"
}
}
TIREX_MODELS = {
"TiRex (35M params)": {
"model_id": "NX-AI/TiRex",
"pipeline_class": "TiRex",
"description": "TiRex xLSTM-based model, excellent for both short and long-term forecasting"
}
} if TIREX_AVAILABLE else {}
@classmethod
def get_all_models(cls):
"""Get all available models"""
all_models = {}
all_models.update(cls.CHRONOS_2_MODELS)
all_models.update(cls.CHRONOS_T5_MODELS)
all_models.update(cls.TIREX_MODELS)
return all_models
@classmethod
def get_model_names(cls):
"""Get list of model names for dropdown"""
return list(cls.get_all_models().keys())
@classmethod
def get_model_config(cls, model_name):
"""Get configuration for a specific model"""
return cls.get_all_models().get(model_name)
def load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32):
"""
Load a forecasting model pipeline.
Args:
model_name: Display name of the model
device_map: Device to load model on (default: "cpu")
dtype: Data type for model weights (default: torch.float32)
Returns:
Loaded pipeline instance or model
"""
config = ModelConfig.get_model_config(model_name)
if config is None:
raise ValueError(f"Unknown model: {model_name}")
pipeline_class = config["pipeline_class"]
model_id = config["model_id"]
# Load TiRex model differently
if pipeline_class == "TiRex":
if not TIREX_AVAILABLE:
raise ImportError(
"TiRex library not installed. Install with: pip install tirex-ts\n"
"Note: TiRex requires GPU support (CUDA-enabled GPU recommended)"
)
# TiRex uses load_model from tirex library
# backend="torch" for CPU/GPU, device="cuda" or "cpu"
import torch
device = "cuda" if torch.cuda.is_available() and device_map == "cuda" else "cpu"
model = load_tirex_model(model_id, backend="torch", device=device)
return TiRexWrapper(model)
# Load Chronos pipelines
pipeline = pipeline_class.from_pretrained(
model_id,
device_map=device_map,
dtype=dtype,
)
return pipeline
class TiRexWrapper:
"""Wrapper to make TiRex compatible with Chronos pipeline API"""
def __init__(self, model):
self.model = model
def predict_df(self, context_df, prediction_length, quantile_levels, **kwargs):
"""
Wrapper to make TiRex work with the same API as Chronos
TiRex.forecast() may return a tuple of (forecast, metadata) or just tensor
"""
import pandas as pd
import torch
# Convert dataframe to tensor (batch_size=1, sequence_length)
context = torch.tensor(context_df['target'].values, dtype=torch.float32).unsqueeze(0)
# TiRex forecast may return tuple or tensor
with torch.no_grad():
result = self.model.forecast(context=context, prediction_length=prediction_length)
# Handle tuple return (forecast, metadata)
if isinstance(result, tuple):
forecast = result[0]
else:
forecast = result
# forecast is shape (batch, prediction_length) or (batch, prediction_length, samples)
if forecast.dim() == 3: # (batch, pred_len, samples)
forecast = forecast[0] # Take first batch
# Calculate quantiles from samples
quantiles = {}
for q in quantile_levels:
quantiles[str(q)] = torch.quantile(forecast, q, dim=-1).cpu().numpy()
median = torch.median(forecast, dim=-1).values.cpu().numpy()
elif forecast.dim() == 2: # (batch, pred_len) - single prediction
forecast = forecast[0].cpu().numpy() # Take first batch
median = forecast
# Use same value for all quantiles since we don't have distribution
quantiles = {str(q): median for q in quantile_levels}
else: # (pred_len,)
median = forecast.cpu().numpy()
quantiles = {str(q): median for q in quantile_levels}
# Create output dataframe matching Chronos format
result_df = pd.DataFrame({
'predictions': median,
**quantiles
})
return result_df
def get_model_info(model_name):
"""
Get information about a model.
Args:
model_name: Display name of the model
Returns:
Dictionary with model information
"""
config = ModelConfig.get_model_config(model_name)
if config is None:
return None
return {
"name": model_name,
"model_id": config["model_id"],
"description": config["description"],
"pipeline": config["pipeline_class"].__name__
}