from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import torch from torchvision import models, transforms import torch.nn.functional as F import librosa, soundfile as sf, tempfile import numpy as np import matplotlib.pyplot as plt import librosa.display from PIL import Image import io from feature_extract import AudioFeatureExtractor import requests, os # === CONFIG === MODEL_REPO = "Chula-PD/voice-mobilenet-pd" MODEL_FILE = "MobileNet_Model.pth" MODEL_URL = f"https://huggingface.co/{MODEL_REPO}/resolve/main/{MODEL_FILE}" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # === FastAPI Init === app = FastAPI(title="CheckPD Voice API", version="1.0") # Allow CORS (for React frontend) app.add_middleware( CORSMiddleware, allow_origins=["*"], # ปรับให้เฉพาะ domain ได้ภายหลัง allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # === Load Model === def load_model(): if not os.path.exists(MODEL_FILE): print("Downloading model weights from Hugging Face...") weights_bytes = requests.get(MODEL_URL) with open(MODEL_FILE, "wb") as f: f.write(weights_bytes.content) model = models.mobilenet_v3_small(weights=None) in_features = model.classifier[-1].in_features model.classifier[-1] = torch.nn.Linear(in_features, 2) model.load_state_dict(torch.load(MODEL_FILE, map_location=device)) model.eval() return model model = load_model() classes = ["HC", "PD"] # === Image Transform === transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] ), ]) @app.get("/") def home(): return {"message": "CheckPD Voice API is running."} @app.post("/predict") async def predict(file: UploadFile = File(...)): try: # Load and preprocess audio with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: tmp.write(await file.read()) tmp.flush() wav_path = tmp.name extractor = AudioFeatureExtractor(wav_path, sr=16000) mel_db = extractor.get_melspectrogram() # Convert mel to image fig, ax = plt.subplots(figsize=(6, 3)) librosa.display.specshow(mel_db, sr=16000, hop_length=51, cmap="viridis") plt.axis("off") buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches="tight", pad_inches=0) plt.close() buf.seek(0) image = Image.open(buf).convert("RGB") # Predict input_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(input_tensor) probs = F.softmax(outputs, dim=1) pred_idx = torch.argmax(probs, dim=1).item() confidence = probs[0][pred_idx].item() return { "label": classes[pred_idx], "confidence": round(confidence, 4) } except Exception as e: raise HTTPException(status_code=500, detail=str(e))