| """Standalone inference example for vjepa2-echonet-ef. |
| |
| Predicts left-ventricular ejection fraction (EF) from an apical 4-chamber |
| echocardiogram clip using a frozen V-JEPA 2 encoder + a small linear head. |
| |
| Requirements: |
| pip install torch transformers av safetensors numpy |
| |
| Run: |
| python inference_example.py /path/to/echo.avi |
| |
| Output: |
| Predicted EF (continuous, %) |
| Reduced-EF flag at the val-tuned cutoff (45.83 for HFrEF screening) |
| |
| DISCLAIMER: Research / triage-aid only. Not a diagnostic device. |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import sys |
| from pathlib import Path |
|
|
| import av |
| import numpy as np |
| import torch |
| from safetensors.torch import load_file |
| from transformers import AutoModel, AutoVideoProcessor |
|
|
|
|
| HERE = Path(__file__).resolve().parent |
| CFG = json.loads((HERE / "config.json").read_text()) |
| HEAD_PATH = HERE / "head.safetensors" |
|
|
| VAL_TUNED_CUTOFF_HFREF = 45.83 |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| def decode_indices(path: str, frame_indices: np.ndarray) -> np.ndarray: |
| wanted = set(int(i) for i in frame_indices) |
| out: dict[int, np.ndarray] = {} |
| container = av.open(path) |
| try: |
| for i, frame in enumerate(container.decode(video=0)): |
| if i in wanted: |
| out[i] = frame.to_ndarray(format="rgb24") |
| if len(out) == len(wanted): |
| break |
| finally: |
| container.close() |
| frames = np.stack([out[int(i)] for i in frame_indices]) |
| if frames.ndim == 3: |
| frames = np.repeat(frames[..., None], 3, axis=-1) |
| return frames.astype(np.uint8) |
|
|
|
|
| def total_frames(path: str) -> int: |
| container = av.open(path) |
| try: |
| s = container.streams.video[0] |
| n = s.frames or sum(1 for _ in container.decode(video=0)) |
| finally: |
| container.close() |
| return int(n) |
|
|
|
|
| def main(video_path: str) -> None: |
| fpc = CFG["frames_per_clip"] |
| n_clips = CFG["clips_per_video"] |
|
|
| processor = AutoVideoProcessor.from_pretrained(CFG["base_model"]) |
| encoder = AutoModel.from_pretrained(CFG["base_model"], dtype=torch.float16).eval().to(DEVICE) |
|
|
| head = torch.nn.Linear(CFG["in_dim"], 1) |
| sd = load_file(str(HEAD_PATH)) |
| head.load_state_dict(sd) |
| head.eval().to(DEVICE) |
|
|
| n = total_frames(video_path) |
| if n < fpc: |
| idx_template = np.arange(fpc) % max(n, 1) |
| clip_idx = [idx_template] * n_clips |
| else: |
| starts = np.linspace(0, n - fpc, n_clips).round().astype(int) |
| clip_idx = [np.arange(s, s + fpc) for s in starts] |
|
|
| pooled = [] |
| with torch.inference_mode(): |
| for idx in clip_idx: |
| frames = decode_indices(video_path, idx) |
| inputs = processor(videos=list(frames), return_tensors="pt") |
| inputs = {k: v.to(DEVICE) for k, v in inputs.items()} |
| hidden = encoder(**inputs).last_hidden_state |
| pooled.append(hidden.float().mean(dim=1).squeeze(0)) |
| embedding = torch.stack(pooled).mean(dim=0) |
| ef = float(head(embedding).item()) |
|
|
| print(f"Predicted EF (%): {ef:.2f}") |
| print(f"Reduced-EF flag (predicted < {VAL_TUNED_CUTOFF_HFREF}): {ef < VAL_TUNED_CUTOFF_HFREF}") |
| print() |
| print("DISCLAIMER: Research / triage-aid only. Not a diagnostic device.") |
|
|
|
|
| if __name__ == "__main__": |
| if len(sys.argv) != 2: |
| print(f"usage: {sys.argv[0]} /path/to/echo.avi", file=sys.stderr) |
| sys.exit(1) |
| main(sys.argv[1]) |
|
|