Merry99's picture
change dataset config
93e58df
raw
history blame
8.76 kB
import os
from typing import List, Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from dotenv import load_dotenv
import pandas as pd
from datetime import datetime
from datasets import Dataset, DatasetDict, load_dataset
# ๋กœ์ปฌ ๊ฐœ๋ฐœ: .env ํŒŒ์ผ ๋กœ๋“œ (์žˆ์œผ๋ฉด)
load_dotenv()
# Hugging Face ์„ค์ •
HF_DATA_REPO_ID = os.getenv("HF_DATA_REPO_ID")
HF_DATA_TOKEN = os.getenv("HF_DATA_TOKEN")
app = FastAPI(title="MuscleCare FastAPI Server")
# ----- ๋ชจ๋ธ -----
class DatasetItem(BaseModel):
id: Optional[int] = None
user_id: str
session_id: Optional[str] = None
window_id: int
window_start_ms: int
window_end_ms: int
timestamp_utc: Optional[str] = None
acc_x_mean: Optional[float] = None
acc_y_mean: Optional[float] = None
acc_z_mean: Optional[float] = None
gyro_x_mean: Optional[float] = None
gyro_y_mean: Optional[float] = None
gyro_z_mean: Optional[float] = None
linacc_x_mean: Optional[float] = None
linacc_y_mean: Optional[float] = None
linacc_z_mean: Optional[float] = None
gravity_x_mean: Optional[float] = None
gravity_y_mean: Optional[float] = None
gravity_z_mean: Optional[float] = None
acc_x_std: Optional[float] = None
acc_y_std: Optional[float] = None
acc_z_std: Optional[float] = None
gyro_x_std: Optional[float] = None
gyro_y_std: Optional[float] = None
gyro_z_std: Optional[float] = None
rms_acc: Optional[float] = None
rms_gyro: Optional[float] = None
mean_freq_acc: Optional[float] = None
mean_freq_gyro: Optional[float] = None
entropy_acc: Optional[float] = None
entropy_gyro: Optional[float] = None
jerk_mean: Optional[float] = None
jerk_std: Optional[float] = None
stability_index: Optional[float] = None
rms_base: Optional[float] = None
freq_base: Optional[float] = None
user_emb: Optional[List[float]] = Field(default=None, description="length=12 vector")
fatigue_prev: Optional[float] = None
fatigue: Optional[float] = None
fatigue_level: Optional[int] = None
quality_flag: Optional[int] = 1
window_size_ms: Optional[int] = 2000
overlap_rate: Optional[float] = 0.5
class DatasetBatchPayload(BaseModel):
batch_data: List[DatasetItem]
# ----- ์—”๋“œํฌ์ธํŠธ -----
@app.get("/")
def root():
"""๋ฃจํŠธ ์—”๋“œํฌ์ธํŠธ - ์„œ๋ฒ„ ์ƒํƒœ ํ™•์ธ"""
return {
"status": "running",
"message": "MuscleCare API Server",
"version": "1.0.0",
"endpoints": {
"health": "/health (๋น ๋ฅธ ์ฒดํฌ)",
"docs": "/docs",
"upload_dataset": "/upload_dataset (๋ฐฐ์น˜ ๋ฐ์ดํ„ฐ ์—…๋กœ๋“œ)",
"user_dataset": "/user_dataset/{user_id}"
}
}
@app.get("/health")
def health():
try:
# ๊ฐ„๋‹จํ•œ health ์ฒดํฌ - DB ์—ฐ๊ฒฐ ์—†์ด ์„œ๋ฒ„ ์ƒํƒœ๋งŒ ํ™•์ธ
return {
"ok": True,
"server": "running",
"timestamp": datetime.now().isoformat(),
"status": "healthy"
}
except Exception as e:
return {"ok": False, "error": str(e)}
@app.post("/upload_dataset")
async def upload_dataset(payload: DatasetBatchPayload):
"""๋ฐฐ์น˜ ๋ฐ์ดํ„ฐ์…‹์„ Hugging Face Hub์— ์ผ๊ด„ ๋ฐ˜์˜ (์Šคํ‚ค๋งˆ ์ •๊ทœํ™” ํฌํ•จ)"""
try:
hf_repo_id = os.getenv("HF_DATA_REPO_ID")
hf_token = os.getenv("HF_DATA_TOKEN")
if not hf_repo_id or not hf_token:
raise HTTPException(status_code=500, detail="Hugging Face ์„ค์ •์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค (HF_DATA_REPO_ID, HF_DATA_TOKEN)")
# ์ƒˆ ์Šคํ‚ค๋งˆ ์ •์˜
target_cols = [
"id",
"user_id",
"session_id",
"window_id",
"window_start_ms",
"window_end_ms",
"timestamp_utc",
"acc_x_mean",
"acc_y_mean",
"acc_z_mean",
"gyro_x_mean",
"gyro_y_mean",
"gyro_z_mean",
"linacc_x_mean",
"linacc_y_mean",
"linacc_z_mean",
"gravity_x_mean",
"gravity_y_mean",
"gravity_z_mean",
"acc_x_std",
"acc_y_std",
"acc_z_std",
"gyro_x_std",
"gyro_y_std",
"gyro_z_std",
"rms_acc",
"rms_gyro",
"mean_freq_acc",
"mean_freq_gyro",
"entropy_acc",
"entropy_gyro",
"jerk_mean",
"jerk_std",
"stability_index",
"rms_base",
"freq_base",
"user_emb",
"fatigue_prev",
"fatigue",
"fatigue_level",
"quality_flag",
"window_size_ms",
"overlap_rate",
]
# ๊ธฐ์กด ๋ฐ์ดํ„ฐ ๋กœ๋“œ
try:
existing = load_dataset(hf_repo_id, token=hf_token)
print("๐Ÿ“‚ ๊ธฐ์กด DatasetDict ๋กœ๋“œ ์™„๋ฃŒ")
except Exception:
existing = DatasetDict()
print("๐Ÿ“‚ ๊ธฐ์กด repo ์—†์Œ โ†’ ์ƒˆ๋กœ ์ƒ์„ฑ")
# ๊ธฐ์กด ์Šคํ‚ค๋งˆ ์ •๊ทœํ™”: ๋ถˆํ•„์š” ์ปฌ๋Ÿผ ์ œ๊ฑฐ, ๋ˆ„๋ฝ ์ปฌ๋Ÿผ ์ถ”๊ฐ€
def normalize_existing_df(df: pd.DataFrame) -> pd.DataFrame:
# user_emb๊ฐ€ ๋ฌธ์ž์—ด์ธ ๊ฒฝ์šฐ ํŒŒ์‹ฑ ์‹œ๋„
if "user_emb" in df.columns:
def _parse_emb(x):
if isinstance(x, list):
return x
if isinstance(x, str):
try:
import json as _json
v = _json.loads(x)
return v if isinstance(v, list) else []
except Exception:
return []
return []
df["user_emb"] = df["user_emb"].apply(_parse_emb)
# ํƒ€์ž„์Šคํƒฌํ”„ ์—†์œผ๋ฉด ์ถ”๊ฐ€
if "timestamp_utc" not in df.columns or df["timestamp_utc"].isnull().all():
df["timestamp_utc"] = datetime.now().isoformat()
# ํƒ€๊ฒŸ ์ปฌ๋Ÿผ ์„ธํŠธ๋กœ ๋งž์ถ”๊ธฐ
for c in target_cols:
if c not in df.columns:
df[c] = None
# ์—ฌ๋ถ„ ์ปฌ๋Ÿผ ์ œ๊ฑฐ
df = df[target_cols]
return df
# payload๋ฅผ ์‚ฌ์šฉ์ž๋ณ„๋กœ ๊ทธ๋ฃนํ™”
user_groups: dict[str, list[dict]] = {}
for item in payload.batch_data:
rec = item.model_dump()
if not rec.get("timestamp_utc"):
rec["timestamp_utc"] = datetime.now().isoformat()
user_groups.setdefault(item.user_id, []).append(rec)
results = {}
# ์‚ฌ์šฉ์ž๋ณ„ ๋ณ‘ํ•ฉ ์ฒ˜๋ฆฌ
for user_id, records in user_groups.items():
try:
new_df = pd.DataFrame(records)
# ์ƒˆ DF๋„ ํƒ€๊ฒŸ ์Šคํ‚ค๋งˆ๋กœ ๋ณด์ •
for c in target_cols:
if c not in new_df.columns:
new_df[c] = None
new_df = new_df[target_cols]
if user_id in existing:
old_df = existing[user_id].to_pandas()
old_df = normalize_existing_df(old_df)
merged = pd.concat([old_df, new_df], ignore_index=True)
existing[user_id] = df_to_dataset(merged)
print(f"๐Ÿ“Š {user_id}: ๋ณ‘ํ•ฉ ({len(old_df)} + {len(new_df)} = {len(merged)})")
else:
existing[user_id] = df_to_dataset(new_df)
print(f"๐Ÿ“Š {user_id}: ์‹ ๊ทœ ์ƒ์„ฑ ({len(new_df)})")
results[user_id] = {"status": "success", "new_rows": len(records)}
except Exception as e:
print(f"โŒ {user_id} ์ฒ˜๋ฆฌ ์‹คํŒจ: {e}")
results[user_id] = {"status": "failed", "error": str(e)}
# ํ‘ธ์‹œ
try:
existing.push_to_hub(hf_repo_id, token=hf_token, private=True)
print(f"โœ… DatasetDict ํ‘ธ์‹œ ์™„๋ฃŒ: {len(existing)} users")
except Exception as e:
print(f"โŒ ์ „์ฒด ํ‘ธ์‹œ ์‹คํŒจ: {e}")
raise HTTPException(status_code=500, detail=f"์ „์ฒด ํ‘ธ์‹œ ์‹คํŒจ: {str(e)}")
return {
"processed_users": len(user_groups),
"total_rows": sum(len(v) for v in user_groups.values()),
"results": results,
}
except HTTPException:
raise
except Exception as e:
print(f"โŒ ๋ฐฐ์น˜ ๋ฐ์ดํ„ฐ์…‹ ์—…๋กœ๋“œ ์‹คํŒจ: {e}")
raise HTTPException(status_code=500, detail=f"๋ฐฐ์น˜ ๋ฐ์ดํ„ฐ์…‹ ์—…๋กœ๋“œ ์‹คํŒจ: {str(e)}")
def df_to_dataset(df):
"""DataFrame์„ Dataset์œผ๋กœ ๋ณ€ํ™˜"""
return Dataset.from_pandas(df)