vjepa2-echonet-ef / inference_example.py
vselvarajijay's picture
Upload folder using huggingface_hub
7ed3ea7 verified
"""Standalone inference example for vjepa2-echonet-ef.
Predicts left-ventricular ejection fraction (EF) from an apical 4-chamber
echocardiogram clip using a frozen V-JEPA 2 encoder + a small linear head.
Requirements:
pip install torch transformers av safetensors numpy
Run:
python inference_example.py /path/to/echo.avi
Output:
Predicted EF (continuous, %)
Reduced-EF flag at the val-tuned cutoff (45.83 for HFrEF screening)
DISCLAIMER: Research / triage-aid only. Not a diagnostic device.
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
import av # type: ignore
import numpy as np
import torch
from safetensors.torch import load_file # type: ignore
from transformers import AutoModel, AutoVideoProcessor # type: ignore
HERE = Path(__file__).resolve().parent
CFG = json.loads((HERE / "config.json").read_text())
HEAD_PATH = HERE / "head.safetensors"
VAL_TUNED_CUTOFF_HFREF = 45.83 # See MODEL_CARD.md operating-points table
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def decode_indices(path: str, frame_indices: np.ndarray) -> np.ndarray:
wanted = set(int(i) for i in frame_indices)
out: dict[int, np.ndarray] = {}
container = av.open(path)
try:
for i, frame in enumerate(container.decode(video=0)):
if i in wanted:
out[i] = frame.to_ndarray(format="rgb24")
if len(out) == len(wanted):
break
finally:
container.close()
frames = np.stack([out[int(i)] for i in frame_indices])
if frames.ndim == 3:
frames = np.repeat(frames[..., None], 3, axis=-1)
return frames.astype(np.uint8)
def total_frames(path: str) -> int:
container = av.open(path)
try:
s = container.streams.video[0]
n = s.frames or sum(1 for _ in container.decode(video=0))
finally:
container.close()
return int(n)
def main(video_path: str) -> None:
fpc = CFG["frames_per_clip"]
n_clips = CFG["clips_per_video"]
processor = AutoVideoProcessor.from_pretrained(CFG["base_model"])
encoder = AutoModel.from_pretrained(CFG["base_model"], dtype=torch.float16).eval().to(DEVICE)
head = torch.nn.Linear(CFG["in_dim"], 1)
sd = load_file(str(HEAD_PATH))
head.load_state_dict(sd)
head.eval().to(DEVICE)
n = total_frames(video_path)
if n < fpc:
idx_template = np.arange(fpc) % max(n, 1)
clip_idx = [idx_template] * n_clips
else:
starts = np.linspace(0, n - fpc, n_clips).round().astype(int)
clip_idx = [np.arange(s, s + fpc) for s in starts]
pooled = []
with torch.inference_mode():
for idx in clip_idx:
frames = decode_indices(video_path, idx)
inputs = processor(videos=list(frames), return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
hidden = encoder(**inputs).last_hidden_state
pooled.append(hidden.float().mean(dim=1).squeeze(0))
embedding = torch.stack(pooled).mean(dim=0)
ef = float(head(embedding).item())
print(f"Predicted EF (%): {ef:.2f}")
print(f"Reduced-EF flag (predicted < {VAL_TUNED_CUTOFF_HFREF}): {ef < VAL_TUNED_CUTOFF_HFREF}")
print()
print("DISCLAIMER: Research / triage-aid only. Not a diagnostic device.")
if __name__ == "__main__":
if len(sys.argv) != 2:
print(f"usage: {sys.argv[0]} /path/to/echo.avi", file=sys.stderr)
sys.exit(1)
main(sys.argv[1])