Spaces:
Sleeping
Sleeping
File size: 4,620 Bytes
10f2364 e4e8f94 10f2364 e4e8f94 10f2364 e4e8f94 bbcbd17 e4e8f94 bbcbd17 e4e8f94 10f2364 e4e8f94 10f2364 e4e8f94 10f2364 e4e8f94 10f2364 e4e8f94 10f2364 e4e8f94 10f2364 e4e8f94 10f2364 bbcbd17 10f2364 bbcbd17 10f2364 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | """
CardioScreen AI β FastAPI Backend
Serves the local AI inference engine for canine cardiac screening.
"""
import os
import sys
from contextlib import asynccontextmanager
from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from inference import (
predict_audio,
_load_cnn_model, _load_finetuned_model, _load_resnet_model, _load_gru_model,
_cnn_available
)
WEIGHTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "weights")
# All 4 model weight files to download from HF
WEIGHT_FILES = [
"cnn_heart_classifier.pt", # Joint CNN (375 KB)
"cnn_config.json",
"cnn_finetuned.pt", # Fine-tuned CNN (375 KB)
"cnn_resnet_classifier.pt", # ImageNet ResNet-18 (43.7 MB)
"gru_canine_finetuned.pt", # Bi-GRU McDonald (639 KB)
"gru_canine_config.json",
]
_startup_errors = []
def _ensure_weights():
"""Download all model weights from HF Space repo if not already present."""
os.makedirs(WEIGHTS_DIR, exist_ok=True)
all_ok = True
for fname in WEIGHT_FILES:
fpath = os.path.join(WEIGHTS_DIR, fname)
if os.path.exists(fpath) and os.path.getsize(fpath) > 1000:
print(f" {fname}: present ({os.path.getsize(fpath)//1024} KB) β", flush=True)
continue
try:
from huggingface_hub import hf_hub_download
print(f" Downloading {fname} from HF model repo...", flush=True)
# Download from public model repo (not Space β Space requires auth)
dest = hf_hub_download(
repo_id="mahmoud611/cardioscreen-weights",
filename=fname,
repo_type="model",
local_dir=WEIGHTS_DIR,
)
# Ensure it landed in the right place
if dest != fpath and os.path.exists(dest):
import shutil; shutil.copy2(dest, fpath)
print(f" {fname}: downloaded β", flush=True)
except Exception as e:
msg = f"Download failed for {fname}: {e}"
print(msg, flush=True)
_startup_errors.append(msg)
if fname.endswith(".pt"):
all_ok = False
return all_ok
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Startup: pre-download all weights and warm up all 4 models."""
print("=== CardioScreen AI v3.0 starting up (4-model comparison) ===", flush=True)
print(f"Python: {sys.version}", flush=True)
try:
import torch
print(f"PyTorch {torch.__version__} β", flush=True)
except ImportError as e:
_startup_errors.append(f"PyTorch not available: {e}")
print("Ensuring all model weights are present...", flush=True)
weights_ok = _ensure_weights()
if weights_ok:
print("Loading all 4 models...", flush=True)
_load_cnn_model()
_load_finetuned_model()
_load_resnet_model()
_load_gru_model()
else:
_startup_errors.append("Some weights missing β affected models will be skipped")
print(f"Startup errors: {_startup_errors}", flush=True)
yield
print("=== Shutting down ===", flush=True)
app = FastAPI(title="CardioScreen AI β Canine Cardiac Screening", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def health_check():
"""Health check β confirms the API is running."""
import inference
weights_status = {
f: os.path.exists(os.path.join(WEIGHTS_DIR, f))
for f in WEIGHT_FILES if f.endswith(".pt")
}
return {
"status": "ok",
"service": "CardioScreen AI",
"version": "3.0",
"models": {
"joint_cnn": inference._cnn_available,
"finetuned_cnn": inference._finetuned_available,
"resnet18": inference._resnet_available,
"bigru": inference._gru_available,
},
"weights": weights_status,
"startup_errors": _startup_errors,
}
@app.post("/analyze")
async def analyze_audio(file: UploadFile = File(...)):
"""Receives audio from the React frontend and returns screening results."""
audio_bytes = await file.read()
print(f"Received: {file.filename}, {len(audio_bytes)} bytes", flush=True)
return predict_audio(audio_bytes)
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
print(f"Starting CardioScreen AI server on http://0.0.0.0:{port}")
uvicorn.run(app, host="0.0.0.0", port=port)
|