"""FastAPI 앱: 수동 학습 및 Hugging Face 업로드 트리거""" from __future__ import annotations import json import os import threading import time from pathlib import Path from typing import Any, Dict, Optional import schedule from fastapi import FastAPI, HTTPException from fastapi.responses import FileResponse from huggingface_hub import HfApi, hf_hub_download try: from huggingface_hub.utils import HfHubHTTPError except ImportError: # pragma: no cover HfHubHTTPError = Exception # type: ignore from pydantic import BaseModel from train_scheduler import TrainingScheduler app = FastAPI( title="MuscleCare Train Scheduler API", description="수동으로 모델 학습 및 Hugging Face 업로드를 트리거합니다.", ) _scheduler = TrainingScheduler() class TrainResponse(BaseModel): status: str new_data_count: int model_path: Optional[str] = None hub_url: Optional[str] = None model_version: Optional[int] = None message: str @app.on_event("startup") def startup_training() -> None: """서버 시작 시 자동으로 모델 학습을 실행합니다.""" try: print("🚀 서버 시작: 자동 모델 학습을 시작합니다...") result = _scheduler.run_scheduled_training() if result["status"] == "trained": print(f"✅ 서버 시작 시 학습 완료: {result['new_data_count']}개 데이터로 학습됨") else: print(f"ℹ️ 서버 시작 시 학습 건너뜀: {result.get('message', '새로운 데이터 없음')}") except Exception as exc: print(f"⚠️ 서버 시작 시 학습 실패: {exc}") # 기존 스케줄링 설정 schedule.clear() schedule.every().sunday.at("00:00").do(_scheduler.run_scheduled_training) def _run_schedule() -> None: while True: schedule.run_pending() time.sleep(60) threading.Thread(target=_run_schedule, daemon=True).start() @app.head("/health") async def health_head(): return None # HEAD는 바디가 필요 없으므로 None 반환 @app.get("/health") def health_check() -> dict: return {"status": "ok"} @app.get("/") def root() -> dict: return { "message": "MuscleCare Train Scheduler API가 실행 중입니다.", "endpoints": { "health": "/health", "trigger": "/trigger", }, "docs": "/docs", } def _upload_to_hub(model_path: str) -> Optional[str]: token = os.getenv("HF_E2E_MODEL_TOKEN") repo_id = os.getenv("HF_E2E_MODEL_REPO_ID") if not token or not repo_id: raise HTTPException( status_code=400, detail="환경 변수 HF_E2E_MODEL_TOKEN / HF_E2E_MODEL_REPO_ID가 설정되어 있지 않습니다.", ) path = Path(model_path) if not path.exists(): raise HTTPException(status_code=404, detail=f"모델 파일을 찾을 수 없습니다: {model_path}") api = HfApi(token=token) api.create_repo(repo_id=repo_id, repo_type="model", private=False, exist_ok=True) api.upload_file( path_or_fileobj=path, path_in_repo=path.name, repo_id=repo_id, repo_type="model", commit_message="Manual scheduler trigger upload", ) return f"https://huggingface.co/{repo_id}" # TODO: include version info in response body @app.get("/model") @app.get("/model/{version:int}") def download_model( version: Optional[int] = None, filename: Optional[str] = None ) -> FileResponse: repo_id = os.getenv("HF_E2E_MODEL_REPO_ID") token = os.getenv("HF_E2E_MODEL_TOKEN") default_filename = os.getenv("HF_E2E_MODEL_FILE", "cnn_gru_fatigue.tflite") if not repo_id: raise HTTPException( status_code=400, detail="환경 변수 HF_E2E_MODEL_REPO_ID가 설정되어 있지 않습니다." ) current_state = _scheduler.load_training_state() current_version = int(current_state.get("model_version", 0) or 0) try: if not version: target_filename = filename or default_filename local_path = hf_hub_download( repo_id=repo_id, filename=target_filename, repo_type="model", token=token, local_dir="./model_cache", local_dir_use_symlinks=False, ) actual_version = current_version else: if version > current_version: raise HTTPException( status_code=404, detail=f"현재 모델 버전은 {current_version}입니다. 버전 {version}은 존재하지 않습니다." ) manifest_path = hf_hub_download( repo_id=repo_id, filename="model_versions.json", repo_type="model", token=token, local_dir="./model_cache", local_dir_use_symlinks=False, ) with open(manifest_path, "r", encoding="utf-8") as f: manifest = json.load(f) version_entry = next( (entry for entry in manifest if entry.get("version") == version), None ) if version_entry is None: raise HTTPException( status_code=404, detail=f"버전 {version}에 해당하는 모델을 찾을 수 없습니다." ) target_filename = filename or version_entry.get("filename") target_revision = version_entry.get("commit") if not target_filename or not target_revision: raise HTTPException( status_code=500, detail=f"버전 {version} 메타데이터가 올바르지 않습니다." ) local_path = hf_hub_download( repo_id=repo_id, filename=target_filename, repo_type="model", token=token, local_dir="./model_cache", local_dir_use_symlinks=False, revision=target_revision, ) actual_version = version except Exception as exc: status = getattr(getattr(exc, "response", None), "status_code", None) if status == 404: raise HTTPException( status_code=404, detail="허깅페이스에서 지정한 모델 파일을 찾을 수 없습니다." ) from exc raise HTTPException( status_code=500, detail=f"Hugging Face Hub 다운로드 실패: {exc}" ) from exc response = FileResponse( path=local_path, filename=Path(target_filename).name, media_type="application/octet-stream" ) response.headers["X-Model-Version"] = str(actual_version) response.headers["X-Model-Filename"] = Path(target_filename).name return response class ResetStateResponse(BaseModel): status: str state: Dict[str, Any] @app.post("/state/reset", response_model=ResetStateResponse) def reset_training_state() -> ResetStateResponse: try: state = _scheduler.reset_training_state() return ResetStateResponse( status="reset", state=state, ) except Exception as exc: # pylint: disable=broad-except raise HTTPException(status_code=500, detail=f"학습 상태 초기화에 실패했습니다: {exc}") from exc @app.post("/trigger", response_model=TrainResponse) def trigger_training(upload: bool = True) -> TrainResponse: try: result = _scheduler.run_scheduled_training() except Exception as exc: # pylint: disable=broad-except raise HTTPException(status_code=500, detail=f"학습 실행 중 오류가 발생했습니다: {exc}") from exc message = "새로운 데이터가 없어 학습을 건너뜁니다." hub_url = None if result["status"] == "trained": message = "모델 학습이 완료되었습니다." model_path = result.get("model_path") if upload and model_path: try: hub_url = _upload_to_hub(model_path) message = "모델 학습 및 업로드가 완료되었습니다." except HTTPException: raise except Exception as exc: # pylint: disable=broad-except raise HTTPException(status_code=500, detail=f"Hugging Face 업로드 실패: {exc}") from exc return TrainResponse( status=result["status"], new_data_count=result["new_data_count"], model_path=result.get("model_path"), hub_url=hub_url, message=message, ) __all__ = ["app"]