Spaces:
Sleeping
Sleeping
| """ | |
| CardioScreen AI β FastAPI Backend | |
| Serves the local AI inference engine for canine cardiac screening. | |
| """ | |
| import os | |
| import sys | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, UploadFile, File | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import uvicorn | |
| from inference import ( | |
| predict_audio, | |
| _load_cnn_model, _load_finetuned_model, _load_resnet_model, _load_gru_model, | |
| _cnn_available | |
| ) | |
| WEIGHTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "weights") | |
| # All 4 model weight files to download from HF | |
| WEIGHT_FILES = [ | |
| "cnn_heart_classifier.pt", # Joint CNN (375 KB) | |
| "cnn_config.json", | |
| "cnn_finetuned.pt", # Fine-tuned CNN (375 KB) | |
| "cnn_resnet_classifier.pt", # ImageNet ResNet-18 (43.7 MB) | |
| "gru_canine_finetuned.pt", # Bi-GRU McDonald (639 KB) | |
| "gru_canine_config.json", | |
| ] | |
| _startup_errors = [] | |
| def _ensure_weights(): | |
| """Download all model weights from HF Space repo if not already present.""" | |
| os.makedirs(WEIGHTS_DIR, exist_ok=True) | |
| all_ok = True | |
| for fname in WEIGHT_FILES: | |
| fpath = os.path.join(WEIGHTS_DIR, fname) | |
| if os.path.exists(fpath) and os.path.getsize(fpath) > 1000: | |
| print(f" {fname}: present ({os.path.getsize(fpath)//1024} KB) β", flush=True) | |
| continue | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| print(f" Downloading {fname} from HF model repo...", flush=True) | |
| # Download from public model repo (not Space β Space requires auth) | |
| dest = hf_hub_download( | |
| repo_id="mahmoud611/cardioscreen-weights", | |
| filename=fname, | |
| repo_type="model", | |
| local_dir=WEIGHTS_DIR, | |
| ) | |
| # Ensure it landed in the right place | |
| if dest != fpath and os.path.exists(dest): | |
| import shutil; shutil.copy2(dest, fpath) | |
| print(f" {fname}: downloaded β", flush=True) | |
| except Exception as e: | |
| msg = f"Download failed for {fname}: {e}" | |
| print(msg, flush=True) | |
| _startup_errors.append(msg) | |
| if fname.endswith(".pt"): | |
| all_ok = False | |
| return all_ok | |
| async def lifespan(app: FastAPI): | |
| """Startup: pre-download all weights and warm up all 4 models.""" | |
| print("=== CardioScreen AI v3.0 starting up (4-model comparison) ===", flush=True) | |
| print(f"Python: {sys.version}", flush=True) | |
| try: | |
| import torch | |
| print(f"PyTorch {torch.__version__} β", flush=True) | |
| except ImportError as e: | |
| _startup_errors.append(f"PyTorch not available: {e}") | |
| print("Ensuring all model weights are present...", flush=True) | |
| weights_ok = _ensure_weights() | |
| if weights_ok: | |
| print("Loading all 4 models...", flush=True) | |
| _load_cnn_model() | |
| _load_finetuned_model() | |
| _load_resnet_model() | |
| _load_gru_model() | |
| else: | |
| _startup_errors.append("Some weights missing β affected models will be skipped") | |
| print(f"Startup errors: {_startup_errors}", flush=True) | |
| yield | |
| print("=== Shutting down ===", flush=True) | |
| app = FastAPI(title="CardioScreen AI β Canine Cardiac Screening", lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def health_check(): | |
| """Health check β confirms the API is running.""" | |
| import inference | |
| weights_status = { | |
| f: os.path.exists(os.path.join(WEIGHTS_DIR, f)) | |
| for f in WEIGHT_FILES if f.endswith(".pt") | |
| } | |
| return { | |
| "status": "ok", | |
| "service": "CardioScreen AI", | |
| "version": "3.0", | |
| "models": { | |
| "joint_cnn": inference._cnn_available, | |
| "finetuned_cnn": inference._finetuned_available, | |
| "resnet18": inference._resnet_available, | |
| "bigru": inference._gru_available, | |
| }, | |
| "weights": weights_status, | |
| "startup_errors": _startup_errors, | |
| } | |
| async def analyze_audio(file: UploadFile = File(...)): | |
| """Receives audio from the React frontend and returns screening results.""" | |
| audio_bytes = await file.read() | |
| print(f"Received: {file.filename}, {len(audio_bytes)} bytes", flush=True) | |
| return predict_audio(audio_bytes) | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| print(f"Starting CardioScreen AI server on http://0.0.0.0:{port}") | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |