""" Granite Time Series (TTM) 기반 시계열 예측 엔진. IBM Granite TinyTimeMixer (TTM) 모델을 사용하여 학교별 학생 수 등 지표의 향후 3~5년을 예측합니다. 모델 로드 우선순위: 1. tsfm_public 라이브러리 (IBM 공식 래퍼) 2. transformers 직접 로드 3. statsforecast AutoARIMA (fallback) 참고: https://huggingface.co/ibm-granite/granite-timeseries-ttm-r2 """ from __future__ import annotations import logging from dataclasses import dataclass, field from typing import Any import numpy as np import pandas as pd from src.config import get_settings logger = logging.getLogger(__name__) @dataclass class ForecastResult: """예측 결과 컨테이너.""" schul_code: str target_col: str forecast_years: list[int] point_forecast: list[float] lower_bound: list[float] # 10th percentile upper_bound: list[float] # 90th percentile model_version: str context_years: list[int] = field(default_factory=list) context_values: list[float] = field(default_factory=list) def to_dict(self) -> dict[str, Any]: return { "schul_code": self.schul_code, "target_col": self.target_col, "forecast_years": self.forecast_years, "point_forecast": self.point_forecast, "lower_bound": self.lower_bound, "upper_bound": self.upper_bound, "model_version": self.model_version, "context_years": self.context_years, "context_values": self.context_values, } def _monthly_to_yearly(monthly_values: np.ndarray, agg: str = "last") -> np.ndarray: """ 월별 예측값을 연간 집계값으로 변환합니다. Parameters ---------- monthly_values: shape (n_months,) 의 예측값 배열. agg: "last" → 연말 기준, "mean" → 연평균. """ n_years = len(monthly_values) // 12 remainder = len(monthly_values) % 12 result = [] for i in range(n_years): year_slice = monthly_values[i * 12 : (i + 1) * 12] result.append(float(year_slice[-1] if agg == "last" else year_slice.mean())) if remainder > 0: result.append(float(monthly_values[n_years * 12 :].mean())) return np.array(result) class GranitePredictor: """ IBM Granite TTM 기반 시계열 예측기. 사용 예:: predictor = GranitePredictor() result = predictor.predict( schul_code="7431234", timeseries=pd.Series({2018: 120, 2019: 105, 2020: 98, 2021: 87, 2022: 75}), horizon_years=5, ) """ _MODEL_ID_TTM = "ibm-granite/granite-timeseries-ttm-r2" _MODEL_VERSION_TTM = "granite-ttm-r2" _MODEL_VERSION_ARIMA = "arima-fallback" def __init__(self) -> None: cfg = get_settings() self._model_id = cfg.granite_model_id self._context_length = cfg.prediction_context_length # 월 단위 self._horizon_months = cfg.prediction_horizon_months self._hf_token = cfg.huggingface_hub_token or None self._pipeline: Any = None # lazy load self._model_version: str = "" # ── 모델 초기화 ─────────────────────────────────────────────────────── def _load_ttm(self) -> bool: """tsfm_public 또는 transformers 로 TTM 모델 로드를 시도합니다.""" # 1순위: tsfm_public try: from tsfm_public.models.tinytimemixer import TinyTimeMixerForPrediction # type: ignore[import] from tsfm_public.toolkit.time_series_forecasting_pipeline import ( # type: ignore[import] TimeSeriesForecastingPipeline, ) model = TinyTimeMixerForPrediction.from_pretrained( self._model_id, token=self._hf_token, ) self._pipeline = TimeSeriesForecastingPipeline( model=model, context_length=self._context_length, prediction_length=self._horizon_months, ) self._model_version = self._MODEL_VERSION_TTM logger.info("TTM 모델 로드 완료 (tsfm_public): %s", self._model_id) return True except Exception as exc: # noqa: BLE001 logger.warning("tsfm_public 로드 실패: %s", exc) # 2순위: transformers AutoModel try: from transformers import AutoConfig, AutoModel # type: ignore[import] config = AutoConfig.from_pretrained(self._model_id, token=self._hf_token) model = AutoModel.from_pretrained(self._model_id, config=config, token=self._hf_token) model.eval() self._pipeline = model self._model_version = self._MODEL_VERSION_TTM + "-raw" logger.info("TTM 모델 로드 완료 (transformers): %s", self._model_id) return True except Exception as exc: # noqa: BLE001 logger.warning("transformers 로드 실패: %s", exc) return False def _ensure_model(self) -> None: """모델이 로드되지 않은 경우 초기화를 시도합니다.""" if self._pipeline is not None or self._model_version == self._MODEL_VERSION_ARIMA: return if not self._load_ttm(): logger.warning("TTM 로드 실패. statsforecast AutoARIMA fallback 사용.") self._model_version = self._MODEL_VERSION_ARIMA # ── 예측 내부 메서드 ────────────────────────────────────────────────── def _predict_ttm( self, series: pd.Series, horizon_months: int, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ TTM 파이프라인으로 예측을 수행합니다. Returns ------- (point, lower, upper) — 각각 shape (horizon_months,) ndarray. """ # 연간 데이터를 월별로 업샘플링 (선형 보간) idx = pd.date_range(start=f"{int(series.index.min())}-01-01", periods=len(series), freq="YS") annual = pd.Series(series.values, index=idx, dtype=float) monthly = annual.resample("MS").interpolate(method="linear") # 최근 context_length 개월 추출 context = monthly.values[-self._context_length :].astype(np.float32) context = context.reshape(1, -1, 1) # (batch=1, seq_len, channels=1) import torch # lazy import with torch.no_grad(): tensor_input = torch.tensor(context) try: # tsfm_public TimeSeriesForecastingPipeline 호출 output = self._pipeline(tensor_input) if hasattr(output, "prediction_outputs"): preds = output.prediction_outputs.squeeze().numpy() elif hasattr(output, "last_hidden_state"): preds = output.last_hidden_state.squeeze().numpy() else: preds = output.squeeze().numpy() except TypeError: # raw transformers 모델 직접 호출 output = self._pipeline(inputs_embeds=tensor_input) preds = output.last_hidden_state.mean(dim=-1).squeeze().numpy() # 예측 길이 맞추기 if len(preds) < horizon_months: preds = np.pad(preds, (0, horizon_months - len(preds)), mode="edge") preds = preds[:horizon_months] preds = np.maximum(preds, 0) # 불확실성 추정 (±10% of std) std = float(np.std(context)) or float(np.abs(preds.mean()) * 0.1) or 1.0 lower = np.maximum(preds - 1.282 * std, 0) # 10th percentile upper = preds + 1.282 * std # 90th percentile return preds, lower, upper def _predict_arima( self, series: pd.Series, horizon_months: int, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ statsforecast AutoARIMA fallback 예측. Returns ------- (point, lower, upper) — 각각 shape (horizon_months,) ndarray. """ try: from statsforecast import StatsForecast # type: ignore[import] from statsforecast.models import AutoARIMA # type: ignore[import] except ImportError as exc: raise ImportError("statsforecast 가 설치되지 않았습니다: pip install statsforecast") from exc # 연간 → 월별 업샘플링 idx = pd.date_range(start=f"{int(series.index.min())}-01-01", periods=len(series), freq="YS") annual = pd.Series(series.values, index=idx, dtype=float) monthly = annual.resample("MS").interpolate(method="linear") sf_df = pd.DataFrame({ "unique_id": "school", "ds": monthly.index, "y": monthly.values, }) sf = StatsForecast(models=[AutoARIMA(season_length=12)], freq="MS") forecast = sf.forecast(df=sf_df, h=horizon_months, level=[80]) point = np.maximum(forecast["AutoARIMA"].values, 0) lower = np.maximum(forecast.get("AutoARIMA-lo-80", pd.Series(point * 0.9)).values, 0) upper = forecast.get("AutoARIMA-hi-80", pd.Series(point * 1.1)).values return point, lower, upper # ── 공개 인터페이스 ──────────────────────────────────────────────────── def predict( self, schul_code: str, timeseries: pd.Series, horizon_years: int = 5, target_col: str = "student_count", ) -> ForecastResult: """ 학교 시계열 데이터를 바탕으로 미래 값을 예측합니다. Parameters ---------- schul_code: 대상 학교 SD_SCHUL_CODE. timeseries: 연도(int)를 인덱스로 하는 시계열 Series. 최소 3개 이상의 관측값을 권장합니다. horizon_years: 예측 기간 (년 단위, 기본 5). target_col: 예측 대상 컬럼명 (결과 레이블용). Returns ------- ForecastResult """ if timeseries.empty or len(timeseries.dropna()) < 2: raise ValueError( f"학교 '{schul_code}': 예측에 필요한 시계열 데이터가 부족합니다 " f"(최소 2개 관측값 필요, 현재 {len(timeseries.dropna())}개)." ) ts = timeseries.dropna().sort_index().astype(float) horizon_months = horizon_years * 12 self._ensure_model() try: if self._model_version != self._MODEL_VERSION_ARIMA and self._pipeline is not None: point_m, lower_m, upper_m = self._predict_ttm(ts, horizon_months) model_ver = self._model_version else: point_m, lower_m, upper_m = self._predict_arima(ts, horizon_months) model_ver = self._MODEL_VERSION_ARIMA except Exception as exc: # noqa: BLE001 logger.warning("TTM 예측 실패, ARIMA fallback 사용: %s", exc) point_m, lower_m, upper_m = self._predict_arima(ts, horizon_months) model_ver = self._MODEL_VERSION_ARIMA # 월 → 연 집계 point_y = _monthly_to_yearly(point_m) lower_y = _monthly_to_yearly(lower_m) upper_y = _monthly_to_yearly(upper_m) last_year = int(ts.index.max()) forecast_years = list(range(last_year + 1, last_year + 1 + len(point_y))) logger.info( "예측 완료: schul_code=%s model=%s horizon=%dy point_mean=%.1f", schul_code, model_ver, horizon_years, float(np.mean(point_y)), ) return ForecastResult( schul_code=schul_code, target_col=target_col, forecast_years=forecast_years, point_forecast=[round(float(v), 1) for v in point_y], lower_bound=[round(float(v), 1) for v in lower_y], upper_bound=[round(float(v), 1) for v in upper_y], model_version=model_ver, context_years=[int(y) for y in ts.index.tolist()], context_values=[round(float(v), 1) for v in ts.tolist()], )