Spaces:
Sleeping
Sleeping
| """FastAPI μ±: μλ νμ΅ λ° λͺ¨λΈ λ€μ΄λ‘λ/μ λ‘λ""" | |
| from __future__ import annotations | |
| import os | |
| import threading | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import schedule | |
| import lightgbm as lgb | |
| import numpy as np | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import FileResponse | |
| from huggingface_hub import HfApi | |
| from pydantic import BaseModel, field_validator | |
| from train_scheduler import TrainingScheduler | |
| app = FastAPI( | |
| title="MuscleCare LightGBM Scheduler", | |
| description="MuscleCare-Train-AI Spaceμ λμΌν APIλ₯Ό LightGBM λͺ¨λΈλ‘ μ 곡ν©λλ€.", | |
| ) | |
| _scheduler = TrainingScheduler() | |
| _model_lock = threading.Lock() | |
| _current_model: Optional[lgb.Booster] = None | |
| _current_model_path: Optional[str] = None | |
| _current_model_version: Optional[int] = None | |
| _model_cache_timestamp: Optional[float] = None | |
| MODEL_CACHE_TIMEOUT = 3600 # 1μκ° | |
| class TrainResponse(BaseModel): | |
| status: str | |
| new_data_count: int | |
| model_path: Optional[str] = None | |
| hub_url: Optional[str] = None | |
| model_version: Optional[int] = None | |
| message: str | |
| new_session_count: Optional[int] = None | |
| class ResetStateResponse(BaseModel): | |
| status: str | |
| state: Dict[str, Any] | |
| class PredictRequest(BaseModel): | |
| rms_acc: float | |
| rms_gyro: float | |
| mean_freq_acc: float | |
| mean_freq_gyro: float | |
| rms_base: float | |
| freq_base: float | |
| user_emb: List[float] | |
| def validate_user_emb(cls, v: List[float]) -> List[float]: | |
| if len(v) != 12: | |
| raise ValueError("user_emb must contain exactly 12 values.") | |
| return v | |
| class PredictResponse(BaseModel): | |
| fatigue: float | |
| model_version: Optional[int] | |
| def _schedule_background_job() -> None: | |
| schedule.clear() | |
| schedule.every().sunday.at(_scheduler.schedule_time).do(_scheduler.run_scheduled_training) | |
| def _loop() -> None: | |
| while True: | |
| schedule.run_pending() | |
| time.sleep(60) | |
| threading.Thread(target=_loop, daemon=True).start() | |
| def _apply_training_result(result: Dict[str, Any]) -> None: | |
| if result.get("status") != "trained": | |
| return | |
| model_path = result.get("model_path") | |
| if not model_path: | |
| print("[Model] νμ΅ κ²°κ³Όμ model_pathκ° μμ΄ λͺ¨λΈμ λ‘λνμ§ λͺ»νμ΅λλ€.") | |
| return | |
| try: | |
| _load_model_from_path(Path(model_path), result.get("model_version")) | |
| except Exception as exc: | |
| print(f"[Model] μ λͺ¨λΈ λ‘λ μ€ν¨: {exc}") | |
| def _load_model_from_path(path: Path, version: Optional[int] = None) -> None: | |
| if not path.exists(): | |
| raise FileNotFoundError(f"λͺ¨λΈ νμΌμ μ°Ύμ μ μμ΅λλ€: {path}") | |
| booster = lgb.Booster(model_file=str(path)) | |
| with _model_lock: | |
| global _current_model, _current_model_path, _current_model_version, _model_cache_timestamp | |
| _current_model = booster | |
| _current_model_path = str(path) | |
| _current_model_version = version | |
| _model_cache_timestamp = time.time() | |
| print(f"[Model] Loaded LightGBM model from {path} (version={version})") | |
| def _get_cached_model() -> Optional[lgb.Booster]: | |
| """μΊμλ λͺ¨λΈ λ°ν, νμμμ μ None λ°ν""" | |
| global _current_model, _model_cache_timestamp | |
| with _model_lock: | |
| if _current_model is None: | |
| return None | |
| if _model_cache_timestamp is None: | |
| return None | |
| if time.time() - _model_cache_timestamp > MODEL_CACHE_TIMEOUT: | |
| print("[Model] λͺ¨λΈ μΊμ λ§λ£, μ¬λ‘λ νμ") | |
| _current_model = None | |
| return None | |
| return _current_model | |
| def _maybe_load_latest_model() -> None: | |
| try: | |
| manifest = _scheduler.get_model_versions() | |
| target_entry = manifest[-1] if manifest else None | |
| candidate_path: Optional[Path] = None | |
| candidate_version: Optional[int] = None | |
| if target_entry: | |
| candidate_path = Path(target_entry["path"]) | |
| candidate_version = target_entry.get("version") | |
| else: | |
| default_path = Path("models/lightgbm_model.txt") | |
| if default_path.exists(): | |
| candidate_path = default_path | |
| if candidate_path and candidate_path.exists(): | |
| try: | |
| _load_model_from_path(candidate_path, candidate_version) | |
| print(f"[Model] λͺ¨λΈ λ‘λ μ±κ³΅: {candidate_path}") | |
| except Exception as exc: | |
| print(f"[Model] λͺ¨λΈ λ‘λ μ€ν¨ (κ³μ μ§ν): {exc}") | |
| else: | |
| print("[Model] λ‘λν λͺ¨λΈμ΄ μμ§ μμ΅λλ€.") | |
| except Exception as exc: | |
| print(f"[Model] λͺ¨λΈ λ‘λ κ³Όμ μμ μμΈ λ°μ: {exc}") | |
| def _get_active_model() -> Tuple[lgb.Booster, Optional[int]]: | |
| # λ¨Όμ μΊμλ λͺ¨λΈ νμΈ | |
| cached_model = _get_cached_model() | |
| if cached_model is not None: | |
| return cached_model, _current_model_version | |
| # μΊμλ λͺ¨λΈμ΄ μμΌλ©΄ μ΅μ λͺ¨λΈ λ‘λ μλ | |
| try: | |
| manifest = _scheduler.get_model_versions() | |
| target_entry = manifest[-1] if manifest else None | |
| if target_entry: | |
| path = Path(target_entry["path"]) | |
| version = target_entry.get("version") | |
| else: | |
| path = Path("models/lightgbm_model.txt") | |
| if path.exists(): | |
| _load_model_from_path(path, version) | |
| return _current_model, _current_model_version | |
| else: | |
| raise HTTPException(status_code=503, detail="λͺ¨λΈ νμΌμ μ°Ύμ μ μμ΅λλ€.") | |
| except Exception as exc: | |
| raise HTTPException(status_code=503, detail=f"λͺ¨λΈ λ‘λ μ€ν¨: {exc}") | |
| def _build_feature_vector(payload: PredictRequest) -> np.ndarray: | |
| rms_base = payload.rms_base if payload.rms_base != 0 else 1e-6 | |
| freq_mean = (payload.mean_freq_acc + payload.mean_freq_gyro) / 2.0 | |
| if freq_mean == 0: | |
| freq_mean = 1e-6 | |
| rms_ratio = ((payload.rms_acc + payload.rms_gyro) / 2.0) / rms_base | |
| freq_ratio = payload.freq_base / freq_mean | |
| feature_vector = [rms_ratio, freq_ratio, *payload.user_emb] | |
| return np.asarray([feature_vector], dtype=np.float32) | |
| def on_startup() -> None: | |
| print("[Startup] MuscleCare Space μμ μ€...") | |
| try: | |
| _schedule_background_job() | |
| print("[Startup] μ€μΌμ€λ¬ μ΄κΈ°ν μλ£") | |
| except Exception as exc: | |
| print(f"[Startup] μ€μΌμ€λ¬ μ΄κΈ°ν μ€ν¨ (κ³μ μ§ν): {exc}") | |
| # λͺ¨λΈ μ λ°μ΄νΈ μλ (μ΅μ λ°μ΄ν°λ‘ λͺ¨λΈ νμ΅) | |
| print("[Startup] λͺ¨λΈ μ λ°μ΄νΈ μλ μ€...") | |
| try: | |
| # κΈ°μ‘΄ λͺ¨λΈ νμΈ | |
| manifest = _scheduler.get_model_versions() | |
| has_existing_model = len(manifest) > 0 | |
| print(f"[Startup] κΈ°μ‘΄ λͺ¨λΈ μ‘΄μ¬: {has_existing_model}") | |
| if not has_existing_model: | |
| print("[Startup] κΈ°μ‘΄ λͺ¨λΈμ΄ μμ΄ μ΄κΈ° νμ΅μ μνν©λλ€...") | |
| result = _scheduler.run_scheduled_training() | |
| if result.get("status") == "trained": | |
| _apply_training_result(result) | |
| print("[Startup] β μ΄κΈ° νμ΅ μλ£") | |
| else: | |
| print(f"[Startup] β οΈ μ΄κΈ° νμ΅ μ€ν¨: {result.get('message', 'μ μ μλ μ€λ₯')}") | |
| else: | |
| print("[Startup] κΈ°μ‘΄ λͺ¨λΈμ΄ μμ΄ μ λ°μ΄νΈλ₯Ό 건λλλλ€ (νμμ /trigger νΈμΆ)") | |
| except Exception as exc: | |
| print(f"[Startup] λͺ¨λΈ μ λ°μ΄νΈ μ€ν¨ (κ³μ μ§ν): {exc}") | |
| print("[Startup] MuscleCare Space μμ μλ£") | |
| async def health_head(): | |
| return None # HEADλ λ°λκ° νμ μμΌλ―λ‘ None λ°ν | |
| def health_check() -> dict: | |
| """ | |
| μμ€ν ν¬μ€μ²΄ν¬ API | |
| - λͺ¨λΈ μν | |
| - μμ€ν 리μμ€ | |
| - νμΌ μν | |
| - μ΅κ·Ό νμ΅ μ 보 | |
| """ | |
| import time | |
| from pathlib import Path | |
| # κΈ°λ³Έ μν | |
| health_status = { | |
| "status": "ok", | |
| "timestamp": time.time(), | |
| "environment": os.getenv("ENVIRONMENT", "development"), | |
| "version": "1.0.0" | |
| } | |
| try: | |
| # λͺ¨λΈ μν νμΈ | |
| try: | |
| cached_model = _get_cached_model() | |
| health_status["model_loaded"] = cached_model is not None | |
| if _model_cache_timestamp: | |
| health_status["model_cache_age_seconds"] = int(time.time() - _model_cache_timestamp) | |
| except Exception as e: | |
| health_status["model_loaded"] = False | |
| health_status["model_error"] = str(e) | |
| # λͺ¨λΈ νμΌ μ‘΄μ¬ μ¬λΆ | |
| try: | |
| model_files = list(Path("models").glob("*.txt")) | |
| health_status["model_files_count"] = len(model_files) | |
| if model_files: | |
| latest_model = max(model_files, key=lambda x: x.stat().st_mtime) | |
| health_status["latest_model_file"] = latest_model.name | |
| except Exception as e: | |
| health_status["model_files_count"] = 0 | |
| health_status["model_files_error"] = str(e) | |
| # λ‘κ·Έ νμΌ μ‘΄μ¬ μ¬λΆ | |
| try: | |
| log_files = list(Path("logs").glob("*.json")) | |
| health_status["log_files_count"] = len(log_files) | |
| except Exception as e: | |
| health_status["log_files_count"] = 0 | |
| health_status["log_files_error"] = str(e) | |
| # μ΅κ·Ό νμ΅ μν | |
| try: | |
| manifest = _scheduler.get_model_versions() | |
| if manifest: | |
| latest = manifest[-1] | |
| health_status["latest_model_version"] = latest.get("version") | |
| health_status["latest_training_time"] = latest.get("timestamp") | |
| health_status["total_sessions_trained"] = sum(m.get("session_count", 0) for m in manifest) | |
| else: | |
| health_status["latest_model_version"] = None | |
| health_status["total_sessions_trained"] = 0 | |
| except Exception as e: | |
| health_status["training_status_error"] = str(e) | |
| # API μλν¬μΈνΈ μν | |
| endpoints_status = { | |
| "predict": "available", | |
| "trigger": "available", | |
| "model": "available", | |
| "update-model": "available", | |
| "health": "available" | |
| } | |
| # νκ²½μ λ°λ₯Έ state_reset μν | |
| if os.getenv("ENVIRONMENT") != "production": | |
| endpoints_status["state_reset"] = "available" | |
| else: | |
| endpoints_status["state_reset"] = "disabled_in_production" | |
| health_status["endpoints"] = endpoints_status | |
| # μμ€ν 리μμ€ (κ°λ¨ λ²μ ) | |
| try: | |
| # νλ‘μΈμ€ μ 보 (κΈ°λ³Έμ μΈ) | |
| health_status["process_id"] = os.getpid() | |
| health_status["working_directory"] = os.getcwd() | |
| except Exception as e: | |
| health_status["system_error"] = str(e) | |
| except Exception as e: | |
| health_status["status"] = "degraded" | |
| health_status["error"] = str(e) | |
| # μλ¬κ° λ°μν΄λ κΈ°λ³Έ μ 보λ μ μ§ | |
| return health_status | |
| def root() -> dict: | |
| endpoints = { | |
| "trigger": "/trigger", | |
| "model": "/model", | |
| "update-model": "/update-model", | |
| "predict": "/predict", | |
| "health": "/health" | |
| } | |
| return { | |
| "message": "MuscleCare LightGBM Scheduler API", | |
| "docs": "/docs", | |
| "endpoints": endpoints, | |
| "environment": os.getenv("ENVIRONMENT", "development"), | |
| } | |
| def _upload_to_hub(model_path: str) -> Optional[str]: | |
| token = os.getenv("HF_HYBRID_MODEL_TOKEN") | |
| repo_id = os.getenv("HF_HYBRID_MODEL_REPO_ID") | |
| print(f"[Upload] Model Hub μ λ‘λ μλ: {model_path}") | |
| if not token or not repo_id: | |
| print(f"[Upload] νκ²½λ³μ λλ½: TOKEN={'***' if token else 'None'}, REPO_ID={repo_id}") | |
| return None | |
| path = Path(model_path) | |
| if not path.exists(): | |
| raise HTTPException(status_code=404, detail=f"λͺ¨λΈ νμΌμ μ°Ύμ μ μμ΅λλ€: {model_path}") | |
| try: | |
| print(f"[Upload] 리ν¬μ§ν 리 μμ±/νμΈ: {repo_id}") | |
| api = HfApi(token=token) | |
| api.create_repo(repo_id=repo_id, repo_type="model", private=False, exist_ok=True) | |
| print(f"[Upload] λͺ¨λΈ νμΌ μ λ‘λ: {path.name}") | |
| api.upload_file( | |
| path_or_fileobj=path, | |
| path_in_repo=path.name, | |
| repo_id=repo_id, | |
| repo_type="model", | |
| commit_message=f"LightGBM model upload ({path.name})", | |
| ) | |
| manifest_path = Path("logs/model_versions.json") | |
| if manifest_path.exists(): | |
| print(f"[Upload] λ©νλ°μ΄ν° νμΌ μ λ‘λ") | |
| api.upload_file( | |
| path_or_fileobj=str(manifest_path), | |
| path_in_repo="model_versions.json", | |
| repo_id=repo_id, | |
| repo_type="model", | |
| commit_message="Update model manifest", | |
| ) | |
| hub_url = f"https://huggingface.co/{repo_id}" | |
| print(f"[Upload] β μ λ‘λ μ±κ³΅: {hub_url}") | |
| return hub_url | |
| except Exception as exc: | |
| print(f"[Upload] β μ λ‘λ μ€ν¨: {exc}") | |
| raise | |
| def _resolve_model_entry(version: Optional[int] = None) -> Dict[str, Any]: | |
| manifest = _scheduler.get_model_versions() | |
| if not manifest: | |
| raise HTTPException(status_code=404, detail="μμ§ νμ΅λ λͺ¨λΈμ΄ μμ΅λλ€.") | |
| if version is None: | |
| return manifest[-1] | |
| for entry in manifest: | |
| if entry.get("version") == version: | |
| return entry | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"λ²μ {version} λͺ¨λΈμ μ°Ύμ μ μμ΅λλ€.", | |
| ) | |
| def download_model(version: Optional[int] = None) -> FileResponse: | |
| entry = _resolve_model_entry(version) | |
| path = Path(entry["path"]) | |
| if not path.exists(): | |
| raise HTTPException(status_code=404, detail="λͺ¨λΈ νμΌμ μ°Ύμ μ μμ΅λλ€.") | |
| response = FileResponse( | |
| path=path, | |
| filename=entry["filename"], | |
| media_type="application/octet-stream", | |
| ) | |
| response.headers["X-Model-Version"] = str(entry["version"]) | |
| return response | |
| def download_latest_alias() -> FileResponse: | |
| return download_model() | |
| # νλ‘λμ νκ²½μμλ reset API λΉνμ±ν | |
| environment = os.getenv("ENVIRONMENT", "development") | |
| if environment != "production": | |
| print(f"[Security] State reset API enabled (environment: {environment})") | |
| def reset_state() -> ResetStateResponse: | |
| print("[Security] State reset requested") | |
| state = _scheduler.reset_training_state() | |
| return ResetStateResponse(status="reset", state=state) | |
| else: | |
| print(f"[Security] State reset API disabled in production environment") | |
| def trigger_training(upload: bool = False) -> TrainResponse: | |
| try: | |
| result = _scheduler.run_scheduled_training() | |
| except Exception as exc: # pragma: no cover | |
| raise HTTPException(status_code=500, detail=f"νμ΅ μ€ν μ€λ₯: {exc}") from exc | |
| message = "λͺ¨λΈ νμ΅μ΄ μλ£λμμ΅λλ€." if result["status"] == "trained" else "νμ΅μ΄ 건λλ°μ΄μ‘μ΅λλ€." | |
| hub_url = None | |
| model_version = result.get("model_version") | |
| model_path = result.get("model_path") | |
| if upload and model_path and result["status"] == "trained": | |
| try: | |
| hub_url = _upload_to_hub(model_path) | |
| message = "λͺ¨λΈ νμ΅ λ° Hugging Face μ λ‘λκ° μλ£λμμ΅λλ€." | |
| except HTTPException: | |
| raise | |
| except Exception as exc: # pragma: no cover | |
| raise HTTPException(status_code=500, detail=f"Hugging Face μ λ‘λ μ€ν¨: {exc}") from exc | |
| _apply_training_result(result) | |
| return TrainResponse( | |
| status=result["status"], | |
| new_data_count=result.get("new_data_count", 0), | |
| model_path=model_path, | |
| hub_url=hub_url, | |
| model_version=model_version, | |
| message=message, | |
| new_session_count=result.get("new_session_count"), | |
| ) | |
| def trigger_training_alias(upload: bool = False) -> TrainResponse: | |
| return trigger_training(upload=upload) | |
| def update_model(force: bool = False) -> TrainResponse: | |
| """ | |
| λͺ¨λΈμ κ°μ λ‘ μ λ°μ΄νΈν©λλ€. | |
| - force=true: κΈ°μ‘΄ λͺ¨λΈμ΄ μμ΄λ μ λ°μ΄νΈ | |
| - force=false: μλ‘μ΄ λ°μ΄ν°κ° μμ λλ§ μ λ°μ΄νΈ | |
| """ | |
| try: | |
| print(f"[Update] λͺ¨λΈ μ λ°μ΄νΈ μμ² (force={force})") | |
| if force: | |
| # κ°μ μ λ°μ΄νΈ: κΈ°μ‘΄ λͺ¨λΈ 무μνκ³ μλ‘ νμ΅ | |
| print("[Update] κ°μ μ λ°μ΄νΈ λͺ¨λ") | |
| # μμλ‘ κΈ°μ‘΄ λͺ¨λΈμ λ°±μ | |
| manifest = _scheduler.get_model_versions() | |
| if manifest: | |
| print(f"[Update] κΈ°μ‘΄ λͺ¨λΈ {len(manifest)}κ° λ°±μ λ¨") | |
| result = _scheduler.run_scheduled_training() | |
| if result.get("status") == "trained": | |
| _apply_training_result(result) | |
| message = "β λͺ¨λΈ μ λ°μ΄νΈ μλ£" | |
| else: | |
| message = f"β οΈ λͺ¨λΈ μ λ°μ΄νΈ 건λλ: {result.get('message', 'μλ‘μ΄ λ°μ΄ν° μμ')}" | |
| return TrainResponse( | |
| status=result["status"], | |
| new_data_count=result.get("new_data_count", 0), | |
| model_path=result.get("model_path"), | |
| hub_url=None, # μ λ°μ΄νΈ μμλ Hub μ λ‘λ νμ§ μμ | |
| model_version=result.get("model_version"), | |
| message=message, | |
| new_session_count=result.get("new_session_count"), | |
| ) | |
| except Exception as exc: | |
| raise HTTPException(status_code=500, detail=f"λͺ¨λΈ μ λ°μ΄νΈ μ€ν¨: {exc}") from exc | |
| def predict(payload: PredictRequest) -> PredictResponse: | |
| booster, version = _get_active_model() | |
| features = _build_feature_vector(payload) | |
| prediction = booster.predict(features)[0] | |
| return PredictResponse(fatigue=float(prediction), model_version=version) | |
| __all__ = ["app"] | |