yuxbox's picture
Upload folder using huggingface_hub
a1aaf30 verified
"""
OLIVES Dataset Loader.
Adapted for the actual Zenodo OLIVES dataset structure:
data/
β”œβ”€β”€ OLIVES/OLIVES/
β”‚ β”œβ”€β”€ Prime_FULL/Prime_FULL/ (DR patients β€” OCT B-scans)
β”‚ β”‚ └── <patient_id>/<visit>/<eye>/*.png
β”‚ └── TREX_DME/TREX DME/ (DME patients β€” OCT B-scans)
β”‚ └── <arm>/<patient_id>/<visit>/<eye>/*.tif
└── OLIVES_Dataset_Labels/OLIVES_Dataset_Labels/
└── full_labels/Biomarker_Clinical_Data_Images.csv
Task: Biomarker profile ranking.
- Given an OCT B-scan, rank candidate biomarker profiles
- Each profile is a subset of the 16 annotated biomarkers
- Correct profile = actual biomarker vector for this eye
- Distractors = profiles from other eyes
Channels:
- Initial: single OCT B-scan (middle slice)
- Requestable: additional OCT slices, clinical measurements (BCVA/CST),
biomarker hints (fundus-visible subset), treatment history
"""
import csv
import hashlib
import json
import logging
import random
from pathlib import Path
from collections import defaultdict
import numpy as np
from .base import DatasetBase, MedicalCase, ChannelData
from api_client import encode_image_to_base64
import config
logger = logging.getLogger(__name__)
# The biomarker columns as they appear in the CSV
OLIVES_CSV_BIOMARKERS = {
"Fluid (IRF)": "fluid_irf",
"Fluid (SRF)": "fluid_srf",
"DRT/ME": "drt_me",
"SHRM": "shrm",
"Preretinal tissue/hemorrhage": "preretinal_tissue",
"Vitreous debris": "vitreous_debris",
"DRIL": "dril",
"Disruption of EZ": "ez_disruption",
"IR hemorrhages": "hemorrhage",
"IR HRF": "ir_hrf",
"Disruption of RPE": "rpe_disruption",
"PED (serous)": "ped_serous",
"Atrophy / thinning of retinal layers": "atrophy",
"VMT": "vmt",
"Partially attached vitreous face": "partial_vitreous",
"Fully attached vitreous face": "full_vitreous",
}
# Canonical biomarker names for profiles
OLIVES_BIOMARKERS = sorted(OLIVES_CSV_BIOMARKERS.values())
def biomarker_vector_to_profile_string(vector: dict[str, bool]) -> str:
"""Convert a biomarker dict to a human-readable profile string."""
present = [
name.replace("_", " ").title()
for name, val in sorted(vector.items()) if val
]
if not present:
return "No biomarkers detected"
return "Present biomarkers: " + ", ".join(present)
def compute_profile_distance(profile_a: dict, profile_b: dict) -> int:
"""Hamming distance between two biomarker profiles."""
dist = 0
for key in OLIVES_BIOMARKERS:
if profile_a.get(key, False) != profile_b.get(key, False):
dist += 1
return dist
def _case_rng(case_id: str) -> random.Random:
seed = int(hashlib.sha256(case_id.encode()).hexdigest()[:8], 16)
return random.Random(seed)
class OLIVESDataset(DatasetBase):
"""Loader for OLIVES ophthalmology dataset."""
def __init__(
self,
data_dir: str | Path = None,
split: str = "test",
n_candidates: int = 5,
n_oct_samples: int = 3,
):
super().__init__(data_dir or config.DATASET_PATHS["olives"], split)
self.n_candidates = n_candidates
self.n_oct_samples = n_oct_samples
def get_name(self) -> str:
return "olives"
def load(self) -> list[MedicalCase]:
logger.info(f"Loading OLIVES dataset from {self.data_dir}")
# ---- Find the CSV ----
csv_path = self._find_csv()
if csv_path is None:
logger.error("No biomarker CSV found")
return []
# ---- Load records ----
with open(csv_path, newline="", encoding="utf-8-sig") as f:
rows = list(csv.DictReader(f))
logger.info(f"Found {len(rows)} records in {csv_path.name}")
# ---- Find the image root ----
image_root = self._find_image_root()
if image_root is None:
logger.error("No image directory found")
return []
logger.info(f"Image root: {image_root}")
# ---- Group by eye ----
eye_groups = defaultdict(list)
for r in rows:
pid = r.get("Patient_ID", "")
path_str = r.get(
"Path (Trial/Arm/Folder/Visit/Eye/Image Name)", ""
)
parts = path_str.strip("/").split("/")
if len(parts) >= 5:
eye = parts[4] # OD or OS
else:
eye = r.get("Eye_ID", "unknown")
eye_key = f"{pid}_{eye}"
r["_eye_key"] = eye_key
r["_path_parts"] = parts
eye_groups[eye_key].append(r)
logger.info(f"Found {len(eye_groups)} unique eyes")
# ---- Build biomarker profiles ----
all_profiles = {}
for eye_key, records in eye_groups.items():
latest = records[-1]
all_profiles[eye_key] = self._extract_biomarker_vector(latest)
# ---- Build cases ----
self.cases = []
for eye_key, records in eye_groups.items():
case = self._build_case(
eye_key, records, all_profiles, image_root
)
if case is not None:
self.cases.append(case)
logger.info(f"Loaded {len(self.cases)} OLIVES cases")
return self.cases
def _find_csv(self) -> Path | None:
"""Find the biomarker CSV in various locations."""
search_paths = [
self.data_dir / "Biomarker_Clinical_Data_Images.csv",
self.data_dir / "OLIVES_Dataset_Labels" / "OLIVES_Dataset_Labels" / "full_labels" / "Biomarker_Clinical_Data_Images.csv",
self.data_dir.parent / "OLIVES_Dataset_Labels" / "OLIVES_Dataset_Labels" / "full_labels" / "Biomarker_Clinical_Data_Images.csv",
]
for p in search_paths:
if p.exists():
return p
# Glob fallback
csvs = list(self.data_dir.rglob("Biomarker*Clinical*.csv"))
if csvs:
return csvs[0]
# Check parent
csvs = list(self.data_dir.parent.rglob("Biomarker*Clinical*.csv"))
if csvs:
return csvs[0]
return None
def _find_image_root(self) -> Path | None:
"""Find the root directory containing Prime_FULL and TREX_DME."""
search = [
self.data_dir / "OLIVES",
self.data_dir / "OLIVES" / "OLIVES",
self.data_dir,
]
for d in search:
if (d / "Prime_FULL").exists() or (d / "TREX_DME").exists():
return d
# Search deeper
for p in self.data_dir.rglob("Prime_FULL"):
return p.parent
return None
def _extract_biomarker_vector(self, record: dict) -> dict[str, bool]:
"""Extract biomarker vector from a CSV row."""
vector = {}
for csv_col, canonical_name in OLIVES_CSV_BIOMARKERS.items():
val = record.get(csv_col, "0")
if isinstance(val, str):
vector[canonical_name] = val.strip() == "1"
else:
vector[canonical_name] = bool(int(float(val or 0)))
return vector
def _find_oct_images(
self, records: list[dict], image_root: Path, n: int = 3
) -> list[Path]:
"""Find OCT B-scan images for an eye."""
# Try to locate images from the path in the CSV
for r in records:
path_str = r.get(
"Path (Trial/Arm/Folder/Visit/Eye/Image Name)", ""
)
parts = path_str.strip("/").split("/")
if len(parts) < 5:
continue
# Construct search directory (without the image filename)
# Path format: /Trial/Arm/Patient/Visit/Eye/Image
trial = parts[0]
remaining = "/".join(parts[1:-1])
search_dirs = [
image_root / trial / remaining,
image_root / parts[0].replace(" ", "_") / remaining,
]
# For Prime: Prime_FULL/Prime_FULL/Patient/Visit/Eye/
if "Prime" in trial or "prime" in trial:
pid = parts[2] if len(parts) > 2 else ""
visit = parts[3] if len(parts) > 3 else ""
eye = parts[4] if len(parts) > 4 else ""
search_dirs.extend([
image_root / "Prime_FULL" / "Prime_FULL" / pid / visit / eye,
image_root / "Prime_FULL" / pid / visit / eye,
])
# For TREX: TREX_DME/TREX DME/Arm/Patient/Visit/Eye/
if "TREX" in trial:
arm = parts[1] if len(parts) > 1 else ""
pid = parts[2] if len(parts) > 2 else ""
visit = parts[3] if len(parts) > 3 else ""
eye = parts[4] if len(parts) > 4 else ""
search_dirs.extend([
image_root / "TREX_DME" / "TREX DME" / arm / pid / visit / eye,
image_root / "TREX_DME" / trial / arm / pid / visit / eye,
])
for d in search_dirs:
if not d.exists():
continue
images = sorted(
list(d.glob("*.png")) + list(d.glob("*.tif"))
+ list(d.glob("*.jpg"))
)
if images:
# Sample N evenly spaced scans
if len(images) <= n:
return images
indices = np.linspace(
0, len(images) - 1, n, dtype=int
)
return [images[i] for i in indices]
return []
def _build_case(
self,
eye_key: str,
records: list[dict],
all_profiles: dict[str, dict[str, bool]],
image_root: Path,
) -> MedicalCase | None:
"""Convert an eye's records into a MedicalCase."""
latest = records[-1]
# ---- Find OCT images ----
oct_images = self._find_oct_images(records, image_root, self.n_oct_samples + 1)
if not oct_images:
logger.debug(f"Skipping eye {eye_key}: no images found")
return None
# Build all available channels, then split by config
all_channels = {}
# Use middle scan as canonical first-line OCT, rest as optional extras
mid_idx = len(oct_images) // 2
initial_image = oct_images[mid_idx]
additional_images = [
img for i, img in enumerate(oct_images) if i != mid_idx
]
try:
initial_b64 = encode_image_to_base64(initial_image)
except Exception as e:
logger.debug(f"Skipping eye {eye_key}: encode failed: {e}")
return None
oct_meta = config.get_channel_definition("olives", "oct_scan")
all_channels["oct_scan"] = ChannelData(
name="oct_scan",
channel_type="image",
description="OCT B-scan showing retinal cross-section",
value=initial_b64,
image_path=initial_image,
cost=float(oct_meta.get("cost", 0.0)),
tier=oct_meta.get("tier", "unknown"),
always_given=bool(oct_meta.get("always_given", False)),
)
# Additional OCT slices
if additional_images:
try:
add_b64 = [encode_image_to_base64(p) for p in additional_images]
ch_meta = config.get_channel_definition("olives", "additional_oct")
all_channels["additional_oct"] = ChannelData(
name="additional_oct",
channel_type="image",
description="Additional OCT B-scans from different retinal locations",
value=add_b64,
cost=float(ch_meta.get("cost", 0.0)),
tier=ch_meta.get("tier", "unknown"),
always_given=bool(ch_meta.get("always_given", False)),
)
except Exception:
pass
# Clinical measurements (BCVA and CST)
bcva = latest.get("BCVA", "")
cst = latest.get("CST", "")
if bcva or cst:
parts = []
if bcva:
parts.append(f"BCVA (logMAR): {bcva}")
if cst:
parts.append(f"CST: {cst} um")
ch_meta = config.get_channel_definition("olives", "clinical_measurements")
all_channels["clinical_measurements"] = ChannelData(
name="clinical_measurements",
channel_type="text",
description="Visual acuity (BCVA) and retinal thickness (CST)",
value="; ".join(parts),
cost=float(ch_meta.get("cost", 0.0)),
tier=ch_meta.get("tier", "unknown"),
always_given=bool(ch_meta.get("always_given", False)),
)
# Biomarker hints (subset β€” only the most obvious ones)
biomarker_vec = all_profiles[eye_key]
obvious_markers = ["fluid_irf", "fluid_srf", "hemorrhage", "drt_me"]
hint_parts = []
for m in obvious_markers:
if m in biomarker_vec:
status = "Present" if biomarker_vec[m] else "Not detected"
hint_parts.append(
f"{m.replace('_', ' ').title()}: {status}"
)
if hint_parts:
ch_meta = config.get_channel_definition("olives", "biomarker_hints")
all_channels["biomarker_hints"] = ChannelData(
name="biomarker_hints",
channel_type="text",
description="Partial biomarker annotations (subset)",
value="; ".join(hint_parts),
cost=float(ch_meta.get("cost", 0.0)),
tier=ch_meta.get("tier", "unknown"),
always_given=bool(ch_meta.get("always_given", False)),
)
# Disease type hint
path_str = latest.get(
"Path (Trial/Arm/Folder/Visit/Eye/Image Name)", ""
)
disease = "DME" if "TREX" in path_str else "DR"
ch_meta = config.get_channel_definition("olives", "disease_context")
all_channels["disease_context"] = ChannelData(
name="disease_context",
channel_type="text",
description="Disease type and treatment context",
value=f"Disease: {disease}",
cost=float(ch_meta.get("cost", 0.0)),
tier=ch_meta.get("tier", "unknown"),
always_given=bool(ch_meta.get("always_given", False)),
)
initial_channels = {
name: ch for name, ch in all_channels.items() if ch.always_given
}
requestable = {
name: ch for name, ch in all_channels.items() if not ch.always_given
}
# ---- Build candidates ----
case_id = f"olives_{eye_key}"
correct_profile = biomarker_vector_to_profile_string(biomarker_vec)
candidates = self._generate_profile_candidates(
eye_key, biomarker_vec, all_profiles, case_id
)
if correct_profile not in candidates:
candidates[0] = correct_profile
rng = _case_rng(case_id)
rng.shuffle(candidates)
return MedicalCase(
case_id=case_id,
dataset="olives",
initial_channels=initial_channels,
requestable_channels=requestable,
candidates=candidates,
ground_truth=correct_profile,
ground_truth_rank=(
candidates.index(correct_profile)
if correct_profile in candidates else 0
),
metadata={
"eye_id": eye_key,
"disease": disease,
"biomarker_vector": biomarker_vec,
},
)
def _generate_profile_candidates(
self,
eye_id: str,
correct_vec: dict[str, bool],
all_profiles: dict[str, dict[str, bool]],
case_id: str,
) -> list[str]:
"""Generate biomarker profile candidates."""
n = self.n_candidates
rng = _case_rng(case_id)
correct_str = biomarker_vector_to_profile_string(correct_vec)
scored = []
for eid, vec in all_profiles.items():
if eid == eye_id:
continue
dist = compute_profile_distance(correct_vec, vec)
profile_str = biomarker_vector_to_profile_string(vec)
if profile_str != correct_str:
scored.append((dist, profile_str, vec))
scored.sort(key=lambda x: x[0])
distractors = []
if scored:
distractors.append(scored[0][1]) # Hard distractor
if len(scored) > 1:
distractors.append(scored[-1][1]) # Easy distractor
mid_pool = scored[len(scored) // 4: 3 * len(scored) // 4]
rng.shuffle(mid_pool)
for dist, prof, vec in mid_pool:
if prof not in distractors and len(distractors) < n - 1:
distractors.append(prof)
while len(distractors) < n - 1 and scored:
pick = rng.choice(scored)
if pick[1] not in distractors:
distractors.append(pick[1])
candidates = [correct_str] + distractors[:n - 1]
rng.shuffle(candidates)
return candidates