"""Standalone inference example for vjepa2-echonet-ef. Predicts left-ventricular ejection fraction (EF) from an apical 4-chamber echocardiogram clip using a frozen V-JEPA 2 encoder + a small linear head. Requirements: pip install torch transformers av safetensors numpy Run: python inference_example.py /path/to/echo.avi Output: Predicted EF (continuous, %) Reduced-EF flag at the val-tuned cutoff (45.83 for HFrEF screening) DISCLAIMER: Research / triage-aid only. Not a diagnostic device. """ from __future__ import annotations import json import sys from pathlib import Path import av # type: ignore import numpy as np import torch from safetensors.torch import load_file # type: ignore from transformers import AutoModel, AutoVideoProcessor # type: ignore HERE = Path(__file__).resolve().parent CFG = json.loads((HERE / "config.json").read_text()) HEAD_PATH = HERE / "head.safetensors" VAL_TUNED_CUTOFF_HFREF = 45.83 # See MODEL_CARD.md operating-points table DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") def decode_indices(path: str, frame_indices: np.ndarray) -> np.ndarray: wanted = set(int(i) for i in frame_indices) out: dict[int, np.ndarray] = {} container = av.open(path) try: for i, frame in enumerate(container.decode(video=0)): if i in wanted: out[i] = frame.to_ndarray(format="rgb24") if len(out) == len(wanted): break finally: container.close() frames = np.stack([out[int(i)] for i in frame_indices]) if frames.ndim == 3: frames = np.repeat(frames[..., None], 3, axis=-1) return frames.astype(np.uint8) def total_frames(path: str) -> int: container = av.open(path) try: s = container.streams.video[0] n = s.frames or sum(1 for _ in container.decode(video=0)) finally: container.close() return int(n) def main(video_path: str) -> None: fpc = CFG["frames_per_clip"] n_clips = CFG["clips_per_video"] processor = AutoVideoProcessor.from_pretrained(CFG["base_model"]) encoder = AutoModel.from_pretrained(CFG["base_model"], dtype=torch.float16).eval().to(DEVICE) head = torch.nn.Linear(CFG["in_dim"], 1) sd = load_file(str(HEAD_PATH)) head.load_state_dict(sd) head.eval().to(DEVICE) n = total_frames(video_path) if n < fpc: idx_template = np.arange(fpc) % max(n, 1) clip_idx = [idx_template] * n_clips else: starts = np.linspace(0, n - fpc, n_clips).round().astype(int) clip_idx = [np.arange(s, s + fpc) for s in starts] pooled = [] with torch.inference_mode(): for idx in clip_idx: frames = decode_indices(video_path, idx) inputs = processor(videos=list(frames), return_tensors="pt") inputs = {k: v.to(DEVICE) for k, v in inputs.items()} hidden = encoder(**inputs).last_hidden_state pooled.append(hidden.float().mean(dim=1).squeeze(0)) embedding = torch.stack(pooled).mean(dim=0) ef = float(head(embedding).item()) print(f"Predicted EF (%): {ef:.2f}") print(f"Reduced-EF flag (predicted < {VAL_TUNED_CUTOFF_HFREF}): {ef < VAL_TUNED_CUTOFF_HFREF}") print() print("DISCLAIMER: Research / triage-aid only. Not a diagnostic device.") if __name__ == "__main__": if len(sys.argv) != 2: print(f"usage: {sys.argv[0]} /path/to/echo.avi", file=sys.stderr) sys.exit(1) main(sys.argv[1])