TSF-EM / models.py
JavadBayazi's picture
Add TiRex model support with tirex-ts integration
6c98d0b
raw
history blame
6.79 kB
"""
Model configuration and loading for time series forecasting.
Supports multiple Chronos model variants with different architectures.
"""
import torch
from chronos import Chronos2Pipeline, ChronosPipeline
# Try to import TiRex forecasting library
try:
from tirex import load_model as load_tirex_model
TIREX_AVAILABLE = True
except ImportError:
TIREX_AVAILABLE = False
class ModelConfig:
"""Configuration for available forecasting models"""
CHRONOS_2_MODELS = {
"Chronos-2 (Latest, 120M params)": {
"model_id": "amazon/chronos-2",
"pipeline_class": Chronos2Pipeline,
"description": "Latest Chronos-2 model with 120M parameters"
}
}
CHRONOS_T5_MODELS = {
"Chronos-T5 Tiny (8M params)": {
"model_id": "amazon/chronos-t5-tiny",
"pipeline_class": ChronosPipeline,
"description": "Smallest Chronos-T5 model, fastest inference"
},
"Chronos-T5 Mini (20M params)": {
"model_id": "amazon/chronos-t5-mini",
"pipeline_class": ChronosPipeline,
"description": "Mini Chronos-T5 model"
},
"Chronos-T5 Small (46M params)": {
"model_id": "amazon/chronos-t5-small",
"pipeline_class": ChronosPipeline,
"description": "Small Chronos-T5 model"
},
"Chronos-T5 Base (200M params)": {
"model_id": "amazon/chronos-t5-base",
"pipeline_class": ChronosPipeline,
"description": "Base Chronos-T5 model"
},
"Chronos-T5 Large (710M params)": {
"model_id": "amazon/chronos-t5-large",
"pipeline_class": ChronosPipeline,
"description": "Largest Chronos-T5 model, best accuracy"
}
}
TIREX_MODELS = {
"TiRex (35M params)": {
"model_id": "NX-AI/TiRex",
"pipeline_class": "TiRex",
"description": "TiRex xLSTM-based model, excellent for both short and long-term forecasting"
}
} if TIREX_AVAILABLE else {}
@classmethod
def get_all_models(cls):
"""Get all available models"""
all_models = {}
all_models.update(cls.CHRONOS_2_MODELS)
all_models.update(cls.CHRONOS_T5_MODELS)
all_models.update(cls.TIREX_MODELS)
return all_models
@classmethod
def get_model_names(cls):
"""Get list of model names for dropdown"""
return list(cls.get_all_models().keys())
@classmethod
def get_model_config(cls, model_name):
"""Get configuration for a specific model"""
return cls.get_all_models().get(model_name)
def load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32):
"""
Load a forecasting model pipeline.
Args:
model_name: Display name of the model
device_map: Device to load model on (default: "cpu")
dtype: Data type for model weights (default: torch.float32)
Returns:
Loaded pipeline instance or model
"""
config = ModelConfig.get_model_config(model_name)
if config is None:
raise ValueError(f"Unknown model: {model_name}")
pipeline_class = config["pipeline_class"]
model_id = config["model_id"]
# Load TiRex model differently
if pipeline_class == "TiRex":
if not TIREX_AVAILABLE:
raise ImportError(
"TiRex library not installed. Install with: pip install tirex-ts\n"
"Note: TiRex requires GPU support (CUDA-enabled GPU recommended)"
)
# TiRex uses load_model from tirex library
# backend="torch" for CPU/GPU, device="cuda" or "cpu"
import torch
device = "cuda" if torch.cuda.is_available() and device_map == "cuda" else "cpu"
model = load_tirex_model(model_id, backend="torch", device=device)
return TiRexWrapper(model)
# Load Chronos pipelines
pipeline = pipeline_class.from_pretrained(
model_id,
device_map=device_map,
dtype=dtype,
)
return pipeline
class TiRexWrapper:
"""Wrapper to make TiRex compatible with Chronos pipeline API"""
def __init__(self, model):
self.model = model
def predict_df(self, context_df, prediction_length, quantile_levels, **kwargs):
"""
Wrapper to make TiRex work with the same API as Chronos
TiRex.forecast() may return a tuple of (forecast, metadata) or just tensor
"""
import pandas as pd
import torch
# Convert dataframe to tensor (batch_size=1, sequence_length)
context = torch.tensor(context_df['target'].values, dtype=torch.float32).unsqueeze(0)
# TiRex forecast may return tuple or tensor
with torch.no_grad():
result = self.model.forecast(context=context, prediction_length=prediction_length)
# Handle tuple return (forecast, metadata)
if isinstance(result, tuple):
forecast = result[0]
else:
forecast = result
# forecast is shape (batch, prediction_length) or (batch, prediction_length, samples)
if forecast.dim() == 3: # (batch, pred_len, samples)
forecast = forecast[0] # Take first batch
# Calculate quantiles from samples
quantiles = {}
for q in quantile_levels:
quantiles[str(q)] = torch.quantile(forecast, q, dim=-1).cpu().numpy()
median = torch.median(forecast, dim=-1).values.cpu().numpy()
elif forecast.dim() == 2: # (batch, pred_len) - single prediction
forecast = forecast[0].cpu().numpy() # Take first batch
median = forecast
# Use same value for all quantiles since we don't have distribution
quantiles = {str(q): median for q in quantile_levels}
else: # (pred_len,)
median = forecast.cpu().numpy()
quantiles = {str(q): median for q in quantile_levels}
# Create output dataframe matching Chronos format
result_df = pd.DataFrame({
'predictions': median,
**quantiles
})
return result_df
def get_model_info(model_name):
"""
Get information about a model.
Args:
model_name: Display name of the model
Returns:
Dictionary with model information
"""
config = ModelConfig.get_model_config(model_name)
if config is None:
return None
return {
"name": model_name,
"model_id": config["model_id"],
"description": config["description"],
"pipeline": config["pipeline_class"].__name__
}