from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import joblib, requests, os, json, io, tempfile import pandas as pd import numpy as np from tremor_analysis_functions import extract_essential_features # ===================================================== # CONFIG # ===================================================== MODEL_REPO = "Chula-PD/tremor-post" # 👈 เปลี่ยนชื่อ repo ตามจริง MODEL_FILE = "tremor_rf_model.joblib" MODEL_URL = f"https://huggingface.co/{MODEL_REPO}/resolve/main/{MODEL_FILE}" # ===================================================== # INIT FastAPI # ===================================================== app = FastAPI(title="CheckPD Tremor API", version="1.0") # Allow CORS (เชื่อมต่อจาก React หรือ Streamlit frontend ได้) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ===================================================== # LOAD MODEL # ===================================================== def load_model(): """โหลด joblib model จาก Hugging Face""" if not os.path.exists(MODEL_FILE): print("⬇️ Downloading model from Hugging Face...") r = requests.get(MODEL_URL) with open(MODEL_FILE, "wb") as f: f.write(r.content) model_dict = joblib.load(MODEL_FILE) print("✅ Model loaded successfully.") return model_dict model_dict = load_model() model = model_dict["model"] scaler = model_dict["scaler"] features = model_dict["features"] # ===================================================== # HELPER: JSON Preprocessing # ===================================================== def preprocess_json(json_data): """ แปลงไฟล์ JSON จากมือถือ → feature vector ที่พร้อมสำหรับ model """ if "recording" in json_data: rec = json_data["recording"] elif "data" in json_data and "recording" in json_data["data"]: rec = json_data["data"]["recording"] else: raise ValueError("Invalid JSON format: missing 'recording' field") records = rec.get("recordedData", []) fmt = rec.get("recordingFormat", []) if not records or not fmt: raise ValueError("Incomplete recording data") df = pd.DataFrame([r["data"] for r in records], columns=fmt) df["label"] = "unknown" df["file"] = "uploaded" feats = extract_essential_features(df) feat_df = pd.DataFrame([feats]).drop(columns=["label", "file"], errors="ignore") # ✅ align feature order X = feat_df.reindex(columns=features, fill_value=0) X_scaled = scaler.transform(X) return X_scaled # ===================================================== # ENDPOINTS # ===================================================== @app.get("/") def home(): return {"message": "CheckPD Tremor API is running 🚀"} @app.post("/predict") async def predict(file: UploadFile = File(...)): """ รับไฟล์ JSON จาก UI แล้ว predict PD/Normal """ try: contents = await file.read() json_data = json.loads(contents.decode("utf-8")) X_scaled = preprocess_json(json_data) y_pred = model.predict(X_scaled)[0] y_proba = model.predict_proba(X_scaled)[0][1] result = { "prediction": "PD" if y_pred == 1 else "Normal", "probability_pd": round(float(y_proba), 4), "file_name": file.filename } return result except Exception as e: raise HTTPException(status_code=500, detail=str(e))