phoner45's picture
Upload 4 files
f2eef96 verified
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import joblib, requests, os, json, io, tempfile
import pandas as pd
import numpy as np
from tremor_analysis_functions import extract_essential_features
# =====================================================
# CONFIG
# =====================================================
MODEL_REPO = "Chula-PD/tremor-post" # 👈 เปลี่ยนชื่อ repo ตามจริง
MODEL_FILE = "tremor_rf_model.joblib"
MODEL_URL = f"https://huggingface.co/{MODEL_REPO}/resolve/main/{MODEL_FILE}"
# =====================================================
# INIT FastAPI
# =====================================================
app = FastAPI(title="CheckPD Tremor API", version="1.0")
# Allow CORS (เชื่อมต่อจาก React หรือ Streamlit frontend ได้)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# =====================================================
# LOAD MODEL
# =====================================================
def load_model():
"""โหลด joblib model จาก Hugging Face"""
if not os.path.exists(MODEL_FILE):
print("⬇️ Downloading model from Hugging Face...")
r = requests.get(MODEL_URL)
with open(MODEL_FILE, "wb") as f:
f.write(r.content)
model_dict = joblib.load(MODEL_FILE)
print("✅ Model loaded successfully.")
return model_dict
model_dict = load_model()
model = model_dict["model"]
scaler = model_dict["scaler"]
features = model_dict["features"]
# =====================================================
# HELPER: JSON Preprocessing
# =====================================================
def preprocess_json(json_data):
"""
แปลงไฟล์ JSON จากมือถือ → feature vector ที่พร้อมสำหรับ model
"""
if "recording" in json_data:
rec = json_data["recording"]
elif "data" in json_data and "recording" in json_data["data"]:
rec = json_data["data"]["recording"]
else:
raise ValueError("Invalid JSON format: missing 'recording' field")
records = rec.get("recordedData", [])
fmt = rec.get("recordingFormat", [])
if not records or not fmt:
raise ValueError("Incomplete recording data")
df = pd.DataFrame([r["data"] for r in records], columns=fmt)
df["label"] = "unknown"
df["file"] = "uploaded"
feats = extract_essential_features(df)
feat_df = pd.DataFrame([feats]).drop(columns=["label", "file"], errors="ignore")
# ✅ align feature order
X = feat_df.reindex(columns=features, fill_value=0)
X_scaled = scaler.transform(X)
return X_scaled
# =====================================================
# ENDPOINTS
# =====================================================
@app.get("/")
def home():
return {"message": "CheckPD Tremor API is running 🚀"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
"""
รับไฟล์ JSON จาก UI แล้ว predict PD/Normal
"""
try:
contents = await file.read()
json_data = json.loads(contents.decode("utf-8"))
X_scaled = preprocess_json(json_data)
y_pred = model.predict(X_scaled)[0]
y_proba = model.predict_proba(X_scaled)[0][1]
result = {
"prediction": "PD" if y_pred == 1 else "Normal",
"probability_pd": round(float(y_proba), 4),
"file_name": file.filename
}
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))