""" 종합 진단 엔진 모듈. 예측(GranitePredictor), XAI(ShapExplainer), 점수(SustainabilityScorer) 를 통합하여 단일 API 호출로 학교 전체 진단 결과를 반환합니다. 또한 정책 시뮬레이션 기능을 제공합니다: - 교사 1인당 학생 수 개선 시나리오 - 복합시설 도입 시나리오 - 기간제 교원 정규직 전환 시나리오 """ from __future__ import annotations import logging from dataclasses import dataclass, field from typing import Any import pandas as pd from src.analytics.predictor import ForecastResult, GranitePredictor from src.analytics.scorer import SustainabilityScore, SustainabilityScorer from src.analytics.xai import ShapExplainer, ShapResult logger = logging.getLogger(__name__) # 시나리오 정의 (key: 시나리오ID, value: 변경 항목) SIMULATION_SCENARIOS: dict[str, dict[str, Any]] = { "improve_teacher_ratio": { "label": "교사 1인당 학생 수 개선", "description": "기간제 교원을 정규직으로 전환하고 교원을 1명 추가 배치합니다.", "delta": {"teacher_count": 1, "temp_teacher_count": -1}, }, "introduce_complex_facility": { "label": "복합시설 도입", "description": "도서관·체육관 등 복합시설을 도입하여 지역 유입 효과를 반영합니다.", "delta": {"established_year_delta": 20, "transfer_net_avg": 3}, }, "community_revitalization": { "label": "지역 연계 강화", "description": "지역 재생 거점으로 전환하여 소멸위험지수를 개선합니다.", "delta": {"population_risk_index": 0.3}, }, } @dataclass class DiagnosticsResult: """종합 진단 결과 컨테이너.""" schul_code: str school_name: str status_label: str status_code: int sustainability_score: float forecast: ForecastResult shap_result: ShapResult score_detail: SustainabilityScore simulations: dict[str, dict[str, Any]] = field(default_factory=dict) metadata: dict[str, Any] = field(default_factory=dict) def to_dict(self) -> dict[str, Any]: return { "schul_code": self.schul_code, "school_name": self.school_name, "status_label": self.status_label, "status_code": self.status_code, "sustainability_score": round(self.sustainability_score, 2), "forecast": self.forecast.to_dict(), "shap_result": self.shap_result.to_dict(), "score_detail": self.score_detail.to_dict(), "simulations": self.simulations, "metadata": self.metadata, } class DiagnosticsEngine: """ 학교 종합 진단 엔진. 사용 예:: engine = DiagnosticsEngine() result = engine.diagnose( schul_code="7431234", school_data={...}, timeseries=pd.Series({2018: 120, ...}), ) """ def __init__( self, predictor: GranitePredictor | None = None, shap_explainer: ShapExplainer | None = None, scorer: SustainabilityScorer | None = None, ) -> None: self._predictor = predictor or GranitePredictor() self._shap = shap_explainer or ShapExplainer() self._scorer = scorer or SustainabilityScorer() def _run_simulation( self, scenario_id: str, base_data: dict[str, Any], ) -> dict[str, Any]: """ 단일 정책 시나리오를 적용한 지속 가능성 점수를 반환합니다. Parameters ---------- scenario_id: SIMULATION_SCENARIOS 의 키. base_data: 원본 학교 데이터 딕셔너리. Returns ------- dict 시나리오 레이블, 설명, 변경 전/후 점수 비교. """ scenario = SIMULATION_SCENARIOS[scenario_id] delta = scenario["delta"] modified = dict(base_data) for key, change in delta.items(): if key == "established_year_delta": # 시설 노후도를 개선하는 것을 시뮬레이션 orig = int(modified.get("established_year", 0) or 0) if orig > 0: modified["established_year"] = orig + int(change) elif key in modified: modified[key] = (float(modified[key] or 0)) + float(change) else: modified[key] = float(change) base_score = self._scorer.compute(base_data.get("sd_schul_code", ""), base_data) sim_score = self._scorer.compute(base_data.get("sd_schul_code", ""), modified) return { "label": scenario["label"], "description": scenario["description"], "base_total_score": round(base_score.total_score, 2), "sim_total_score": round(sim_score.total_score, 2), "delta_score": round(sim_score.total_score - base_score.total_score, 2), "sim_curriculum_score": round(sim_score.curriculum_score, 2), "sim_personnel_score": round(sim_score.personnel_score, 2), "sim_facility_score": round(sim_score.facility_score, 2), "sim_community_score": round(sim_score.community_score, 2), } def diagnose( self, schul_code: str, school_data: dict[str, Any], timeseries: pd.Series, horizon_years: int = 5, run_simulations: bool = True, ) -> DiagnosticsResult: """ 학교 종합 진단을 수행합니다. Parameters ---------- schul_code: 대상 학교 SD_SCHUL_CODE. school_data: 학교 지표 딕셔너리 (SustainabilityScorer, ShapExplainer 공통 입력). timeseries: 연도(int) 인덱스의 학생 수 시계열 Series. horizon_years: 예측 기간 (년). run_simulations: 정책 시뮬레이션 실행 여부. Returns ------- DiagnosticsResult """ school_name = str(school_data.get("school_name", schul_code)) logger.info("진단 시작: schul_code=%s school_name=%s", schul_code, school_name) # 1. 시계열 예측 try: forecast = self._predictor.predict( schul_code=schul_code, timeseries=timeseries, horizon_years=horizon_years, target_col="student_count", ) except Exception as exc: logger.error("예측 실패: schul_code=%s error=%s", schul_code, exc) raise # 히스토리를 school_data 에 주입 (SHAP 피처 계산용) school_data_enriched = dict(school_data) school_data_enriched["student_count_history"] = forecast.context_values school_data_enriched["sd_schul_code"] = schul_code # 2. SHAP 분류 및 기여도 분석 try: shap_result = self._shap.explain(schul_code=schul_code, school_data=school_data_enriched) except Exception as exc: logger.error("SHAP 분석 실패: schul_code=%s error=%s", schul_code, exc) raise # 3. 지속 가능성 점수 산출 try: score_detail = self._scorer.compute(schul_code=schul_code, school_data=school_data_enriched) except Exception as exc: logger.error("점수 산출 실패: schul_code=%s error=%s", schul_code, exc) raise # 4. 정책 시뮬레이션 simulations: dict[str, dict[str, Any]] = {} if run_simulations: for sid in SIMULATION_SCENARIOS: try: simulations[sid] = self._run_simulation(sid, school_data_enriched) except Exception as exc: # noqa: BLE001 logger.warning("시뮬레이션 실패: scenario=%s error=%s", sid, exc) simulations[sid] = {"error": str(exc)} result = DiagnosticsResult( schul_code=schul_code, school_name=school_name, status_label=shap_result.status_label, status_code=shap_result.status_code, sustainability_score=score_detail.total_score, forecast=forecast, shap_result=shap_result, score_detail=score_detail, simulations=simulations, metadata={ "diagnosed_at": pd.Timestamp.now().isoformat(), "model_version": forecast.model_version, "data_quality": school_data_enriched.get("data_quality_score", 1.0), }, ) logger.info( "진단 완료: schul_code=%s status=%s score=%.1f", schul_code, shap_result.status_label, score_detail.total_score, ) return result