Merry99's picture
add method on /health
b18773a
"""FastAPI μ•±: μˆ˜λ™ ν•™μŠ΅ 및 λͺ¨λΈ λ‹€μš΄λ‘œλ“œ/μ—…λ‘œλ“œ"""
from __future__ import annotations
import os
import threading
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import schedule
import lightgbm as lgb
import numpy as np
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from huggingface_hub import HfApi
from pydantic import BaseModel, field_validator
from train_scheduler import TrainingScheduler
app = FastAPI(
title="MuscleCare LightGBM Scheduler",
description="MuscleCare-Train-AI Space와 λ™μΌν•œ APIλ₯Ό LightGBM λͺ¨λΈλ‘œ μ œκ³΅ν•©λ‹ˆλ‹€.",
)
_scheduler = TrainingScheduler()
_model_lock = threading.Lock()
_current_model: Optional[lgb.Booster] = None
_current_model_path: Optional[str] = None
_current_model_version: Optional[int] = None
_model_cache_timestamp: Optional[float] = None
MODEL_CACHE_TIMEOUT = 3600 # 1μ‹œκ°„
class TrainResponse(BaseModel):
status: str
new_data_count: int
model_path: Optional[str] = None
hub_url: Optional[str] = None
model_version: Optional[int] = None
message: str
new_session_count: Optional[int] = None
class ResetStateResponse(BaseModel):
status: str
state: Dict[str, Any]
class PredictRequest(BaseModel):
rms_acc: float
rms_gyro: float
mean_freq_acc: float
mean_freq_gyro: float
rms_base: float
freq_base: float
user_emb: List[float]
@field_validator("user_emb")
@classmethod
def validate_user_emb(cls, v: List[float]) -> List[float]:
if len(v) != 12:
raise ValueError("user_emb must contain exactly 12 values.")
return v
class PredictResponse(BaseModel):
fatigue: float
model_version: Optional[int]
def _schedule_background_job() -> None:
schedule.clear()
schedule.every().sunday.at(_scheduler.schedule_time).do(_scheduler.run_scheduled_training)
def _loop() -> None:
while True:
schedule.run_pending()
time.sleep(60)
threading.Thread(target=_loop, daemon=True).start()
def _apply_training_result(result: Dict[str, Any]) -> None:
if result.get("status") != "trained":
return
model_path = result.get("model_path")
if not model_path:
print("[Model] ν•™μŠ΅ 결과에 model_pathκ°€ μ—†μ–΄ λͺ¨λΈμ„ λ‘œλ“œν•˜μ§€ λͺ»ν–ˆμŠ΅λ‹ˆλ‹€.")
return
try:
_load_model_from_path(Path(model_path), result.get("model_version"))
except Exception as exc:
print(f"[Model] μƒˆ λͺ¨λΈ λ‘œλ“œ μ‹€νŒ¨: {exc}")
def _load_model_from_path(path: Path, version: Optional[int] = None) -> None:
if not path.exists():
raise FileNotFoundError(f"λͺ¨λΈ νŒŒμΌμ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€: {path}")
booster = lgb.Booster(model_file=str(path))
with _model_lock:
global _current_model, _current_model_path, _current_model_version, _model_cache_timestamp
_current_model = booster
_current_model_path = str(path)
_current_model_version = version
_model_cache_timestamp = time.time()
print(f"[Model] Loaded LightGBM model from {path} (version={version})")
def _get_cached_model() -> Optional[lgb.Booster]:
"""μΊμ‹œλœ λͺ¨λΈ λ°˜ν™˜, νƒ€μž„μ•„μ›ƒ μ‹œ None λ°˜ν™˜"""
global _current_model, _model_cache_timestamp
with _model_lock:
if _current_model is None:
return None
if _model_cache_timestamp is None:
return None
if time.time() - _model_cache_timestamp > MODEL_CACHE_TIMEOUT:
print("[Model] λͺ¨λΈ μΊμ‹œ 만료, μž¬λ‘œλ“œ ν•„μš”")
_current_model = None
return None
return _current_model
def _maybe_load_latest_model() -> None:
try:
manifest = _scheduler.get_model_versions()
target_entry = manifest[-1] if manifest else None
candidate_path: Optional[Path] = None
candidate_version: Optional[int] = None
if target_entry:
candidate_path = Path(target_entry["path"])
candidate_version = target_entry.get("version")
else:
default_path = Path("models/lightgbm_model.txt")
if default_path.exists():
candidate_path = default_path
if candidate_path and candidate_path.exists():
try:
_load_model_from_path(candidate_path, candidate_version)
print(f"[Model] λͺ¨λΈ λ‘œλ“œ 성곡: {candidate_path}")
except Exception as exc:
print(f"[Model] λͺ¨λΈ λ‘œλ“œ μ‹€νŒ¨ (계속 μ§„ν–‰): {exc}")
else:
print("[Model] λ‘œλ“œν•  λͺ¨λΈμ΄ 아직 μ—†μŠ΅λ‹ˆλ‹€.")
except Exception as exc:
print(f"[Model] λͺ¨λΈ λ‘œλ“œ κ³Όμ •μ—μ„œ μ˜ˆμ™Έ λ°œμƒ: {exc}")
def _get_active_model() -> Tuple[lgb.Booster, Optional[int]]:
# λ¨Όμ € μΊμ‹œλœ λͺ¨λΈ 확인
cached_model = _get_cached_model()
if cached_model is not None:
return cached_model, _current_model_version
# μΊμ‹œλœ λͺ¨λΈμ΄ μ—†μœΌλ©΄ μ΅œμ‹  λͺ¨λΈ λ‘œλ“œ μ‹œλ„
try:
manifest = _scheduler.get_model_versions()
target_entry = manifest[-1] if manifest else None
if target_entry:
path = Path(target_entry["path"])
version = target_entry.get("version")
else:
path = Path("models/lightgbm_model.txt")
if path.exists():
_load_model_from_path(path, version)
return _current_model, _current_model_version
else:
raise HTTPException(status_code=503, detail="λͺ¨λΈ νŒŒμΌμ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€.")
except Exception as exc:
raise HTTPException(status_code=503, detail=f"λͺ¨λΈ λ‘œλ“œ μ‹€νŒ¨: {exc}")
def _build_feature_vector(payload: PredictRequest) -> np.ndarray:
rms_base = payload.rms_base if payload.rms_base != 0 else 1e-6
freq_mean = (payload.mean_freq_acc + payload.mean_freq_gyro) / 2.0
if freq_mean == 0:
freq_mean = 1e-6
rms_ratio = ((payload.rms_acc + payload.rms_gyro) / 2.0) / rms_base
freq_ratio = payload.freq_base / freq_mean
feature_vector = [rms_ratio, freq_ratio, *payload.user_emb]
return np.asarray([feature_vector], dtype=np.float32)
@app.on_event("startup")
def on_startup() -> None:
print("[Startup] MuscleCare Space μ‹œμž‘ 쀑...")
try:
_schedule_background_job()
print("[Startup] μŠ€μΌ€μ€„λŸ¬ μ΄ˆκΈ°ν™” μ™„λ£Œ")
except Exception as exc:
print(f"[Startup] μŠ€μΌ€μ€„λŸ¬ μ΄ˆκΈ°ν™” μ‹€νŒ¨ (계속 μ§„ν–‰): {exc}")
# λͺ¨λΈ μ—…λ°μ΄νŠΈ μ‹œλ„ (μ΅œμ‹  λ°μ΄ν„°λ‘œ λͺ¨λΈ ν•™μŠ΅)
print("[Startup] λͺ¨λΈ μ—…λ°μ΄νŠΈ μ‹œλ„ 쀑...")
try:
# κΈ°μ‘΄ λͺ¨λΈ 확인
manifest = _scheduler.get_model_versions()
has_existing_model = len(manifest) > 0
print(f"[Startup] κΈ°μ‘΄ λͺ¨λΈ 쑴재: {has_existing_model}")
if not has_existing_model:
print("[Startup] κΈ°μ‘΄ λͺ¨λΈμ΄ μ—†μ–΄ 초기 ν•™μŠ΅μ„ μˆ˜ν–‰ν•©λ‹ˆλ‹€...")
result = _scheduler.run_scheduled_training()
if result.get("status") == "trained":
_apply_training_result(result)
print("[Startup] βœ… 초기 ν•™μŠ΅ μ™„λ£Œ")
else:
print(f"[Startup] ⚠️ 초기 ν•™μŠ΅ μ‹€νŒ¨: {result.get('message', 'μ•Œ 수 μ—†λŠ” 였λ₯˜')}")
else:
print("[Startup] κΈ°μ‘΄ λͺ¨λΈμ΄ μžˆμ–΄ μ—…λ°μ΄νŠΈλ₯Ό κ±΄λ„ˆλœλ‹ˆλ‹€ (ν•„μš”μ‹œ /trigger 호좜)")
except Exception as exc:
print(f"[Startup] λͺ¨λΈ μ—…λ°μ΄νŠΈ μ‹€νŒ¨ (계속 μ§„ν–‰): {exc}")
print("[Startup] MuscleCare Space μ‹œμž‘ μ™„λ£Œ")
@app.head("/health")
async def health_head():
return None # HEADλŠ” λ°”λ””κ°€ ν•„μš” μ—†μœΌλ―€λ‘œ None λ°˜ν™˜
@app.get("/health")
def health_check() -> dict:
"""
μ‹œμŠ€ν…œ ν—¬μŠ€μ²΄ν¬ API
- λͺ¨λΈ μƒνƒœ
- μ‹œμŠ€ν…œ λ¦¬μ†ŒμŠ€
- 파일 μƒνƒœ
- 졜근 ν•™μŠ΅ 정보
"""
import time
from pathlib import Path
# κΈ°λ³Έ μƒνƒœ
health_status = {
"status": "ok",
"timestamp": time.time(),
"environment": os.getenv("ENVIRONMENT", "development"),
"version": "1.0.0"
}
try:
# λͺ¨λΈ μƒνƒœ 확인
try:
cached_model = _get_cached_model()
health_status["model_loaded"] = cached_model is not None
if _model_cache_timestamp:
health_status["model_cache_age_seconds"] = int(time.time() - _model_cache_timestamp)
except Exception as e:
health_status["model_loaded"] = False
health_status["model_error"] = str(e)
# λͺ¨λΈ 파일 쑴재 μ—¬λΆ€
try:
model_files = list(Path("models").glob("*.txt"))
health_status["model_files_count"] = len(model_files)
if model_files:
latest_model = max(model_files, key=lambda x: x.stat().st_mtime)
health_status["latest_model_file"] = latest_model.name
except Exception as e:
health_status["model_files_count"] = 0
health_status["model_files_error"] = str(e)
# 둜그 파일 쑴재 μ—¬λΆ€
try:
log_files = list(Path("logs").glob("*.json"))
health_status["log_files_count"] = len(log_files)
except Exception as e:
health_status["log_files_count"] = 0
health_status["log_files_error"] = str(e)
# 졜근 ν•™μŠ΅ μƒνƒœ
try:
manifest = _scheduler.get_model_versions()
if manifest:
latest = manifest[-1]
health_status["latest_model_version"] = latest.get("version")
health_status["latest_training_time"] = latest.get("timestamp")
health_status["total_sessions_trained"] = sum(m.get("session_count", 0) for m in manifest)
else:
health_status["latest_model_version"] = None
health_status["total_sessions_trained"] = 0
except Exception as e:
health_status["training_status_error"] = str(e)
# API μ—”λ“œν¬μΈνŠΈ μƒνƒœ
endpoints_status = {
"predict": "available",
"trigger": "available",
"model": "available",
"update-model": "available",
"health": "available"
}
# ν™˜κ²½μ— λ”°λ₯Έ state_reset μƒνƒœ
if os.getenv("ENVIRONMENT") != "production":
endpoints_status["state_reset"] = "available"
else:
endpoints_status["state_reset"] = "disabled_in_production"
health_status["endpoints"] = endpoints_status
# μ‹œμŠ€ν…œ λ¦¬μ†ŒμŠ€ (간단 버전)
try:
# ν”„λ‘œμ„ΈμŠ€ 정보 (기본적인)
health_status["process_id"] = os.getpid()
health_status["working_directory"] = os.getcwd()
except Exception as e:
health_status["system_error"] = str(e)
except Exception as e:
health_status["status"] = "degraded"
health_status["error"] = str(e)
# μ—λŸ¬κ°€ λ°œμƒν•΄λ„ κΈ°λ³Έ μ •λ³΄λŠ” μœ μ§€
return health_status
@app.get("/")
def root() -> dict:
endpoints = {
"trigger": "/trigger",
"model": "/model",
"update-model": "/update-model",
"predict": "/predict",
"health": "/health"
}
return {
"message": "MuscleCare LightGBM Scheduler API",
"docs": "/docs",
"endpoints": endpoints,
"environment": os.getenv("ENVIRONMENT", "development"),
}
def _upload_to_hub(model_path: str) -> Optional[str]:
token = os.getenv("HF_HYBRID_MODEL_TOKEN")
repo_id = os.getenv("HF_HYBRID_MODEL_REPO_ID")
print(f"[Upload] Model Hub μ—…λ‘œλ“œ μ‹œλ„: {model_path}")
if not token or not repo_id:
print(f"[Upload] ν™˜κ²½λ³€μˆ˜ λˆ„λ½: TOKEN={'***' if token else 'None'}, REPO_ID={repo_id}")
return None
path = Path(model_path)
if not path.exists():
raise HTTPException(status_code=404, detail=f"λͺ¨λΈ νŒŒμΌμ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€: {model_path}")
try:
print(f"[Upload] 리포지토리 생성/확인: {repo_id}")
api = HfApi(token=token)
api.create_repo(repo_id=repo_id, repo_type="model", private=False, exist_ok=True)
print(f"[Upload] λͺ¨λΈ 파일 μ—…λ‘œλ“œ: {path.name}")
api.upload_file(
path_or_fileobj=path,
path_in_repo=path.name,
repo_id=repo_id,
repo_type="model",
commit_message=f"LightGBM model upload ({path.name})",
)
manifest_path = Path("logs/model_versions.json")
if manifest_path.exists():
print(f"[Upload] 메타데이터 파일 μ—…λ‘œλ“œ")
api.upload_file(
path_or_fileobj=str(manifest_path),
path_in_repo="model_versions.json",
repo_id=repo_id,
repo_type="model",
commit_message="Update model manifest",
)
hub_url = f"https://huggingface.co/{repo_id}"
print(f"[Upload] βœ… μ—…λ‘œλ“œ 성곡: {hub_url}")
return hub_url
except Exception as exc:
print(f"[Upload] ❌ μ—…λ‘œλ“œ μ‹€νŒ¨: {exc}")
raise
def _resolve_model_entry(version: Optional[int] = None) -> Dict[str, Any]:
manifest = _scheduler.get_model_versions()
if not manifest:
raise HTTPException(status_code=404, detail="아직 ν•™μŠ΅λœ λͺ¨λΈμ΄ μ—†μŠ΅λ‹ˆλ‹€.")
if version is None:
return manifest[-1]
for entry in manifest:
if entry.get("version") == version:
return entry
raise HTTPException(
status_code=404,
detail=f"버전 {version} λͺ¨λΈμ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€.",
)
@app.get("/model")
@app.get("/model/{version:int}")
def download_model(version: Optional[int] = None) -> FileResponse:
entry = _resolve_model_entry(version)
path = Path(entry["path"])
if not path.exists():
raise HTTPException(status_code=404, detail="λͺ¨λΈ νŒŒμΌμ„ 찾을 수 μ—†μŠ΅λ‹ˆλ‹€.")
response = FileResponse(
path=path,
filename=entry["filename"],
media_type="application/octet-stream",
)
response.headers["X-Model-Version"] = str(entry["version"])
return response
@app.get("/download")
def download_latest_alias() -> FileResponse:
return download_model()
# ν”„λ‘œλ•μ…˜ ν™˜κ²½μ—μ„œλŠ” reset API λΉ„ν™œμ„±ν™”
environment = os.getenv("ENVIRONMENT", "development")
if environment != "production":
print(f"[Security] State reset API enabled (environment: {environment})")
@app.post("/state/reset", response_model=ResetStateResponse)
def reset_state() -> ResetStateResponse:
print("[Security] State reset requested")
state = _scheduler.reset_training_state()
return ResetStateResponse(status="reset", state=state)
else:
print(f"[Security] State reset API disabled in production environment")
@app.post("/trigger", response_model=TrainResponse)
def trigger_training(upload: bool = False) -> TrainResponse:
try:
result = _scheduler.run_scheduled_training()
except Exception as exc: # pragma: no cover
raise HTTPException(status_code=500, detail=f"ν•™μŠ΅ μ‹€ν–‰ 였λ₯˜: {exc}") from exc
message = "λͺ¨λΈ ν•™μŠ΅μ΄ μ™„λ£Œλ˜μ—ˆμŠ΅λ‹ˆλ‹€." if result["status"] == "trained" else "ν•™μŠ΅μ΄ κ±΄λ„ˆλ›°μ–΄μ‘ŒμŠ΅λ‹ˆλ‹€."
hub_url = None
model_version = result.get("model_version")
model_path = result.get("model_path")
if upload and model_path and result["status"] == "trained":
try:
hub_url = _upload_to_hub(model_path)
message = "λͺ¨λΈ ν•™μŠ΅ 및 Hugging Face μ—…λ‘œλ“œκ°€ μ™„λ£Œλ˜μ—ˆμŠ΅λ‹ˆλ‹€."
except HTTPException:
raise
except Exception as exc: # pragma: no cover
raise HTTPException(status_code=500, detail=f"Hugging Face μ—…λ‘œλ“œ μ‹€νŒ¨: {exc}") from exc
_apply_training_result(result)
return TrainResponse(
status=result["status"],
new_data_count=result.get("new_data_count", 0),
model_path=model_path,
hub_url=hub_url,
model_version=model_version,
message=message,
new_session_count=result.get("new_session_count"),
)
@app.post("/train", response_model=TrainResponse)
def trigger_training_alias(upload: bool = False) -> TrainResponse:
return trigger_training(upload=upload)
@app.post("/update-model", response_model=TrainResponse)
def update_model(force: bool = False) -> TrainResponse:
"""
λͺ¨λΈμ„ κ°•μ œλ‘œ μ—…λ°μ΄νŠΈν•©λ‹ˆλ‹€.
- force=true: κΈ°μ‘΄ λͺ¨λΈμ΄ μžˆμ–΄λ„ μ—…λ°μ΄νŠΈ
- force=false: μƒˆλ‘œμš΄ 데이터가 μžˆμ„ λ•Œλ§Œ μ—…λ°μ΄νŠΈ
"""
try:
print(f"[Update] λͺ¨λΈ μ—…λ°μ΄νŠΈ μš”μ²­ (force={force})")
if force:
# κ°•μ œ μ—…λ°μ΄νŠΈ: κΈ°μ‘΄ λͺ¨λΈ λ¬΄μ‹œν•˜κ³  μƒˆλ‘œ ν•™μŠ΅
print("[Update] κ°•μ œ μ—…λ°μ΄νŠΈ λͺ¨λ“œ")
# μž„μ‹œλ‘œ κΈ°μ‘΄ λͺ¨λΈμ„ λ°±μ—…
manifest = _scheduler.get_model_versions()
if manifest:
print(f"[Update] κΈ°μ‘΄ λͺ¨λΈ {len(manifest)}개 백업됨")
result = _scheduler.run_scheduled_training()
if result.get("status") == "trained":
_apply_training_result(result)
message = "βœ… λͺ¨λΈ μ—…λ°μ΄νŠΈ μ™„λ£Œ"
else:
message = f"⚠️ λͺ¨λΈ μ—…λ°μ΄νŠΈ κ±΄λ„ˆλœ€: {result.get('message', 'μƒˆλ‘œμš΄ 데이터 μ—†μŒ')}"
return TrainResponse(
status=result["status"],
new_data_count=result.get("new_data_count", 0),
model_path=result.get("model_path"),
hub_url=None, # μ—…λ°μ΄νŠΈ μ‹œμ—λŠ” Hub μ—…λ‘œλ“œ ν•˜μ§€ μ•ŠμŒ
model_version=result.get("model_version"),
message=message,
new_session_count=result.get("new_session_count"),
)
except Exception as exc:
raise HTTPException(status_code=500, detail=f"λͺ¨λΈ μ—…λ°μ΄νŠΈ μ‹€νŒ¨: {exc}") from exc
@app.post("/predict", response_model=PredictResponse)
def predict(payload: PredictRequest) -> PredictResponse:
booster, version = _get_active_model()
features = _build_feature_vector(payload)
prediction = booster.predict(features)[0]
return PredictResponse(fatigue=float(prediction), model_version=version)
__all__ = ["app"]