Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import joblib, requests, os, json, io, tempfile | |
| import pandas as pd | |
| import numpy as np | |
| from tremor_analysis_functions import extract_essential_features | |
| # ===================================================== | |
| # CONFIG | |
| # ===================================================== | |
| MODEL_REPO = "Chula-PD/tremor-post" # 👈 เปลี่ยนชื่อ repo ตามจริง | |
| MODEL_FILE = "tremor_rf_model.joblib" | |
| MODEL_URL = f"https://huggingface.co/{MODEL_REPO}/resolve/main/{MODEL_FILE}" | |
| # ===================================================== | |
| # INIT FastAPI | |
| # ===================================================== | |
| app = FastAPI(title="CheckPD Tremor API", version="1.0") | |
| # Allow CORS (เชื่อมต่อจาก React หรือ Streamlit frontend ได้) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ===================================================== | |
| # LOAD MODEL | |
| # ===================================================== | |
| def load_model(): | |
| """โหลด joblib model จาก Hugging Face""" | |
| if not os.path.exists(MODEL_FILE): | |
| print("⬇️ Downloading model from Hugging Face...") | |
| r = requests.get(MODEL_URL) | |
| with open(MODEL_FILE, "wb") as f: | |
| f.write(r.content) | |
| model_dict = joblib.load(MODEL_FILE) | |
| print("✅ Model loaded successfully.") | |
| return model_dict | |
| model_dict = load_model() | |
| model = model_dict["model"] | |
| scaler = model_dict["scaler"] | |
| features = model_dict["features"] | |
| # ===================================================== | |
| # HELPER: JSON Preprocessing | |
| # ===================================================== | |
| def preprocess_json(json_data): | |
| """ | |
| แปลงไฟล์ JSON จากมือถือ → feature vector ที่พร้อมสำหรับ model | |
| """ | |
| if "recording" in json_data: | |
| rec = json_data["recording"] | |
| elif "data" in json_data and "recording" in json_data["data"]: | |
| rec = json_data["data"]["recording"] | |
| else: | |
| raise ValueError("Invalid JSON format: missing 'recording' field") | |
| records = rec.get("recordedData", []) | |
| fmt = rec.get("recordingFormat", []) | |
| if not records or not fmt: | |
| raise ValueError("Incomplete recording data") | |
| df = pd.DataFrame([r["data"] for r in records], columns=fmt) | |
| df["label"] = "unknown" | |
| df["file"] = "uploaded" | |
| feats = extract_essential_features(df) | |
| feat_df = pd.DataFrame([feats]).drop(columns=["label", "file"], errors="ignore") | |
| # ✅ align feature order | |
| X = feat_df.reindex(columns=features, fill_value=0) | |
| X_scaled = scaler.transform(X) | |
| return X_scaled | |
| # ===================================================== | |
| # ENDPOINTS | |
| # ===================================================== | |
| def home(): | |
| return {"message": "CheckPD Tremor API is running 🚀"} | |
| async def predict(file: UploadFile = File(...)): | |
| """ | |
| รับไฟล์ JSON จาก UI แล้ว predict PD/Normal | |
| """ | |
| try: | |
| contents = await file.read() | |
| json_data = json.loads(contents.decode("utf-8")) | |
| X_scaled = preprocess_json(json_data) | |
| y_pred = model.predict(X_scaled)[0] | |
| y_proba = model.predict_proba(X_scaled)[0][1] | |
| result = { | |
| "prediction": "PD" if y_pred == 1 else "Normal", | |
| "probability_pd": round(float(y_proba), 4), | |
| "file_name": file.filename | |
| } | |
| return result | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |