Spaces:
Running
Running
File size: 6,785 Bytes
92f4bb2 6c98d0b 92f4bb2 6c98d0b 92f4bb2 6c98d0b 92f4bb2 6c98d0b 92f4bb2 6c98d0b 92f4bb2 6c98d0b 92f4bb2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
"""
Model configuration and loading for time series forecasting.
Supports multiple Chronos model variants with different architectures.
"""
import torch
from chronos import Chronos2Pipeline, ChronosPipeline
# Try to import TiRex forecasting library
try:
from tirex import load_model as load_tirex_model
TIREX_AVAILABLE = True
except ImportError:
TIREX_AVAILABLE = False
class ModelConfig:
"""Configuration for available forecasting models"""
CHRONOS_2_MODELS = {
"Chronos-2 (Latest, 120M params)": {
"model_id": "amazon/chronos-2",
"pipeline_class": Chronos2Pipeline,
"description": "Latest Chronos-2 model with 120M parameters"
}
}
CHRONOS_T5_MODELS = {
"Chronos-T5 Tiny (8M params)": {
"model_id": "amazon/chronos-t5-tiny",
"pipeline_class": ChronosPipeline,
"description": "Smallest Chronos-T5 model, fastest inference"
},
"Chronos-T5 Mini (20M params)": {
"model_id": "amazon/chronos-t5-mini",
"pipeline_class": ChronosPipeline,
"description": "Mini Chronos-T5 model"
},
"Chronos-T5 Small (46M params)": {
"model_id": "amazon/chronos-t5-small",
"pipeline_class": ChronosPipeline,
"description": "Small Chronos-T5 model"
},
"Chronos-T5 Base (200M params)": {
"model_id": "amazon/chronos-t5-base",
"pipeline_class": ChronosPipeline,
"description": "Base Chronos-T5 model"
},
"Chronos-T5 Large (710M params)": {
"model_id": "amazon/chronos-t5-large",
"pipeline_class": ChronosPipeline,
"description": "Largest Chronos-T5 model, best accuracy"
}
}
TIREX_MODELS = {
"TiRex (35M params)": {
"model_id": "NX-AI/TiRex",
"pipeline_class": "TiRex",
"description": "TiRex xLSTM-based model, excellent for both short and long-term forecasting"
}
} if TIREX_AVAILABLE else {}
@classmethod
def get_all_models(cls):
"""Get all available models"""
all_models = {}
all_models.update(cls.CHRONOS_2_MODELS)
all_models.update(cls.CHRONOS_T5_MODELS)
all_models.update(cls.TIREX_MODELS)
return all_models
@classmethod
def get_model_names(cls):
"""Get list of model names for dropdown"""
return list(cls.get_all_models().keys())
@classmethod
def get_model_config(cls, model_name):
"""Get configuration for a specific model"""
return cls.get_all_models().get(model_name)
def load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32):
"""
Load a forecasting model pipeline.
Args:
model_name: Display name of the model
device_map: Device to load model on (default: "cpu")
dtype: Data type for model weights (default: torch.float32)
Returns:
Loaded pipeline instance or model
"""
config = ModelConfig.get_model_config(model_name)
if config is None:
raise ValueError(f"Unknown model: {model_name}")
pipeline_class = config["pipeline_class"]
model_id = config["model_id"]
# Load TiRex model differently
if pipeline_class == "TiRex":
if not TIREX_AVAILABLE:
raise ImportError(
"TiRex library not installed. Install with: pip install tirex-ts\n"
"Note: TiRex requires GPU support (CUDA-enabled GPU recommended)"
)
# TiRex uses load_model from tirex library
# backend="torch" for CPU/GPU, device="cuda" or "cpu"
import torch
device = "cuda" if torch.cuda.is_available() and device_map == "cuda" else "cpu"
model = load_tirex_model(model_id, backend="torch", device=device)
return TiRexWrapper(model)
# Load Chronos pipelines
pipeline = pipeline_class.from_pretrained(
model_id,
device_map=device_map,
dtype=dtype,
)
return pipeline
class TiRexWrapper:
"""Wrapper to make TiRex compatible with Chronos pipeline API"""
def __init__(self, model):
self.model = model
def predict_df(self, context_df, prediction_length, quantile_levels, **kwargs):
"""
Wrapper to make TiRex work with the same API as Chronos
TiRex.forecast() may return a tuple of (forecast, metadata) or just tensor
"""
import pandas as pd
import torch
# Convert dataframe to tensor (batch_size=1, sequence_length)
context = torch.tensor(context_df['target'].values, dtype=torch.float32).unsqueeze(0)
# TiRex forecast may return tuple or tensor
with torch.no_grad():
result = self.model.forecast(context=context, prediction_length=prediction_length)
# Handle tuple return (forecast, metadata)
if isinstance(result, tuple):
forecast = result[0]
else:
forecast = result
# forecast is shape (batch, prediction_length) or (batch, prediction_length, samples)
if forecast.dim() == 3: # (batch, pred_len, samples)
forecast = forecast[0] # Take first batch
# Calculate quantiles from samples
quantiles = {}
for q in quantile_levels:
quantiles[str(q)] = torch.quantile(forecast, q, dim=-1).cpu().numpy()
median = torch.median(forecast, dim=-1).values.cpu().numpy()
elif forecast.dim() == 2: # (batch, pred_len) - single prediction
forecast = forecast[0].cpu().numpy() # Take first batch
median = forecast
# Use same value for all quantiles since we don't have distribution
quantiles = {str(q): median for q in quantile_levels}
else: # (pred_len,)
median = forecast.cpu().numpy()
quantiles = {str(q): median for q in quantile_levels}
# Create output dataframe matching Chronos format
result_df = pd.DataFrame({
'predictions': median,
**quantiles
})
return result_df
def get_model_info(model_name):
"""
Get information about a model.
Args:
model_name: Display name of the model
Returns:
Dictionary with model information
"""
config = ModelConfig.get_model_config(model_name)
if config is None:
return None
return {
"name": model_name,
"model_id": config["model_id"],
"description": config["description"],
"pipeline": config["pipeline_class"].__name__
}
|