| """ |
| Granite Time Series (TTM) κΈ°λ° μκ³μ΄ μμΈ‘ μμ§. |
| |
| IBM Granite TinyTimeMixer (TTM) λͺ¨λΈμ μ¬μ©νμ¬ νκ΅λ³ |
| νμ μ λ± μ§νμ ν₯ν 3~5λ
μ μμΈ‘ν©λλ€. |
| |
| λͺ¨λΈ λ‘λ μ°μ μμ: |
| 1. tsfm_public λΌμ΄λΈλ¬λ¦¬ (IBM 곡μ λνΌ) |
| 2. transformers μ§μ λ‘λ |
| 3. statsforecast AutoARIMA (fallback) |
| |
| μ°Έκ³ : |
| https://huggingface.co/ibm-granite/granite-timeseries-ttm-r2 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| from dataclasses import dataclass, field |
| from typing import Any |
|
|
| import numpy as np |
| import pandas as pd |
|
|
| from src.config import get_settings |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class ForecastResult: |
| """μμΈ‘ κ²°κ³Ό 컨ν
μ΄λ.""" |
|
|
| schul_code: str |
| target_col: str |
| forecast_years: list[int] |
| point_forecast: list[float] |
| lower_bound: list[float] |
| upper_bound: list[float] |
| model_version: str |
| context_years: list[int] = field(default_factory=list) |
| context_values: list[float] = field(default_factory=list) |
|
|
| def to_dict(self) -> dict[str, Any]: |
| return { |
| "schul_code": self.schul_code, |
| "target_col": self.target_col, |
| "forecast_years": self.forecast_years, |
| "point_forecast": self.point_forecast, |
| "lower_bound": self.lower_bound, |
| "upper_bound": self.upper_bound, |
| "model_version": self.model_version, |
| "context_years": self.context_years, |
| "context_values": self.context_values, |
| } |
|
|
|
|
| def _monthly_to_yearly(monthly_values: np.ndarray, agg: str = "last") -> np.ndarray: |
| """ |
| μλ³ μμΈ‘κ°μ μ°κ° μ§κ³κ°μΌλ‘ λ³νν©λλ€. |
| |
| Parameters |
| ---------- |
| monthly_values: |
| shape (n_months,) μ μμΈ‘κ° λ°°μ΄. |
| agg: |
| "last" β μ°λ§ κΈ°μ€, "mean" β μ°νκ· . |
| """ |
| n_years = len(monthly_values) // 12 |
| remainder = len(monthly_values) % 12 |
| result = [] |
| for i in range(n_years): |
| year_slice = monthly_values[i * 12 : (i + 1) * 12] |
| result.append(float(year_slice[-1] if agg == "last" else year_slice.mean())) |
| if remainder > 0: |
| result.append(float(monthly_values[n_years * 12 :].mean())) |
| return np.array(result) |
|
|
|
|
| class GranitePredictor: |
| """ |
| IBM Granite TTM κΈ°λ° μκ³μ΄ μμΈ‘κΈ°. |
| |
| μ¬μ© μ:: |
| |
| predictor = GranitePredictor() |
| result = predictor.predict( |
| schul_code="7431234", |
| timeseries=pd.Series({2018: 120, 2019: 105, 2020: 98, 2021: 87, 2022: 75}), |
| horizon_years=5, |
| ) |
| """ |
|
|
| _MODEL_ID_TTM = "ibm-granite/granite-timeseries-ttm-r2" |
| _MODEL_VERSION_TTM = "granite-ttm-r2" |
| _MODEL_VERSION_ARIMA = "arima-fallback" |
|
|
| def __init__(self) -> None: |
| cfg = get_settings() |
| self._model_id = cfg.granite_model_id |
| self._context_length = cfg.prediction_context_length |
| self._horizon_months = cfg.prediction_horizon_months |
| self._hf_token = cfg.huggingface_hub_token or None |
| self._pipeline: Any = None |
| self._model_version: str = "" |
|
|
| |
|
|
| def _load_ttm(self) -> bool: |
| """tsfm_public λλ transformers λ‘ TTM λͺ¨λΈ λ‘λλ₯Ό μλν©λλ€.""" |
| |
| try: |
| from tsfm_public.models.tinytimemixer import TinyTimeMixerForPrediction |
| from tsfm_public.toolkit.time_series_forecasting_pipeline import ( |
| TimeSeriesForecastingPipeline, |
| ) |
|
|
| model = TinyTimeMixerForPrediction.from_pretrained( |
| self._model_id, |
| token=self._hf_token, |
| ) |
| self._pipeline = TimeSeriesForecastingPipeline( |
| model=model, |
| context_length=self._context_length, |
| prediction_length=self._horizon_months, |
| ) |
| self._model_version = self._MODEL_VERSION_TTM |
| logger.info("TTM λͺ¨λΈ λ‘λ μλ£ (tsfm_public): %s", self._model_id) |
| return True |
| except Exception as exc: |
| logger.warning("tsfm_public λ‘λ μ€ν¨: %s", exc) |
|
|
| |
| try: |
| from transformers import AutoConfig, AutoModel |
|
|
| config = AutoConfig.from_pretrained(self._model_id, token=self._hf_token) |
| model = AutoModel.from_pretrained(self._model_id, config=config, token=self._hf_token) |
| model.eval() |
| self._pipeline = model |
| self._model_version = self._MODEL_VERSION_TTM + "-raw" |
| logger.info("TTM λͺ¨λΈ λ‘λ μλ£ (transformers): %s", self._model_id) |
| return True |
| except Exception as exc: |
| logger.warning("transformers λ‘λ μ€ν¨: %s", exc) |
|
|
| return False |
|
|
| def _ensure_model(self) -> None: |
| """λͺ¨λΈμ΄ λ‘λλμ§ μμ κ²½μ° μ΄κΈ°νλ₯Ό μλν©λλ€.""" |
| if self._pipeline is not None or self._model_version == self._MODEL_VERSION_ARIMA: |
| return |
| if not self._load_ttm(): |
| logger.warning("TTM λ‘λ μ€ν¨. statsforecast AutoARIMA fallback μ¬μ©.") |
| self._model_version = self._MODEL_VERSION_ARIMA |
|
|
| |
|
|
| def _predict_ttm( |
| self, |
| series: pd.Series, |
| horizon_months: int, |
| ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: |
| """ |
| TTM νμ΄νλΌμΈμΌλ‘ μμΈ‘μ μνν©λλ€. |
| |
| Returns |
| ------- |
| (point, lower, upper) β κ°κ° shape (horizon_months,) ndarray. |
| """ |
| |
| idx = pd.date_range(start=f"{int(series.index.min())}-01-01", periods=len(series), freq="YS") |
| annual = pd.Series(series.values, index=idx, dtype=float) |
| monthly = annual.resample("MS").interpolate(method="linear") |
|
|
| |
| context = monthly.values[-self._context_length :].astype(np.float32) |
| context = context.reshape(1, -1, 1) |
|
|
| import torch |
|
|
| with torch.no_grad(): |
| tensor_input = torch.tensor(context) |
| try: |
| |
| output = self._pipeline(tensor_input) |
| if hasattr(output, "prediction_outputs"): |
| preds = output.prediction_outputs.squeeze().numpy() |
| elif hasattr(output, "last_hidden_state"): |
| preds = output.last_hidden_state.squeeze().numpy() |
| else: |
| preds = output.squeeze().numpy() |
| except TypeError: |
| |
| output = self._pipeline(inputs_embeds=tensor_input) |
| preds = output.last_hidden_state.mean(dim=-1).squeeze().numpy() |
|
|
| |
| if len(preds) < horizon_months: |
| preds = np.pad(preds, (0, horizon_months - len(preds)), mode="edge") |
| preds = preds[:horizon_months] |
| preds = np.maximum(preds, 0) |
|
|
| |
| std = float(np.std(context)) or float(np.abs(preds.mean()) * 0.1) or 1.0 |
| lower = np.maximum(preds - 1.282 * std, 0) |
| upper = preds + 1.282 * std |
|
|
| return preds, lower, upper |
|
|
| def _predict_arima( |
| self, |
| series: pd.Series, |
| horizon_months: int, |
| ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: |
| """ |
| statsforecast AutoARIMA fallback μμΈ‘. |
| |
| Returns |
| ------- |
| (point, lower, upper) β κ°κ° shape (horizon_months,) ndarray. |
| """ |
| try: |
| from statsforecast import StatsForecast |
| from statsforecast.models import AutoARIMA |
| except ImportError as exc: |
| raise ImportError("statsforecast κ° μ€μΉλμ§ μμμ΅λλ€: pip install statsforecast") from exc |
|
|
| |
| idx = pd.date_range(start=f"{int(series.index.min())}-01-01", periods=len(series), freq="YS") |
| annual = pd.Series(series.values, index=idx, dtype=float) |
| monthly = annual.resample("MS").interpolate(method="linear") |
|
|
| sf_df = pd.DataFrame({ |
| "unique_id": "school", |
| "ds": monthly.index, |
| "y": monthly.values, |
| }) |
|
|
| sf = StatsForecast(models=[AutoARIMA(season_length=12)], freq="MS") |
| forecast = sf.forecast(df=sf_df, h=horizon_months, level=[80]) |
|
|
| point = np.maximum(forecast["AutoARIMA"].values, 0) |
| lower = np.maximum(forecast.get("AutoARIMA-lo-80", pd.Series(point * 0.9)).values, 0) |
| upper = forecast.get("AutoARIMA-hi-80", pd.Series(point * 1.1)).values |
|
|
| return point, lower, upper |
|
|
| |
|
|
| def predict( |
| self, |
| schul_code: str, |
| timeseries: pd.Series, |
| horizon_years: int = 5, |
| target_col: str = "student_count", |
| ) -> ForecastResult: |
| """ |
| νκ΅ μκ³μ΄ λ°μ΄ν°λ₯Ό λ°νμΌλ‘ λ―Έλ κ°μ μμΈ‘ν©λλ€. |
| |
| Parameters |
| ---------- |
| schul_code: |
| λμ νκ΅ SD_SCHUL_CODE. |
| timeseries: |
| μ°λ(int)λ₯Ό μΈλ±μ€λ‘ νλ μκ³μ΄ Series. |
| μ΅μ 3κ° μ΄μμ κ΄μΈ‘κ°μ κΆμ₯ν©λλ€. |
| horizon_years: |
| μμΈ‘ κΈ°κ° (λ
λ¨μ, κΈ°λ³Έ 5). |
| target_col: |
| μμΈ‘ λμ 컬λΌλͺ
(κ²°κ³Ό λ μ΄λΈμ©). |
| |
| Returns |
| ------- |
| ForecastResult |
| """ |
| if timeseries.empty or len(timeseries.dropna()) < 2: |
| raise ValueError( |
| f"νκ΅ '{schul_code}': μμΈ‘μ νμν μκ³μ΄ λ°μ΄ν°κ° λΆμ‘±ν©λλ€ " |
| f"(μ΅μ 2κ° κ΄μΈ‘κ° νμ, νμ¬ {len(timeseries.dropna())}κ°)." |
| ) |
|
|
| ts = timeseries.dropna().sort_index().astype(float) |
| horizon_months = horizon_years * 12 |
|
|
| self._ensure_model() |
|
|
| try: |
| if self._model_version != self._MODEL_VERSION_ARIMA and self._pipeline is not None: |
| point_m, lower_m, upper_m = self._predict_ttm(ts, horizon_months) |
| model_ver = self._model_version |
| else: |
| point_m, lower_m, upper_m = self._predict_arima(ts, horizon_months) |
| model_ver = self._MODEL_VERSION_ARIMA |
| except Exception as exc: |
| logger.warning("TTM μμΈ‘ μ€ν¨, ARIMA fallback μ¬μ©: %s", exc) |
| point_m, lower_m, upper_m = self._predict_arima(ts, horizon_months) |
| model_ver = self._MODEL_VERSION_ARIMA |
|
|
| |
| point_y = _monthly_to_yearly(point_m) |
| lower_y = _monthly_to_yearly(lower_m) |
| upper_y = _monthly_to_yearly(upper_m) |
|
|
| last_year = int(ts.index.max()) |
| forecast_years = list(range(last_year + 1, last_year + 1 + len(point_y))) |
|
|
| logger.info( |
| "μμΈ‘ μλ£: schul_code=%s model=%s horizon=%dy point_mean=%.1f", |
| schul_code, |
| model_ver, |
| horizon_years, |
| float(np.mean(point_y)), |
| ) |
|
|
| return ForecastResult( |
| schul_code=schul_code, |
| target_col=target_col, |
| forecast_years=forecast_years, |
| point_forecast=[round(float(v), 1) for v in point_y], |
| lower_bound=[round(float(v), 1) for v in lower_y], |
| upper_bound=[round(float(v), 1) for v in upper_y], |
| model_version=model_ver, |
| context_years=[int(y) for y in ts.index.tolist()], |
| context_values=[round(float(v), 1) for v in ts.tolist()], |
| ) |
|
|