File size: 4,565 Bytes
a8ee194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
---
license: cc-by-4.0
library_name: erdes
tags:
  - ocular-ultrasound
  - medical-imaging
  - 3d-classification
  - retinal-detachment
pipeline_tag: image-classification
---

# VNET — Normal Vs Rd

Trained model weights for **retinal detachment classification (normal vs. RD)** using ocular ultrasound videos.

| Resource | Link |
|----------|------|
| Paper | [![arXiv](https://img.shields.io/badge/arXiv-2508.04735-b31b1b.svg)](https://arxiv.org/abs/2508.04735) |
| Dataset | [![HF Dataset](https://img.shields.io/badge/🤗-Dataset-yellow)](https://huggingface.co/datasets/pcvlab/erdes) [![Zenodo](https://img.shields.io/badge/Zenodo-Dataset-blue)](https://zenodo.org/records/18644370) |
| Checkpoints | [![Zenodo](https://img.shields.io/badge/Zenodo-Checkpoints-blue)](https://zenodo.org/records/18821031) |
| Code | [![GitHub](https://img.shields.io/badge/GitHub-OSUPCVLab/ERDES-black?logo=github)](https://github.com/OSUPCVLab/ERDES) |

## Model Details

| Property | Value |
|----------|-------|
| Architecture | V-Net |
| Input modality | 3D ocular ultrasound video |
| Input shape | `[1, 96, 128, 128]` (C, D, H, W) |
| Pooling | Global Average Pooling |
| Output | Binary classification (sigmoid) |

## Labels

| Label | Class |
|-------|-------|
| 0 | Normal |
| 1 | Retinal Detachment |

## Usage

```bash
pip install git+https://github.com/OSUPCVLab/ERDES.git ultralytics
```

```python
import torch
import numpy as np
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from ultralytics import YOLO
from erdes.models.components.cls_model import VNetClassifier
from erdes.data.components.utils import resize

# --- 1. Load YOLO for ocular globe detection ---
yolo = YOLO(hf_hub_download("pcvlab/yolov8_ocular_ultrasound_globe_detection", "yolov8_ocular_ultrasound_globe_detection.pt"))

# --- 2. Crop your POCUS ultrasound video using YOLO (finds largest globe bbox across all frames) ---
def crop_video(video_path, model, conf=0.8):
    # First pass: find the largest bounding box across all frames
    area_max, cropping_bbox = 0, None
    for frame in model.predict(video_path, stream=True, verbose=False, conf=conf):
        if len(frame.boxes.xywhn):
            bbox = frame.boxes.xywhn[0].cpu().numpy()
            area = bbox[2] * bbox[3]
            if area > area_max:
                area_max, cropping_bbox = area, bbox

    if cropping_bbox is None:
        raise ValueError("YOLO could not detect ocular globe in video.")

    # Second pass: crop every frame with the largest bbox
    frames = []
    for frame in model.predict(video_path, stream=True, verbose=False, conf=conf):
        img = frame.orig_img                                    # [H, W, C] BGR
        h, w, _ = img.shape
        x_c, y_c, bw, bh = cropping_bbox
        x1, y1 = int((x_c - bw/2) * w), int((y_c - bh/2) * h)
        x2, y2 = int((x_c + bw/2) * w), int((y_c + bh/2) * h)
        frames.append(img[y1:y2, x1:x2])

    return np.stack(frames)                                     # [D, H, W, C]

frames = crop_video("your_video.mp4", yolo)                    # [D, H, W, C]

# --- 3. Preprocess ---
video = torch.from_numpy(frames).float()                       # [D, H, W, C]
video = video.permute(3, 0, 1, 2)                              # [C, D, H, W]
if video.shape[0] == 3:
    video = video.mean(dim=0, keepdim=True)                    # grayscale [1, D, H, W]
video = resize((96, 128, 128))(video) / 255.0                  # pad + resize + normalize
video = video.unsqueeze(0)                                      # [1, 1, 96, 128, 128]

# --- 4. Load model and run inference ---
model = VNetClassifier(in_channels=1, num_classes=1, pooling="avg")
weights = load_file(hf_hub_download("pcvlab/vnet_normal_vs_rd", "model.safetensors"))
model.load_state_dict(weights)
model.eval()

with torch.no_grad():
    logit = model(video)
    prob = torch.sigmoid(logit).item()
    pred = int(prob > 0.5)

labels = {'0': 'Normal', '1': 'Retinal Detachment'}
print(f"Prediction: {labels[str(pred)]} (confidence: {prob:.3f})")
```

## Citation

If you use this model, please cite the ERDES paper:

```bibtex
@misc{ozkut2026erdes,
  title={ERDES: A Benchmark Video Dataset for Retinal Detachment and Macular Status Classification in Ocular Ultrasound},
  author={Yasemin Ozkut and Pouyan Navard and Srikar Adhikari and Elaine Situ-LaCasse and Josie Acu{\~n}a and Adrienne Yarnish and Alper Yilmaz},
  year={2026},
  eprint={2508.04735},
  archivePrefix={arXiv},
  primaryClass={cs.CV},
  url={https://arxiv.org/abs/2508.04735}
}
```