File size: 4,620 Bytes
10f2364
 
 
 
 
 
 
 
 
 
 
e4e8f94
 
 
 
 
10f2364
 
e4e8f94
 
 
 
 
 
 
 
 
 
10f2364
 
 
 
e4e8f94
 
 
 
 
 
 
 
 
 
bbcbd17
 
e4e8f94
bbcbd17
 
 
 
e4e8f94
 
 
 
 
 
 
 
 
 
 
 
10f2364
 
 
e4e8f94
 
10f2364
 
 
 
 
e4e8f94
10f2364
e4e8f94
10f2364
e4e8f94
10f2364
e4e8f94
 
 
 
 
10f2364
e4e8f94
10f2364
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbcbd17
 
 
 
10f2364
 
 
bbcbd17
 
 
 
 
 
 
 
10f2364
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
CardioScreen AI β€” FastAPI Backend
Serves the local AI inference engine for canine cardiac screening.
"""
import os
import sys
from contextlib import asynccontextmanager

from fastapi import FastAPI, UploadFile, File
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
from inference import (
    predict_audio,
    _load_cnn_model, _load_finetuned_model, _load_resnet_model, _load_gru_model,
    _cnn_available
)

WEIGHTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "weights")

# All 4 model weight files to download from HF
WEIGHT_FILES = [
    "cnn_heart_classifier.pt",   # Joint CNN (375 KB)
    "cnn_config.json",
    "cnn_finetuned.pt",          # Fine-tuned CNN (375 KB)
    "cnn_resnet_classifier.pt",  # ImageNet ResNet-18 (43.7 MB)
    "gru_canine_finetuned.pt",   # Bi-GRU McDonald (639 KB)
    "gru_canine_config.json",
]

_startup_errors = []

def _ensure_weights():
    """Download all model weights from HF Space repo if not already present."""
    os.makedirs(WEIGHTS_DIR, exist_ok=True)
    all_ok = True
    for fname in WEIGHT_FILES:
        fpath = os.path.join(WEIGHTS_DIR, fname)
        if os.path.exists(fpath) and os.path.getsize(fpath) > 1000:
            print(f"  {fname}: present ({os.path.getsize(fpath)//1024} KB) βœ“", flush=True)
            continue
        try:
            from huggingface_hub import hf_hub_download
            print(f"  Downloading {fname} from HF model repo...", flush=True)
            # Download from public model repo (not Space β€” Space requires auth)
            dest = hf_hub_download(
                repo_id="mahmoud611/cardioscreen-weights",
                filename=fname,
                repo_type="model",
                local_dir=WEIGHTS_DIR,
            )
            # Ensure it landed in the right place
            if dest != fpath and os.path.exists(dest):
                import shutil; shutil.copy2(dest, fpath)
            print(f"  {fname}: downloaded βœ“", flush=True)
        except Exception as e:
            msg = f"Download failed for {fname}: {e}"
            print(msg, flush=True)
            _startup_errors.append(msg)
            if fname.endswith(".pt"):
                all_ok = False
    return all_ok

@asynccontextmanager
async def lifespan(app: FastAPI):
    """Startup: pre-download all weights and warm up all 4 models."""
    print("=== CardioScreen AI v3.0 starting up (4-model comparison) ===", flush=True)
    print(f"Python: {sys.version}", flush=True)
    try:
        import torch
        print(f"PyTorch {torch.__version__} βœ“", flush=True)
    except ImportError as e:
        _startup_errors.append(f"PyTorch not available: {e}")

    print("Ensuring all model weights are present...", flush=True)
    weights_ok = _ensure_weights()

    if weights_ok:
        print("Loading all 4 models...", flush=True)
        _load_cnn_model()
        _load_finetuned_model()
        _load_resnet_model()
        _load_gru_model()
    else:
        _startup_errors.append("Some weights missing β€” affected models will be skipped")

    print(f"Startup errors: {_startup_errors}", flush=True)
    yield
    print("=== Shutting down ===", flush=True)

app = FastAPI(title="CardioScreen AI β€” Canine Cardiac Screening", lifespan=lifespan)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/")
async def health_check():
    """Health check β€” confirms the API is running."""
    import inference
    weights_status = {
        f: os.path.exists(os.path.join(WEIGHTS_DIR, f))
        for f in WEIGHT_FILES if f.endswith(".pt")
    }
    return {
        "status": "ok",
        "service": "CardioScreen AI",
        "version": "3.0",
        "models": {
            "joint_cnn":    inference._cnn_available,
            "finetuned_cnn": inference._finetuned_available,
            "resnet18":     inference._resnet_available,
            "bigru":        inference._gru_available,
        },
        "weights": weights_status,
        "startup_errors": _startup_errors,
    }

@app.post("/analyze")
async def analyze_audio(file: UploadFile = File(...)):
    """Receives audio from the React frontend and returns screening results."""
    audio_bytes = await file.read()
    print(f"Received: {file.filename}, {len(audio_bytes)} bytes", flush=True)
    return predict_audio(audio_bytes)

if __name__ == "__main__":
    port = int(os.environ.get("PORT", 7860))
    print(f"Starting CardioScreen AI server on http://0.0.0.0:{port}")
    uvicorn.run(app, host="0.0.0.0", port=port)