| from __future__ import annotations |
|
|
| import warnings |
| from typing import Any |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
|
|
| from models.base_model import BaseForecastModel |
|
|
|
|
| class Chronos2Model(BaseForecastModel): |
| def __init__(self, model_id: str = "amazon/chronos-2", device_map: str | None = None) -> None: |
| self.model_id = model_id |
| self.device_map = device_map |
| self.pipeline = None |
|
|
| @staticmethod |
| def _is_cuda_runtime_error(exc: RuntimeError) -> bool: |
| message = str(exc).lower() |
| cuda_signals = ( |
| "cuda error", |
| "cublas", |
| "cudnn", |
| "device-side assert", |
| "invalid configuration argument", |
| "illegal memory access", |
| "out of memory", |
| ) |
| return any(signal in message for signal in cuda_signals) |
|
|
| def _resolve_device_map(self) -> str: |
| if self.device_map: |
| return self.device_map |
| try: |
| return "cuda" if torch.cuda.is_available() else "cpu" |
| except Exception: |
| return "cpu" |
|
|
| def load_model(self) -> None: |
| try: |
| from chronos import Chronos2Pipeline |
| except ModuleNotFoundError as exc: |
| raise ModuleNotFoundError( |
| "Missing dependency 'chronos-forecasting'. Install it with: " |
| "`pip install chronos-forecasting` " |
| "(if pip resolver conflicts with toto-ts, try `pip install --no-deps chronos-forecasting`)." |
| ) from exc |
|
|
| chosen_device = self._resolve_device_map() |
| try: |
| self.pipeline = Chronos2Pipeline.from_pretrained(self.model_id, device_map=chosen_device) |
| self.device_map = chosen_device |
| except RuntimeError as exc: |
| if chosen_device == "cuda" and self._is_cuda_runtime_error(exc): |
| warnings.warn( |
| "CUDA failed while loading Chronos-2. Falling back to CPU.", |
| RuntimeWarning, |
| stacklevel=2, |
| ) |
| self.pipeline = Chronos2Pipeline.from_pretrained(self.model_id, device_map="cpu") |
| self.device_map = "cpu" |
| else: |
| raise |
|
|
| @staticmethod |
| def _to_long_context(context_df: pd.DataFrame) -> pd.DataFrame: |
| if context_df.empty: |
| raise ValueError("Chronos-2 received an empty context dataframe.") |
|
|
| ts_col = context_df.index.name or "index" |
| prepared = context_df.copy().reset_index().rename(columns={ts_col: "timestamp"}) |
| prepared["timestamp"] = pd.to_datetime(prepared["timestamp"], errors="coerce") |
| prepared = prepared.dropna(subset=["timestamp"]).sort_values("timestamp") |
| if prepared.empty: |
| raise ValueError("Chronos-2 could not parse timestamps from context data.") |
|
|
| id_column = "id" |
| target_column = "target" |
| long_df = prepared.melt( |
| id_vars=["timestamp"], |
| value_vars=list(context_df.columns), |
| var_name=id_column, |
| value_name=target_column, |
| ) |
| long_df[target_column] = pd.to_numeric(long_df[target_column], errors="coerce") |
| long_df = long_df.dropna(subset=[target_column]) |
| if long_df.empty: |
| raise ValueError("Chronos-2 context data has no numeric values after preprocessing.") |
| return long_df[[id_column, "timestamp", target_column]] |
|
|
| @staticmethod |
| def _find_point_column(pred_df: pd.DataFrame) -> Any: |
| if "predictions" in pred_df.columns: |
| return "predictions" |
| if "prediction" in pred_df.columns: |
| return "prediction" |
| if "0.5" in pred_df.columns: |
| return "0.5" |
| if 0.5 in pred_df.columns: |
| return 0.5 |
|
|
| quantile_candidates: list[tuple[float, Any]] = [] |
| for col in pred_df.columns: |
| try: |
| q = float(str(col)) |
| except (TypeError, ValueError): |
| continue |
| quantile_candidates.append((q, col)) |
|
|
| if quantile_candidates: |
| _, closest_col = min(quantile_candidates, key=lambda item: abs(item[0] - 0.5)) |
| return closest_col |
|
|
| raise ValueError("Chronos-2 output did not include a point forecast column.") |
|
|
| @staticmethod |
| def _extract_quantile_columns(pred_df: pd.DataFrame) -> dict[float, Any]: |
| quantile_cols: dict[float, Any] = {} |
| for col in pred_df.columns: |
| try: |
| q = float(str(col)) |
| except (TypeError, ValueError): |
| continue |
| if 0.0 <= q <= 1.0: |
| quantile_cols[q] = col |
| return quantile_cols |
|
|
| @staticmethod |
| def _pivot_forecast( |
| pred_df: pd.DataFrame, |
| value_column: Any, |
| target_cols: list[str], |
| prediction_length: int, |
| ) -> tuple[np.ndarray, pd.DatetimeIndex]: |
| timestamp_col = "timestamp" |
| id_col = "id" |
|
|
| if not {id_col, timestamp_col, value_column}.issubset(set(pred_df.columns)): |
| missing = {id_col, timestamp_col, value_column}.difference(set(pred_df.columns)) |
| raise ValueError(f"Chronos-2 output missing columns: {sorted(missing)}") |
|
|
| prepared = pred_df[[id_col, timestamp_col, value_column]].copy() |
| prepared[timestamp_col] = pd.to_datetime(prepared[timestamp_col], errors="coerce") |
| prepared[value_column] = pd.to_numeric(prepared[value_column], errors="coerce") |
| prepared = prepared.dropna(subset=[timestamp_col, value_column]).sort_values([id_col, timestamp_col]) |
|
|
| pivot = prepared.pivot_table(index=timestamp_col, columns=id_col, values=value_column, aggfunc="last") |
| pivot = pivot.sort_index() |
| pivot = pivot.reindex(columns=target_cols) |
|
|
| if pivot.empty: |
| raise ValueError("Chronos-2 returned empty forecasts after pivot.") |
|
|
| pivot = pivot.tail(prediction_length) |
| if len(pivot) < prediction_length: |
| raise ValueError( |
| f"Chronos-2 returned {len(pivot)} rows, expected at least {prediction_length}." |
| ) |
|
|
| return pivot.values.astype(float), pd.DatetimeIndex(pivot.index) |
|
|
| def predict( |
| self, |
| context_data: Any, |
| prediction_length: int, |
| num_samples: int = 256, |
| quantile_levels: tuple[float, ...] = (0.1, 0.5, 0.9), |
| **kwargs: Any, |
| ) -> dict[str, Any]: |
| _ = num_samples |
| if self.pipeline is None: |
| self.load_model() |
|
|
| if isinstance(context_data, dict): |
| context_df = context_data.get("context") |
| frequency = context_data.get("frequency") |
| target_cols = context_data.get("target_cols") or (list(context_df.columns) if context_df is not None else []) |
| else: |
| context_df = context_data |
| frequency = None |
| target_cols = list(context_df.columns) if isinstance(context_df, pd.DataFrame) else [] |
|
|
| if context_df is None or not isinstance(context_df, pd.DataFrame): |
| raise ValueError("Chronos-2 expects context_data to include a pandas DataFrame under 'context'.") |
| if not target_cols: |
| raise ValueError("Chronos-2 requires at least one target column.") |
|
|
| long_context = self._to_long_context(context_df) |
| prediction_df = self.pipeline.predict_df( |
| long_context, |
| prediction_length=prediction_length, |
| quantile_levels=list(quantile_levels), |
| id_column="id", |
| timestamp_column="timestamp", |
| target="target", |
| **kwargs, |
| ) |
|
|
| if prediction_df is None or len(prediction_df) == 0: |
| raise ValueError("Chronos-2 returned an empty prediction dataframe.") |
|
|
| point_col = self._find_point_column(prediction_df) |
| mean_values, forecast_index = self._pivot_forecast( |
| prediction_df, |
| value_column=point_col, |
| target_cols=target_cols, |
| prediction_length=prediction_length, |
| ) |
|
|
| quantiles: dict[float, np.ndarray] = {} |
| quantile_columns = self._extract_quantile_columns(prediction_df) |
| for q, col in quantile_columns.items(): |
| try: |
| q_values, _ = self._pivot_forecast( |
| prediction_df, |
| value_column=col, |
| target_cols=target_cols, |
| prediction_length=prediction_length, |
| ) |
| quantiles[q] = q_values |
| except Exception: |
| continue |
|
|
| return { |
| "mean": mean_values, |
| "quantiles": quantiles, |
| "samples": None, |
| "forecast_index": forecast_index, |
| "frequency": frequency, |
| "device_used": self.device_map or "auto", |
| "samples_per_batch_used": "n/a", |
| "context_length_used": len(context_df), |
| "num_variates_used": context_df.shape[1], |
| } |
|
|
| def get_model_info(self) -> dict[str, Any]: |
| return { |
| "name": "Chronos-2", |
| "full_name": "amazon/chronos-2", |
| "description": "Universal foundation model by Amazon for univariate/multivariate forecasting.", |
| "supports_multivariate": True, |
| "supports_covariates": True, |
| "max_context_length": "Auto", |
| "parameters": "120M", |
| "device": self.device_map or "auto", |
| } |
|
|