kr4phy's picture
Sync from GitHub
cff6ac7
Raw
History Blame Contribute Delete
12.5 kB
"""
Granite Time Series (TTM) 기반 μ‹œκ³„μ—΄ 예츑 μ—”μ§„.
IBM Granite TinyTimeMixer (TTM) λͺ¨λΈμ„ μ‚¬μš©ν•˜μ—¬ 학ꡐ별
학생 수 λ“± μ§€ν‘œμ˜ ν–₯ν›„ 3~5년을 μ˜ˆμΈ‘ν•©λ‹ˆλ‹€.
λͺ¨λΈ λ‘œλ“œ μš°μ„ μˆœμœ„:
1. tsfm_public 라이브러리 (IBM 곡식 래퍼)
2. transformers 직접 λ‘œλ“œ
3. statsforecast AutoARIMA (fallback)
μ°Έκ³ :
https://huggingface.co/ibm-granite/granite-timeseries-ttm-r2
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from typing import Any
import numpy as np
import pandas as pd
from src.config import get_settings
logger = logging.getLogger(__name__)
@dataclass
class ForecastResult:
"""예츑 κ²°κ³Ό μ»¨ν…Œμ΄λ„ˆ."""
schul_code: str
target_col: str
forecast_years: list[int]
point_forecast: list[float]
lower_bound: list[float] # 10th percentile
upper_bound: list[float] # 90th percentile
model_version: str
context_years: list[int] = field(default_factory=list)
context_values: list[float] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
return {
"schul_code": self.schul_code,
"target_col": self.target_col,
"forecast_years": self.forecast_years,
"point_forecast": self.point_forecast,
"lower_bound": self.lower_bound,
"upper_bound": self.upper_bound,
"model_version": self.model_version,
"context_years": self.context_years,
"context_values": self.context_values,
}
def _monthly_to_yearly(monthly_values: np.ndarray, agg: str = "last") -> np.ndarray:
"""
월별 μ˜ˆμΈ‘κ°’μ„ μ—°κ°„ μ§‘κ³„κ°’μœΌλ‘œ λ³€ν™˜ν•©λ‹ˆλ‹€.
Parameters
----------
monthly_values:
shape (n_months,) 의 μ˜ˆμΈ‘κ°’ λ°°μ—΄.
agg:
"last" β†’ 연말 κΈ°μ€€, "mean" β†’ 연평균.
"""
n_years = len(monthly_values) // 12
remainder = len(monthly_values) % 12
result = []
for i in range(n_years):
year_slice = monthly_values[i * 12 : (i + 1) * 12]
result.append(float(year_slice[-1] if agg == "last" else year_slice.mean()))
if remainder > 0:
result.append(float(monthly_values[n_years * 12 :].mean()))
return np.array(result)
class GranitePredictor:
"""
IBM Granite TTM 기반 μ‹œκ³„μ—΄ 예츑기.
μ‚¬μš© 예::
predictor = GranitePredictor()
result = predictor.predict(
schul_code="7431234",
timeseries=pd.Series({2018: 120, 2019: 105, 2020: 98, 2021: 87, 2022: 75}),
horizon_years=5,
)
"""
_MODEL_ID_TTM = "ibm-granite/granite-timeseries-ttm-r2"
_MODEL_VERSION_TTM = "granite-ttm-r2"
_MODEL_VERSION_ARIMA = "arima-fallback"
def __init__(self) -> None:
cfg = get_settings()
self._model_id = cfg.granite_model_id
self._context_length = cfg.prediction_context_length # μ›” λ‹¨μœ„
self._horizon_months = cfg.prediction_horizon_months
self._hf_token = cfg.huggingface_hub_token or None
self._pipeline: Any = None # lazy load
self._model_version: str = ""
# ── λͺ¨λΈ μ΄ˆκΈ°ν™” ───────────────────────────────────────────────────────
def _load_ttm(self) -> bool:
"""tsfm_public λ˜λŠ” transformers 둜 TTM λͺ¨λΈ λ‘œλ“œλ₯Ό μ‹œλ„ν•©λ‹ˆλ‹€."""
# 1μˆœμœ„: tsfm_public
try:
from tsfm_public.models.tinytimemixer import TinyTimeMixerForPrediction # type: ignore[import]
from tsfm_public.toolkit.time_series_forecasting_pipeline import ( # type: ignore[import]
TimeSeriesForecastingPipeline,
)
model = TinyTimeMixerForPrediction.from_pretrained(
self._model_id,
token=self._hf_token,
)
self._pipeline = TimeSeriesForecastingPipeline(
model=model,
context_length=self._context_length,
prediction_length=self._horizon_months,
)
self._model_version = self._MODEL_VERSION_TTM
logger.info("TTM λͺ¨λΈ λ‘œλ“œ μ™„λ£Œ (tsfm_public): %s", self._model_id)
return True
except Exception as exc: # noqa: BLE001
logger.warning("tsfm_public λ‘œλ“œ μ‹€νŒ¨: %s", exc)
# 2μˆœμœ„: transformers AutoModel
try:
from transformers import AutoConfig, AutoModel # type: ignore[import]
config = AutoConfig.from_pretrained(self._model_id, token=self._hf_token)
model = AutoModel.from_pretrained(self._model_id, config=config, token=self._hf_token)
model.eval()
self._pipeline = model
self._model_version = self._MODEL_VERSION_TTM + "-raw"
logger.info("TTM λͺ¨λΈ λ‘œλ“œ μ™„λ£Œ (transformers): %s", self._model_id)
return True
except Exception as exc: # noqa: BLE001
logger.warning("transformers λ‘œλ“œ μ‹€νŒ¨: %s", exc)
return False
def _ensure_model(self) -> None:
"""λͺ¨λΈμ΄ λ‘œλ“œλ˜μ§€ μ•Šμ€ 경우 μ΄ˆκΈ°ν™”λ₯Ό μ‹œλ„ν•©λ‹ˆλ‹€."""
if self._pipeline is not None or self._model_version == self._MODEL_VERSION_ARIMA:
return
if not self._load_ttm():
logger.warning("TTM λ‘œλ“œ μ‹€νŒ¨. statsforecast AutoARIMA fallback μ‚¬μš©.")
self._model_version = self._MODEL_VERSION_ARIMA
# ── 예츑 λ‚΄λΆ€ λ©”μ„œλ“œ ──────────────────────────────────────────────────
def _predict_ttm(
self,
series: pd.Series,
horizon_months: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
TTM νŒŒμ΄ν”„λΌμΈμœΌλ‘œ μ˜ˆμΈ‘μ„ μˆ˜ν–‰ν•©λ‹ˆλ‹€.
Returns
-------
(point, lower, upper) β€” 각각 shape (horizon_months,) ndarray.
"""
# μ—°κ°„ 데이터λ₯Ό μ›”λ³„λ‘œ μ—…μƒ˜ν”Œλ§ (μ„ ν˜• 보간)
idx = pd.date_range(start=f"{int(series.index.min())}-01-01", periods=len(series), freq="YS")
annual = pd.Series(series.values, index=idx, dtype=float)
monthly = annual.resample("MS").interpolate(method="linear")
# 졜근 context_length κ°œμ›” μΆ”μΆœ
context = monthly.values[-self._context_length :].astype(np.float32)
context = context.reshape(1, -1, 1) # (batch=1, seq_len, channels=1)
import torch # lazy import
with torch.no_grad():
tensor_input = torch.tensor(context)
try:
# tsfm_public TimeSeriesForecastingPipeline 호좜
output = self._pipeline(tensor_input)
if hasattr(output, "prediction_outputs"):
preds = output.prediction_outputs.squeeze().numpy()
elif hasattr(output, "last_hidden_state"):
preds = output.last_hidden_state.squeeze().numpy()
else:
preds = output.squeeze().numpy()
except TypeError:
# raw transformers λͺ¨λΈ 직접 호좜
output = self._pipeline(inputs_embeds=tensor_input)
preds = output.last_hidden_state.mean(dim=-1).squeeze().numpy()
# 예츑 길이 λ§žμΆ”κΈ°
if len(preds) < horizon_months:
preds = np.pad(preds, (0, horizon_months - len(preds)), mode="edge")
preds = preds[:horizon_months]
preds = np.maximum(preds, 0)
# λΆˆν™•μ‹€μ„± μΆ”μ • (Β±10% of std)
std = float(np.std(context)) or float(np.abs(preds.mean()) * 0.1) or 1.0
lower = np.maximum(preds - 1.282 * std, 0) # 10th percentile
upper = preds + 1.282 * std # 90th percentile
return preds, lower, upper
def _predict_arima(
self,
series: pd.Series,
horizon_months: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
statsforecast AutoARIMA fallback 예츑.
Returns
-------
(point, lower, upper) β€” 각각 shape (horizon_months,) ndarray.
"""
try:
from statsforecast import StatsForecast # type: ignore[import]
from statsforecast.models import AutoARIMA # type: ignore[import]
except ImportError as exc:
raise ImportError("statsforecast κ°€ μ„€μΉ˜λ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€: pip install statsforecast") from exc
# μ—°κ°„ β†’ 월별 μ—…μƒ˜ν”Œλ§
idx = pd.date_range(start=f"{int(series.index.min())}-01-01", periods=len(series), freq="YS")
annual = pd.Series(series.values, index=idx, dtype=float)
monthly = annual.resample("MS").interpolate(method="linear")
sf_df = pd.DataFrame({
"unique_id": "school",
"ds": monthly.index,
"y": monthly.values,
})
sf = StatsForecast(models=[AutoARIMA(season_length=12)], freq="MS")
forecast = sf.forecast(df=sf_df, h=horizon_months, level=[80])
point = np.maximum(forecast["AutoARIMA"].values, 0)
lower = np.maximum(forecast.get("AutoARIMA-lo-80", pd.Series(point * 0.9)).values, 0)
upper = forecast.get("AutoARIMA-hi-80", pd.Series(point * 1.1)).values
return point, lower, upper
# ── 곡개 μΈν„°νŽ˜μ΄μŠ€ ────────────────────────────────────────────────────
def predict(
self,
schul_code: str,
timeseries: pd.Series,
horizon_years: int = 5,
target_col: str = "student_count",
) -> ForecastResult:
"""
학ꡐ μ‹œκ³„μ—΄ 데이터λ₯Ό λ°”νƒ•μœΌλ‘œ 미래 값을 μ˜ˆμΈ‘ν•©λ‹ˆλ‹€.
Parameters
----------
schul_code:
λŒ€μƒ 학ꡐ SD_SCHUL_CODE.
timeseries:
연도(int)λ₯Ό 인덱슀둜 ν•˜λŠ” μ‹œκ³„μ—΄ Series.
μ΅œμ†Œ 3개 μ΄μƒμ˜ 관츑값을 ꢌμž₯ν•©λ‹ˆλ‹€.
horizon_years:
예츑 κΈ°κ°„ (λ…„ λ‹¨μœ„, κΈ°λ³Έ 5).
target_col:
예츑 λŒ€μƒ 컬럼λͺ… (κ²°κ³Ό λ ˆμ΄λΈ”μš©).
Returns
-------
ForecastResult
"""
if timeseries.empty or len(timeseries.dropna()) < 2:
raise ValueError(
f"학ꡐ '{schul_code}': μ˜ˆμΈ‘μ— ν•„μš”ν•œ μ‹œκ³„μ—΄ 데이터가 λΆ€μ‘±ν•©λ‹ˆλ‹€ "
f"(μ΅œμ†Œ 2개 κ΄€μΈ‘κ°’ ν•„μš”, ν˜„μž¬ {len(timeseries.dropna())}개)."
)
ts = timeseries.dropna().sort_index().astype(float)
horizon_months = horizon_years * 12
self._ensure_model()
try:
if self._model_version != self._MODEL_VERSION_ARIMA and self._pipeline is not None:
point_m, lower_m, upper_m = self._predict_ttm(ts, horizon_months)
model_ver = self._model_version
else:
point_m, lower_m, upper_m = self._predict_arima(ts, horizon_months)
model_ver = self._MODEL_VERSION_ARIMA
except Exception as exc: # noqa: BLE001
logger.warning("TTM 예츑 μ‹€νŒ¨, ARIMA fallback μ‚¬μš©: %s", exc)
point_m, lower_m, upper_m = self._predict_arima(ts, horizon_months)
model_ver = self._MODEL_VERSION_ARIMA
# μ›” β†’ μ—° 집계
point_y = _monthly_to_yearly(point_m)
lower_y = _monthly_to_yearly(lower_m)
upper_y = _monthly_to_yearly(upper_m)
last_year = int(ts.index.max())
forecast_years = list(range(last_year + 1, last_year + 1 + len(point_y)))
logger.info(
"예츑 μ™„λ£Œ: schul_code=%s model=%s horizon=%dy point_mean=%.1f",
schul_code,
model_ver,
horizon_years,
float(np.mean(point_y)),
)
return ForecastResult(
schul_code=schul_code,
target_col=target_col,
forecast_years=forecast_years,
point_forecast=[round(float(v), 1) for v in point_y],
lower_bound=[round(float(v), 1) for v in lower_y],
upper_bound=[round(float(v), 1) for v in upper_y],
model_version=model_ver,
context_years=[int(y) for y in ts.index.tolist()],
context_values=[round(float(v), 1) for v in ts.tolist()],
)