| from __future__ import annotations |
|
|
| import os |
| import tempfile |
| import traceback |
| from typing import List |
|
|
| import numpy as np |
| from fastapi import FastAPI, File, UploadFile, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.staticfiles import StaticFiles |
|
|
| import cv2 |
| import joblib |
| import tensorflow as tf |
| from tensorflow.keras.applications.resnet import preprocess_input |
|
|
|
|
| |
| |
| |
| MODEL_PATH = os.getenv("MODEL_PATH", "/app/model.keras") |
| SCALER_PATH = os.getenv("SCALER_PATH", "/app/scaler.save") |
| STATIC_DIR = os.getenv("STATIC_DIR", "/app/static") |
|
|
| _model = None |
| _scaler = None |
|
|
|
|
| def get_model(): |
| global _model |
| if _model is None: |
| if not os.path.exists(MODEL_PATH): |
| raise RuntimeError( |
| f"Model file not found at {MODEL_PATH}. " |
| "Place your Keras model at /app/model.keras or set MODEL_PATH." |
| ) |
| |
| _model = tf.keras.models.load_model(MODEL_PATH, compile=False) |
| return _model |
|
|
|
|
| def get_scaler(): |
| global _scaler |
| if _scaler is None: |
| if not os.path.exists(SCALER_PATH): |
| raise RuntimeError( |
| f"Scaler file not found at {SCALER_PATH}. " |
| "Place scaler.save at /app/scaler.save or set SCALER_PATH." |
| ) |
| _scaler = joblib.load(SCALER_PATH) |
| return _scaler |
|
|
|
|
| |
| |
| |
| def load_data(image_path: str) -> tf.Tensor: |
| image = tf.io.read_file(image_path) |
| image = tf.io.decode_png(image, channels=3) |
| image = tf.image.resize(image, [224, 224], method="bilinear") |
| image = tf.cast(image, tf.float32) |
| image = preprocess_input(image) |
| return image |
|
|
|
|
| |
| |
| |
| def extract_frames_to_pngs(video_bytes: bytes, max_frames: int = 300) -> List[str]: |
| """Decode video bytes with OpenCV and write frames as PNGs to a temp dir. |
| Returns a list of PNG file paths. |
| """ |
| tmpdir = tempfile.mkdtemp(prefix="frames_") |
| video_path = os.path.join(tmpdir, "input.mp4") |
|
|
| with open(video_path, "wb") as f: |
| f.write(video_bytes) |
|
|
| cap = cv2.VideoCapture(video_path) |
| if not cap.isOpened(): |
| raise ValueError("Could not open uploaded video. (Unsupported codec/container?)") |
|
|
| paths: List[str] = [] |
| idx = 0 |
| while idx < max_frames: |
| ok, frame = cap.read() |
| if not ok: |
| break |
|
|
| out_path = os.path.join(tmpdir, f"frame_{idx:05d}.png") |
| cv2.imwrite(out_path, frame) |
| paths.append(out_path) |
| idx += 1 |
|
|
| cap.release() |
|
|
| if not paths: |
| raise ValueError("No frames extracted from video.") |
| return paths |
|
|
|
|
| |
| |
| |
| def moving_average(x: np.ndarray, window: int = 7) -> np.ndarray: |
| if window <= 1: |
| return x |
| window = int(max(1, min(window, x.shape[0]))) |
| kernel = np.ones(window, dtype=np.float32) / float(window) |
| pad = window // 2 |
| xpad = np.pad(x.astype(np.float32), (pad, pad), mode="edge") |
| return np.convolve(xpad, kernel, mode="valid") |
|
|
|
|
| def compute_ef(edv: float, esv: float) -> float: |
| if edv <= 0: |
| return float("nan") |
| return float((edv - esv) / edv * 100.0) |
|
|
|
|
| def classify_heart_function(ef: float) -> str: |
| if not np.isfinite(ef): |
| return "heart failure" |
| if ef >= 55.0: |
| return "normal" |
| if ef >= 40.0: |
| return "mildly dysfunction" |
| return "heart failure" |
|
|
|
|
| def _normalize_model_output(raw, n_frames: int) -> np.ndarray: |
| """ |
| Normalize model.predict output to shape (N, 1) float array suitable for scaler.inverse_transform. |
| Handles models that return: |
| - single array: (N,), (N,1), (N,k), (N,1,1), etc. |
| - list/tuple of arrays (multi-output) |
| """ |
| |
| if isinstance(raw, (list, tuple)): |
| shapes = [np.asarray(x).shape for x in raw] |
| print("PRED LIST SHAPES:", shapes) |
|
|
| chosen = None |
| for r in raw: |
| r_arr = np.asarray(r) |
| if r_arr.ndim >= 1 and r_arr.shape[0] == n_frames: |
| chosen = r_arr |
| break |
| if chosen is None: |
| chosen = np.asarray(raw[0]) |
| raw_arr = chosen |
| else: |
| raw_arr = np.asarray(raw) |
| print("PRED SHAPE:", raw_arr.shape) |
|
|
| raw_arr = np.asarray(raw_arr) |
|
|
| |
| if raw_arr.ndim == 1: |
| raw_arr = raw_arr.reshape(-1, 1) |
| elif raw_arr.ndim == 2: |
| if raw_arr.shape[0] != n_frames: |
| |
| if raw_arr.shape[1] == n_frames: |
| raw_arr = raw_arr.T |
| |
| if raw_arr.shape[1] != 1: |
| raw_arr = raw_arr[:, :1] |
| else: |
| |
| raw_arr = raw_arr.reshape(raw_arr.shape[0], -1) |
| raw_arr = raw_arr[:, :1] |
|
|
| if raw_arr.shape[0] != n_frames: |
| raise ValueError(f"Prediction length mismatch: got {raw_arr.shape[0]} but expected {n_frames}") |
|
|
| return raw_arr.astype(np.float32) |
|
|
|
|
| |
| |
| |
| app = FastAPI() |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| @app.post("/api/analyze") |
| async def analyze(video: UploadFile = File(...)): |
| video_bytes = await video.read() |
| if not video_bytes: |
| raise HTTPException(status_code=400, detail="Empty video upload.") |
|
|
| try: |
| frame_paths = extract_frames_to_pngs( |
| video_bytes, |
| max_frames=int(os.getenv("MAX_FRAMES", "300")), |
| ) |
|
|
| |
| batch = tf.stack([load_data(p) for p in frame_paths], axis=0) |
|
|
| model = get_model() |
| raw_preds = model.predict(batch, verbose=0) |
|
|
| preds_np = _normalize_model_output(raw_preds, n_frames=batch.shape[0]) |
|
|
| scaler = get_scaler() |
| values = scaler.inverse_transform(preds_np).reshape(-1).astype(np.float32) |
|
|
| |
| smooth_window = int(os.getenv("SMOOTH_WINDOW", "7")) |
| smooth = moving_average(values, window=smooth_window) |
|
|
| edv = float(np.max(smooth)) |
| esv = float(np.min(smooth)) |
| ef = compute_ef(edv, esv) |
| heart_fn = classify_heart_function(ef) |
|
|
| return { |
| "ejectionFraction": round(float(ef), 1), |
| "heartFunction": heart_fn, |
| "edv": round(edv, 2), |
| "esv": round(esv, 2), |
| "numFrames": int(values.shape[0]), |
| } |
|
|
| except HTTPException: |
| raise |
| except Exception as e: |
| print("ANALYZE ERROR TRACEBACK:\n", traceback.format_exc()) |
| raise HTTPException(status_code=500, detail=f"Inference error: {e}") |
|
|
|
|
| |
| if os.path.isdir(STATIC_DIR): |
| app.mount("/", StaticFiles(directory=STATIC_DIR, html=True), name="static") |
|
|