import os from typing import List, Optional from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from dotenv import load_dotenv import pandas as pd from datetime import datetime from datasets import Dataset, DatasetDict, load_dataset # 로컬 개발: .env 파일 로드 (있으면) load_dotenv() # Hugging Face 설정 HF_DATA_REPO_ID = os.getenv("HF_DATA_REPO_ID") HF_DATA_TOKEN = os.getenv("HF_DATA_TOKEN") app = FastAPI(title="MuscleCare FastAPI Server") # ----- 모델 ----- class DatasetItem(BaseModel): user_id: str session_id: Optional[str] = None window_id: int window_start_ms: int window_end_ms: int timestamp_utc: Optional[str] = None acc_x_mean: Optional[float] = None acc_y_mean: Optional[float] = None acc_z_mean: Optional[float] = None gyro_x_mean: Optional[float] = None gyro_y_mean: Optional[float] = None gyro_z_mean: Optional[float] = None linacc_x_mean: Optional[float] = None linacc_y_mean: Optional[float] = None linacc_z_mean: Optional[float] = None gravity_x_mean: Optional[float] = None gravity_y_mean: Optional[float] = None gravity_z_mean: Optional[float] = None acc_x_std: Optional[float] = None acc_y_std: Optional[float] = None acc_z_std: Optional[float] = None gyro_x_std: Optional[float] = None gyro_y_std: Optional[float] = None gyro_z_std: Optional[float] = None rms_acc: Optional[float] = None rms_gyro: Optional[float] = None mean_freq_acc: Optional[float] = None mean_freq_gyro: Optional[float] = None entropy_acc: Optional[float] = None entropy_gyro: Optional[float] = None jerk_mean: Optional[float] = None jerk_std: Optional[float] = None stability_index: Optional[float] = None rms_base: Optional[float] = None freq_base: Optional[float] = None user_emb: Optional[List[float]] = Field(default=None, description="length=12 vector") fatigue_prev: Optional[float] = None fatigue: Optional[float] = None fatigue_level: Optional[int] = None quality_flag: Optional[int] = 1 window_size_ms: Optional[int] = 2000 overlap_rate: Optional[float] = 0.5 class DatasetBatchPayload(BaseModel): batch_data: List[DatasetItem] # ----- 엔드포인트 ----- @app.get("/") def root(): """루트 엔드포인트 - 서버 상태 확인""" return { "status": "running", "message": "MuscleCare API Server", "version": "1.0.0", "endpoints": { "health": "/health (빠른 체크)", "docs": "/docs", "upload_dataset": "/upload_dataset (배치 데이터 업로드)", "user_dataset": "/user_dataset/{user_id}" } } @app.head("/health") async def health_head(): return None # HEAD는 바디가 필요 없으므로 None 반환 @app.get("/health") def health(): try: # 간단한 health 체크 - DB 연결 없이 서버 상태만 확인 return { "ok": True, "server": "running", "timestamp": datetime.now().isoformat(), "status": "healthy" } except Exception as e: return {"ok": False, "error": str(e)} @app.post("/upload_dataset") async def upload_dataset(payload: DatasetBatchPayload): """배치 데이터셋을 Hugging Face Hub에 일괄 반영 (스키마 정규화 포함)""" try: hf_repo_id = os.getenv("HF_DATA_REPO_ID") hf_token = os.getenv("HF_DATA_TOKEN") if not hf_repo_id or not hf_token: raise HTTPException(status_code=500, detail="Hugging Face 설정이 필요합니다 (HF_DATA_REPO_ID, HF_DATA_TOKEN)") # 새 스키마 정의 target_cols = [ "user_id", "session_id", "window_id", "window_start_ms", "window_end_ms", "timestamp_utc", "acc_x_mean", "acc_y_mean", "acc_z_mean", "gyro_x_mean", "gyro_y_mean", "gyro_z_mean", "linacc_x_mean", "linacc_y_mean", "linacc_z_mean", "gravity_x_mean", "gravity_y_mean", "gravity_z_mean", "acc_x_std", "acc_y_std", "acc_z_std", "gyro_x_std", "gyro_y_std", "gyro_z_std", "rms_acc", "rms_gyro", "mean_freq_acc", "mean_freq_gyro", "entropy_acc", "entropy_gyro", "jerk_mean", "jerk_std", "stability_index", "rms_base", "freq_base", "user_emb", "fatigue_prev", "fatigue", "fatigue_level", "quality_flag", "window_size_ms", "overlap_rate", ] # 기존 데이터 로드 try: existing = load_dataset(hf_repo_id, token=hf_token) print("📂 기존 DatasetDict 로드 완료") except Exception: existing = DatasetDict() print("📂 기존 repo 없음 → 새로 생성") # 기존 스키마 정규화: 불필요 컬럼 제거, 누락 컬럼 추가 def normalize_existing_df(df: pd.DataFrame) -> pd.DataFrame: # user_emb가 문자열인 경우 파싱 시도 if "user_emb" in df.columns: def _parse_emb(x): if isinstance(x, list): return x if isinstance(x, str): try: import json as _json v = _json.loads(x) return v if isinstance(v, list) else [] except Exception: return [] return [] df["user_emb"] = df["user_emb"].apply(_parse_emb) # 타임스탬프 없으면 추가 if "timestamp_utc" not in df.columns or df["timestamp_utc"].isnull().all(): df["timestamp_utc"] = datetime.now().isoformat() # 타겟 컬럼 세트로 맞추기 for c in target_cols: if c not in df.columns: df[c] = None # 여분 컬럼 제거 df = df[target_cols] return df # payload를 사용자별로 그룹화 user_groups: dict[str, list[dict]] = {} for item in payload.batch_data: rec = item.model_dump() if not rec.get("timestamp_utc"): rec["timestamp_utc"] = datetime.now().isoformat() user_groups.setdefault(item.user_id, []).append(rec) results = {} # 사용자별 병합 처리 for user_id, records in user_groups.items(): try: new_df = pd.DataFrame(records) # 새 DF도 타겟 스키마로 보정 for c in target_cols: if c not in new_df.columns: new_df[c] = None new_df = new_df[target_cols] if user_id in existing: old_df = existing[user_id].to_pandas() old_df = normalize_existing_df(old_df) merged = pd.concat([old_df, new_df], ignore_index=True) existing[user_id] = df_to_dataset(merged) print(f"📊 {user_id}: 병합 ({len(old_df)} + {len(new_df)} = {len(merged)})") else: existing[user_id] = df_to_dataset(new_df) print(f"📊 {user_id}: 신규 생성 ({len(new_df)})") results[user_id] = {"status": "success", "new_rows": len(records)} except Exception as e: print(f"❌ {user_id} 처리 실패: {e}") results[user_id] = {"status": "failed", "error": str(e)} # 푸시 try: existing.push_to_hub(hf_repo_id, token=hf_token, private=True) print(f"✅ DatasetDict 푸시 완료: {len(existing)} users") except Exception as e: print(f"❌ 전체 푸시 실패: {e}") raise HTTPException(status_code=500, detail=f"전체 푸시 실패: {str(e)}") return { "processed_users": len(user_groups), "total_rows": sum(len(v) for v in user_groups.values()), "results": results, } except HTTPException: raise except Exception as e: print(f"❌ 배치 데이터셋 업로드 실패: {e}") raise HTTPException(status_code=500, detail=f"배치 데이터셋 업로드 실패: {str(e)}") def df_to_dataset(df): """DataFrame을 Dataset으로 변환""" return Dataset.from_pandas(df)