| |
| import os |
| import logging |
| from typing import Optional, Dict |
|
|
| import torch |
| import pandas as pd |
| import numpy as np |
|
|
| from utils.tracing import Tracer |
| from utils.config import AppConfig |
| from transformers import AutoModel, AutoConfig |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| MIN_SERIES_LENGTH = 2 |
| MAX_SERIES_LENGTH = 10000 |
| MIN_HORIZON = 1 |
| MAX_HORIZON = 365 |
| DEFAULT_MODEL_ID = "ibm-granite/granite-timeseries-ttm-r1" |
|
|
|
|
| class ForecastToolError(Exception): |
| """Custom exception for forecast tool errors.""" |
| pass |
|
|
|
|
| class TimeseriesForecastTool: |
| """ |
| Lightweight wrapper around Granite Time Series models for zero-shot forecasting. |
| |
| This wrapper: |
| - Loads the model with AutoModel.from_pretrained |
| - Validates input series and horizon |
| - Attempts multiple inference methods (predict, forward with prediction_length) |
| - Returns a Pandas DataFrame with forecast column |
| - Provides comprehensive error handling and logging |
| |
| Expected input: |
| - series: pd.Series with DatetimeIndex (regular frequency recommended) |
| - horizon: int, number of future steps to forecast |
| """ |
|
|
| def __init__( |
| self, |
| cfg: Optional[AppConfig], |
| tracer: Optional[Tracer], |
| model_id: str = DEFAULT_MODEL_ID, |
| device: Optional[str] = None, |
| ): |
| self.cfg = cfg |
| self.tracer = tracer |
| self.model_id = model_id |
| self.model = None |
| self.config = None |
| |
| |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| logger.info(f"TimeseriesForecastTool initialized with device: {self.device}") |
| |
| |
| self._initialized = False |
| |
| def _ensure_loaded(self): |
| """Lazy load the model and configuration.""" |
| if self._initialized: |
| return |
| |
| try: |
| logger.info(f"Loading Granite time series model: {self.model_id}") |
| |
| |
| try: |
| self.config = AutoConfig.from_pretrained(self.model_id) |
| logger.info(f"Model config loaded: {type(self.config).__name__}") |
| except Exception as e: |
| logger.warning(f"Could not load model config: {e}") |
| self.config = None |
| |
| |
| try: |
| self.model = AutoModel.from_pretrained( |
| self.model_id, |
| trust_remote_code=True |
| ) |
| self.model.to(self.device) |
| self.model.eval() |
| logger.info(f"Model loaded successfully: {type(self.model).__name__}") |
| |
| except Exception as e: |
| raise ForecastToolError( |
| f"Failed to load model '{self.model_id}': {e}\n" |
| "Ensure the model is available and transformers is up to date." |
| ) from e |
| |
| self._initialized = True |
| |
| except ForecastToolError: |
| raise |
| except Exception as e: |
| raise ForecastToolError(f"Model initialization failed: {e}") from e |
| |
| def _validate_series(self, series: pd.Series) -> tuple[bool, str]: |
| """ |
| Validate input time series. |
| Returns (is_valid, error_message). |
| """ |
| if not isinstance(series, pd.Series): |
| return False, "Input must be a pandas Series" |
| |
| if series.empty: |
| return False, "Series is empty" |
| |
| if len(series) < MIN_SERIES_LENGTH: |
| return False, f"Series too short (min {MIN_SERIES_LENGTH} points required)" |
| |
| if len(series) > MAX_SERIES_LENGTH: |
| return False, f"Series too long (max {MAX_SERIES_LENGTH} points allowed)" |
| |
| |
| if series.isnull().any(): |
| null_count = series.isnull().sum() |
| return False, f"Series contains {null_count} null values. Please handle missing data first." |
| |
| |
| if not np.isfinite(series).all(): |
| return False, "Series contains infinite values" |
| |
| |
| if not pd.api.types.is_numeric_dtype(series): |
| return False, f"Series must be numeric, got dtype: {series.dtype}" |
| |
| return True, "" |
| |
| def _validate_horizon(self, horizon: int) -> tuple[bool, str]: |
| """ |
| Validate forecast horizon. |
| Returns (is_valid, error_message). |
| """ |
| try: |
| h = int(horizon) |
| except (TypeError, ValueError): |
| return False, f"Horizon must be an integer, got: {horizon}" |
| |
| if h < MIN_HORIZON: |
| return False, f"Horizon too small (min {MIN_HORIZON})" |
| |
| if h > MAX_HORIZON: |
| return False, f"Horizon too large (max {MAX_HORIZON})" |
| |
| return True, "" |
| |
| def _prepare_input_tensor(self, series: pd.Series) -> torch.Tensor: |
| """ |
| Convert pandas Series to PyTorch tensor. |
| Handles type conversion and device placement. |
| """ |
| try: |
| |
| values = series.astype("float32").to_numpy() |
| |
| |
| tensor = torch.tensor(values, dtype=torch.float32, device=self.device) |
| |
| |
| tensor = tensor.unsqueeze(0) |
| |
| logger.debug(f"Input tensor shape: {tensor.shape}, device: {tensor.device}") |
| |
| return tensor |
| |
| except Exception as e: |
| raise ForecastToolError(f"Failed to prepare input tensor: {e}") from e |
| |
| def _try_predict_method(self, x: torch.Tensor, horizon: int) -> Optional[np.ndarray]: |
| """ |
| Try using the model's .predict() method. |
| Returns None if method doesn't exist or fails. |
| """ |
| if not hasattr(self.model, "predict"): |
| logger.debug("Model has no 'predict' method") |
| return None |
| |
| try: |
| logger.info("Attempting forecast with .predict() method") |
| preds = self.model.predict(x, prediction_length=horizon) |
| |
| |
| if not isinstance(preds, torch.Tensor): |
| preds = torch.tensor(preds, device=self.device) |
| |
| |
| output = preds.squeeze().detach().cpu().numpy() |
| |
| |
| if output.shape[-1] != horizon: |
| logger.warning( |
| f"Prediction length mismatch: expected {horizon}, got {output.shape[-1]}" |
| ) |
| |
| logger.info(f"Forecast successful via .predict(): {output.shape}") |
| return output |
| |
| except Exception as e: |
| logger.warning(f"predict() method failed: {e}") |
| return None |
| |
| def _try_forward_method(self, x: torch.Tensor, horizon: int) -> Optional[np.ndarray]: |
| """ |
| Try using the model's forward() method with prediction_length parameter. |
| Returns None if method fails. |
| """ |
| try: |
| logger.info("Attempting forecast with forward(prediction_length=...)") |
| outputs = self.model(x, prediction_length=horizon) |
| |
| |
| prediction_tensor = None |
| |
| |
| for attr in ("predictions", "prediction", "logits", "forecast", "output"): |
| if hasattr(outputs, attr): |
| candidate = getattr(outputs, attr) |
| |
| |
| if isinstance(candidate, (tuple, list)): |
| candidate = candidate[0] |
| |
| |
| if not isinstance(candidate, torch.Tensor): |
| candidate = torch.tensor(candidate, device=self.device) |
| |
| prediction_tensor = candidate |
| logger.debug(f"Found predictions in attribute: {attr}") |
| break |
| |
| |
| if prediction_tensor is None and isinstance(outputs, torch.Tensor): |
| prediction_tensor = outputs |
| logger.debug("Using raw tensor output") |
| |
| if prediction_tensor is None: |
| logger.warning("Could not extract predictions from forward() output") |
| return None |
| |
| |
| output = prediction_tensor.squeeze().detach().cpu().numpy() |
| |
| |
| if output.ndim > 1: |
| |
| if output.shape[0] == horizon: |
| output = output.flatten() |
| else: |
| output = output[-1] if output.shape[0] < output.shape[1] else output.flatten() |
| |
| |
| if len(output) != horizon: |
| logger.warning( |
| f"Output length {len(output)} doesn't match horizon {horizon}. Truncating/padding." |
| ) |
| if len(output) > horizon: |
| output = output[:horizon] |
| else: |
| |
| output = np.pad(output, (0, horizon - len(output)), mode='edge') |
| |
| logger.info(f"Forecast successful via forward(): {output.shape}") |
| return output |
| |
| except TypeError as e: |
| logger.warning(f"forward() doesn't accept prediction_length: {e}") |
| return None |
| except Exception as e: |
| logger.warning(f"forward() method failed: {e}") |
| return None |
| |
| def zeroshot_forecast(self, series: pd.Series, horizon: int = 96) -> pd.DataFrame: |
| """ |
| Generate zero-shot forecast for input time series. |
| |
| Args: |
| series: Input time series (pd.Series with numeric values) |
| horizon: Number of periods to forecast (default: 96) |
| |
| Returns: |
| DataFrame with 'forecast' column containing predictions |
| |
| Raises: |
| ForecastToolError: If forecasting fails |
| """ |
| try: |
| |
| is_valid, error_msg = self._validate_series(series) |
| if not is_valid: |
| raise ForecastToolError(f"Invalid series: {error_msg}") |
| |
| is_valid, error_msg = self._validate_horizon(horizon) |
| if not is_valid: |
| raise ForecastToolError(f"Invalid horizon: {error_msg}") |
| |
| |
| self._ensure_loaded() |
| |
| |
| logger.info( |
| f"Forecasting: series_length={len(series)}, " |
| f"horizon={horizon}, " |
| f"series_mean={series.mean():.2f}, " |
| f"series_std={series.std():.2f}" |
| ) |
| |
| |
| x = self._prepare_input_tensor(series) |
| |
| |
| output = None |
| |
| with torch.no_grad(): |
| |
| output = self._try_predict_method(x, horizon) |
| |
| |
| if output is None: |
| output = self._try_forward_method(x, horizon) |
| |
| |
| if output is None: |
| raise ForecastToolError( |
| "Could not generate forecast using available model methods.\n" |
| "The model may not support zero-shot forecasting with this interface.\n" |
| "Suggestions:\n" |
| " • Check model documentation for correct usage\n" |
| " • Ensure transformers library is up to date\n" |
| " • Try a different model or use traditional forecasting (ARIMA, Prophet)\n" |
| f" • Model type: {type(self.model).__name__}" |
| ) |
| |
| |
| result_df = pd.DataFrame({"forecast": output}) |
| |
| |
| logger.info( |
| f"Forecast complete: " |
| f"mean={output.mean():.2f}, " |
| f"std={output.std():.2f}, " |
| f"min={output.min():.2f}, " |
| f"max={output.max():.2f}" |
| ) |
| |
| |
| if self.tracer: |
| self.tracer.trace_event("forecast", { |
| "series_length": len(series), |
| "horizon": horizon, |
| "forecast_mean": float(output.mean()), |
| "forecast_std": float(output.std()) |
| }) |
| |
| return result_df |
| |
| except ForecastToolError: |
| raise |
| except Exception as e: |
| error_msg = f"Forecasting failed unexpectedly: {str(e)}" |
| logger.error(error_msg) |
| if self.tracer: |
| self.tracer.trace_event("forecast_error", {"error": error_msg}) |
| raise ForecastToolError(error_msg) from e |
| |
| def get_model_info(self) -> Dict[str, any]: |
| """Get information about the loaded model.""" |
| self._ensure_loaded() |
| |
| return { |
| "model_id": self.model_id, |
| "model_type": type(self.model).__name__, |
| "device": str(self.device), |
| "has_predict": hasattr(self.model, "predict"), |
| "config": str(self.config) if self.config else None |
| } |