mahmoud611's picture
Upload api.py with huggingface_hub
bbcbd17 verified
"""
CardioScreen AI β€” FastAPI Backend
Serves the local AI inference engine for canine cardiac screening.
"""
import os
import sys
from contextlib import asynccontextmanager
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from inference import (
predict_audio,
_load_cnn_model, _load_finetuned_model, _load_resnet_model, _load_gru_model,
_cnn_available
)
WEIGHTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "weights")
# All 4 model weight files to download from HF
WEIGHT_FILES = [
"cnn_heart_classifier.pt", # Joint CNN (375 KB)
"cnn_config.json",
"cnn_finetuned.pt", # Fine-tuned CNN (375 KB)
"cnn_resnet_classifier.pt", # ImageNet ResNet-18 (43.7 MB)
"gru_canine_finetuned.pt", # Bi-GRU McDonald (639 KB)
"gru_canine_config.json",
]
_startup_errors = []
def _ensure_weights():
"""Download all model weights from HF Space repo if not already present."""
os.makedirs(WEIGHTS_DIR, exist_ok=True)
all_ok = True
for fname in WEIGHT_FILES:
fpath = os.path.join(WEIGHTS_DIR, fname)
if os.path.exists(fpath) and os.path.getsize(fpath) > 1000:
print(f" {fname}: present ({os.path.getsize(fpath)//1024} KB) βœ“", flush=True)
continue
try:
from huggingface_hub import hf_hub_download
print(f" Downloading {fname} from HF model repo...", flush=True)
# Download from public model repo (not Space β€” Space requires auth)
dest = hf_hub_download(
repo_id="mahmoud611/cardioscreen-weights",
filename=fname,
repo_type="model",
local_dir=WEIGHTS_DIR,
)
# Ensure it landed in the right place
if dest != fpath and os.path.exists(dest):
import shutil; shutil.copy2(dest, fpath)
print(f" {fname}: downloaded βœ“", flush=True)
except Exception as e:
msg = f"Download failed for {fname}: {e}"
print(msg, flush=True)
_startup_errors.append(msg)
if fname.endswith(".pt"):
all_ok = False
return all_ok
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Startup: pre-download all weights and warm up all 4 models."""
print("=== CardioScreen AI v3.0 starting up (4-model comparison) ===", flush=True)
print(f"Python: {sys.version}", flush=True)
try:
import torch
print(f"PyTorch {torch.__version__} βœ“", flush=True)
except ImportError as e:
_startup_errors.append(f"PyTorch not available: {e}")
print("Ensuring all model weights are present...", flush=True)
weights_ok = _ensure_weights()
if weights_ok:
print("Loading all 4 models...", flush=True)
_load_cnn_model()
_load_finetuned_model()
_load_resnet_model()
_load_gru_model()
else:
_startup_errors.append("Some weights missing β€” affected models will be skipped")
print(f"Startup errors: {_startup_errors}", flush=True)
yield
print("=== Shutting down ===", flush=True)
app = FastAPI(title="CardioScreen AI β€” Canine Cardiac Screening", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def health_check():
"""Health check β€” confirms the API is running."""
import inference
weights_status = {
f: os.path.exists(os.path.join(WEIGHTS_DIR, f))
for f in WEIGHT_FILES if f.endswith(".pt")
}
return {
"status": "ok",
"service": "CardioScreen AI",
"version": "3.0",
"models": {
"joint_cnn": inference._cnn_available,
"finetuned_cnn": inference._finetuned_available,
"resnet18": inference._resnet_available,
"bigru": inference._gru_available,
},
"weights": weights_status,
"startup_errors": _startup_errors,
}
@app.post("/analyze")
async def analyze_audio(file: UploadFile = File(...)):
"""Receives audio from the React frontend and returns screening results."""
audio_bytes = await file.read()
print(f"Received: {file.filename}, {len(audio_bytes)} bytes", flush=True)
return predict_audio(audio_bytes)
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
print(f"Starting CardioScreen AI server on http://0.0.0.0:{port}")
uvicorn.run(app, host="0.0.0.0", port=port)