Spaces:
Sleeping
Sleeping
| """ | |
| OLIVES Dataset Loader. | |
| Adapted for the actual Zenodo OLIVES dataset structure: | |
| data/ | |
| βββ OLIVES/OLIVES/ | |
| β βββ Prime_FULL/Prime_FULL/ (DR patients β OCT B-scans) | |
| β β βββ <patient_id>/<visit>/<eye>/*.png | |
| β βββ TREX_DME/TREX DME/ (DME patients β OCT B-scans) | |
| β βββ <arm>/<patient_id>/<visit>/<eye>/*.tif | |
| βββ OLIVES_Dataset_Labels/OLIVES_Dataset_Labels/ | |
| βββ full_labels/Biomarker_Clinical_Data_Images.csv | |
| Task: Biomarker profile ranking. | |
| - Given an OCT B-scan, rank candidate biomarker profiles | |
| - Each profile is a subset of the 16 annotated biomarkers | |
| - Correct profile = actual biomarker vector for this eye | |
| - Distractors = profiles from other eyes | |
| Channels: | |
| - Initial: single OCT B-scan (middle slice) | |
| - Requestable: additional OCT slices, clinical measurements (BCVA/CST), | |
| biomarker hints (fundus-visible subset), treatment history | |
| """ | |
| import csv | |
| import hashlib | |
| import json | |
| import logging | |
| import random | |
| from pathlib import Path | |
| from collections import defaultdict | |
| import numpy as np | |
| from .base import DatasetBase, MedicalCase, ChannelData | |
| from api_client import encode_image_to_base64 | |
| import config | |
| logger = logging.getLogger(__name__) | |
| # The biomarker columns as they appear in the CSV | |
| OLIVES_CSV_BIOMARKERS = { | |
| "Fluid (IRF)": "fluid_irf", | |
| "Fluid (SRF)": "fluid_srf", | |
| "DRT/ME": "drt_me", | |
| "SHRM": "shrm", | |
| "Preretinal tissue/hemorrhage": "preretinal_tissue", | |
| "Vitreous debris": "vitreous_debris", | |
| "DRIL": "dril", | |
| "Disruption of EZ": "ez_disruption", | |
| "IR hemorrhages": "hemorrhage", | |
| "IR HRF": "ir_hrf", | |
| "Disruption of RPE": "rpe_disruption", | |
| "PED (serous)": "ped_serous", | |
| "Atrophy / thinning of retinal layers": "atrophy", | |
| "VMT": "vmt", | |
| "Partially attached vitreous face": "partial_vitreous", | |
| "Fully attached vitreous face": "full_vitreous", | |
| } | |
| # Canonical biomarker names for profiles | |
| OLIVES_BIOMARKERS = sorted(OLIVES_CSV_BIOMARKERS.values()) | |
| def biomarker_vector_to_profile_string(vector: dict[str, bool]) -> str: | |
| """Convert a biomarker dict to a human-readable profile string.""" | |
| present = [ | |
| name.replace("_", " ").title() | |
| for name, val in sorted(vector.items()) if val | |
| ] | |
| if not present: | |
| return "No biomarkers detected" | |
| return "Present biomarkers: " + ", ".join(present) | |
| def compute_profile_distance(profile_a: dict, profile_b: dict) -> int: | |
| """Hamming distance between two biomarker profiles.""" | |
| dist = 0 | |
| for key in OLIVES_BIOMARKERS: | |
| if profile_a.get(key, False) != profile_b.get(key, False): | |
| dist += 1 | |
| return dist | |
| def _case_rng(case_id: str) -> random.Random: | |
| seed = int(hashlib.sha256(case_id.encode()).hexdigest()[:8], 16) | |
| return random.Random(seed) | |
| class OLIVESDataset(DatasetBase): | |
| """Loader for OLIVES ophthalmology dataset.""" | |
| def __init__( | |
| self, | |
| data_dir: str | Path = None, | |
| split: str = "test", | |
| n_candidates: int = 5, | |
| n_oct_samples: int = 3, | |
| ): | |
| super().__init__(data_dir or config.DATASET_PATHS["olives"], split) | |
| self.n_candidates = n_candidates | |
| self.n_oct_samples = n_oct_samples | |
| def get_name(self) -> str: | |
| return "olives" | |
| def load(self) -> list[MedicalCase]: | |
| logger.info(f"Loading OLIVES dataset from {self.data_dir}") | |
| # ---- Find the CSV ---- | |
| csv_path = self._find_csv() | |
| if csv_path is None: | |
| logger.error("No biomarker CSV found") | |
| return [] | |
| # ---- Load records ---- | |
| with open(csv_path, newline="", encoding="utf-8-sig") as f: | |
| rows = list(csv.DictReader(f)) | |
| logger.info(f"Found {len(rows)} records in {csv_path.name}") | |
| # ---- Find the image root ---- | |
| image_root = self._find_image_root() | |
| if image_root is None: | |
| logger.error("No image directory found") | |
| return [] | |
| logger.info(f"Image root: {image_root}") | |
| # ---- Group by eye ---- | |
| eye_groups = defaultdict(list) | |
| for r in rows: | |
| pid = r.get("Patient_ID", "") | |
| path_str = r.get( | |
| "Path (Trial/Arm/Folder/Visit/Eye/Image Name)", "" | |
| ) | |
| parts = path_str.strip("/").split("/") | |
| if len(parts) >= 5: | |
| eye = parts[4] # OD or OS | |
| else: | |
| eye = r.get("Eye_ID", "unknown") | |
| eye_key = f"{pid}_{eye}" | |
| r["_eye_key"] = eye_key | |
| r["_path_parts"] = parts | |
| eye_groups[eye_key].append(r) | |
| logger.info(f"Found {len(eye_groups)} unique eyes") | |
| # ---- Build biomarker profiles ---- | |
| all_profiles = {} | |
| for eye_key, records in eye_groups.items(): | |
| latest = records[-1] | |
| all_profiles[eye_key] = self._extract_biomarker_vector(latest) | |
| # ---- Build cases ---- | |
| self.cases = [] | |
| for eye_key, records in eye_groups.items(): | |
| case = self._build_case( | |
| eye_key, records, all_profiles, image_root | |
| ) | |
| if case is not None: | |
| self.cases.append(case) | |
| logger.info(f"Loaded {len(self.cases)} OLIVES cases") | |
| return self.cases | |
| def _find_csv(self) -> Path | None: | |
| """Find the biomarker CSV in various locations.""" | |
| search_paths = [ | |
| self.data_dir / "Biomarker_Clinical_Data_Images.csv", | |
| self.data_dir / "OLIVES_Dataset_Labels" / "OLIVES_Dataset_Labels" / "full_labels" / "Biomarker_Clinical_Data_Images.csv", | |
| self.data_dir.parent / "OLIVES_Dataset_Labels" / "OLIVES_Dataset_Labels" / "full_labels" / "Biomarker_Clinical_Data_Images.csv", | |
| ] | |
| for p in search_paths: | |
| if p.exists(): | |
| return p | |
| # Glob fallback | |
| csvs = list(self.data_dir.rglob("Biomarker*Clinical*.csv")) | |
| if csvs: | |
| return csvs[0] | |
| # Check parent | |
| csvs = list(self.data_dir.parent.rglob("Biomarker*Clinical*.csv")) | |
| if csvs: | |
| return csvs[0] | |
| return None | |
| def _find_image_root(self) -> Path | None: | |
| """Find the root directory containing Prime_FULL and TREX_DME.""" | |
| search = [ | |
| self.data_dir / "OLIVES", | |
| self.data_dir / "OLIVES" / "OLIVES", | |
| self.data_dir, | |
| ] | |
| for d in search: | |
| if (d / "Prime_FULL").exists() or (d / "TREX_DME").exists(): | |
| return d | |
| # Search deeper | |
| for p in self.data_dir.rglob("Prime_FULL"): | |
| return p.parent | |
| return None | |
| def _extract_biomarker_vector(self, record: dict) -> dict[str, bool]: | |
| """Extract biomarker vector from a CSV row.""" | |
| vector = {} | |
| for csv_col, canonical_name in OLIVES_CSV_BIOMARKERS.items(): | |
| val = record.get(csv_col, "0") | |
| if isinstance(val, str): | |
| vector[canonical_name] = val.strip() == "1" | |
| else: | |
| vector[canonical_name] = bool(int(float(val or 0))) | |
| return vector | |
| def _find_oct_images( | |
| self, records: list[dict], image_root: Path, n: int = 3 | |
| ) -> list[Path]: | |
| """Find OCT B-scan images for an eye.""" | |
| # Try to locate images from the path in the CSV | |
| for r in records: | |
| path_str = r.get( | |
| "Path (Trial/Arm/Folder/Visit/Eye/Image Name)", "" | |
| ) | |
| parts = path_str.strip("/").split("/") | |
| if len(parts) < 5: | |
| continue | |
| # Construct search directory (without the image filename) | |
| # Path format: /Trial/Arm/Patient/Visit/Eye/Image | |
| trial = parts[0] | |
| remaining = "/".join(parts[1:-1]) | |
| search_dirs = [ | |
| image_root / trial / remaining, | |
| image_root / parts[0].replace(" ", "_") / remaining, | |
| ] | |
| # For Prime: Prime_FULL/Prime_FULL/Patient/Visit/Eye/ | |
| if "Prime" in trial or "prime" in trial: | |
| pid = parts[2] if len(parts) > 2 else "" | |
| visit = parts[3] if len(parts) > 3 else "" | |
| eye = parts[4] if len(parts) > 4 else "" | |
| search_dirs.extend([ | |
| image_root / "Prime_FULL" / "Prime_FULL" / pid / visit / eye, | |
| image_root / "Prime_FULL" / pid / visit / eye, | |
| ]) | |
| # For TREX: TREX_DME/TREX DME/Arm/Patient/Visit/Eye/ | |
| if "TREX" in trial: | |
| arm = parts[1] if len(parts) > 1 else "" | |
| pid = parts[2] if len(parts) > 2 else "" | |
| visit = parts[3] if len(parts) > 3 else "" | |
| eye = parts[4] if len(parts) > 4 else "" | |
| search_dirs.extend([ | |
| image_root / "TREX_DME" / "TREX DME" / arm / pid / visit / eye, | |
| image_root / "TREX_DME" / trial / arm / pid / visit / eye, | |
| ]) | |
| for d in search_dirs: | |
| if not d.exists(): | |
| continue | |
| images = sorted( | |
| list(d.glob("*.png")) + list(d.glob("*.tif")) | |
| + list(d.glob("*.jpg")) | |
| ) | |
| if images: | |
| # Sample N evenly spaced scans | |
| if len(images) <= n: | |
| return images | |
| indices = np.linspace( | |
| 0, len(images) - 1, n, dtype=int | |
| ) | |
| return [images[i] for i in indices] | |
| return [] | |
| def _build_case( | |
| self, | |
| eye_key: str, | |
| records: list[dict], | |
| all_profiles: dict[str, dict[str, bool]], | |
| image_root: Path, | |
| ) -> MedicalCase | None: | |
| """Convert an eye's records into a MedicalCase.""" | |
| latest = records[-1] | |
| # ---- Find OCT images ---- | |
| oct_images = self._find_oct_images(records, image_root, self.n_oct_samples + 1) | |
| if not oct_images: | |
| logger.debug(f"Skipping eye {eye_key}: no images found") | |
| return None | |
| # Build all available channels, then split by config | |
| all_channels = {} | |
| # Use middle scan as canonical first-line OCT, rest as optional extras | |
| mid_idx = len(oct_images) // 2 | |
| initial_image = oct_images[mid_idx] | |
| additional_images = [ | |
| img for i, img in enumerate(oct_images) if i != mid_idx | |
| ] | |
| try: | |
| initial_b64 = encode_image_to_base64(initial_image) | |
| except Exception as e: | |
| logger.debug(f"Skipping eye {eye_key}: encode failed: {e}") | |
| return None | |
| oct_meta = config.get_channel_definition("olives", "oct_scan") | |
| all_channels["oct_scan"] = ChannelData( | |
| name="oct_scan", | |
| channel_type="image", | |
| description="OCT B-scan showing retinal cross-section", | |
| value=initial_b64, | |
| image_path=initial_image, | |
| cost=float(oct_meta.get("cost", 0.0)), | |
| tier=oct_meta.get("tier", "unknown"), | |
| always_given=bool(oct_meta.get("always_given", False)), | |
| ) | |
| # Additional OCT slices | |
| if additional_images: | |
| try: | |
| add_b64 = [encode_image_to_base64(p) for p in additional_images] | |
| ch_meta = config.get_channel_definition("olives", "additional_oct") | |
| all_channels["additional_oct"] = ChannelData( | |
| name="additional_oct", | |
| channel_type="image", | |
| description="Additional OCT B-scans from different retinal locations", | |
| value=add_b64, | |
| cost=float(ch_meta.get("cost", 0.0)), | |
| tier=ch_meta.get("tier", "unknown"), | |
| always_given=bool(ch_meta.get("always_given", False)), | |
| ) | |
| except Exception: | |
| pass | |
| # Clinical measurements (BCVA and CST) | |
| bcva = latest.get("BCVA", "") | |
| cst = latest.get("CST", "") | |
| if bcva or cst: | |
| parts = [] | |
| if bcva: | |
| parts.append(f"BCVA (logMAR): {bcva}") | |
| if cst: | |
| parts.append(f"CST: {cst} um") | |
| ch_meta = config.get_channel_definition("olives", "clinical_measurements") | |
| all_channels["clinical_measurements"] = ChannelData( | |
| name="clinical_measurements", | |
| channel_type="text", | |
| description="Visual acuity (BCVA) and retinal thickness (CST)", | |
| value="; ".join(parts), | |
| cost=float(ch_meta.get("cost", 0.0)), | |
| tier=ch_meta.get("tier", "unknown"), | |
| always_given=bool(ch_meta.get("always_given", False)), | |
| ) | |
| # Biomarker hints (subset β only the most obvious ones) | |
| biomarker_vec = all_profiles[eye_key] | |
| obvious_markers = ["fluid_irf", "fluid_srf", "hemorrhage", "drt_me"] | |
| hint_parts = [] | |
| for m in obvious_markers: | |
| if m in biomarker_vec: | |
| status = "Present" if biomarker_vec[m] else "Not detected" | |
| hint_parts.append( | |
| f"{m.replace('_', ' ').title()}: {status}" | |
| ) | |
| if hint_parts: | |
| ch_meta = config.get_channel_definition("olives", "biomarker_hints") | |
| all_channels["biomarker_hints"] = ChannelData( | |
| name="biomarker_hints", | |
| channel_type="text", | |
| description="Partial biomarker annotations (subset)", | |
| value="; ".join(hint_parts), | |
| cost=float(ch_meta.get("cost", 0.0)), | |
| tier=ch_meta.get("tier", "unknown"), | |
| always_given=bool(ch_meta.get("always_given", False)), | |
| ) | |
| # Disease type hint | |
| path_str = latest.get( | |
| "Path (Trial/Arm/Folder/Visit/Eye/Image Name)", "" | |
| ) | |
| disease = "DME" if "TREX" in path_str else "DR" | |
| ch_meta = config.get_channel_definition("olives", "disease_context") | |
| all_channels["disease_context"] = ChannelData( | |
| name="disease_context", | |
| channel_type="text", | |
| description="Disease type and treatment context", | |
| value=f"Disease: {disease}", | |
| cost=float(ch_meta.get("cost", 0.0)), | |
| tier=ch_meta.get("tier", "unknown"), | |
| always_given=bool(ch_meta.get("always_given", False)), | |
| ) | |
| initial_channels = { | |
| name: ch for name, ch in all_channels.items() if ch.always_given | |
| } | |
| requestable = { | |
| name: ch for name, ch in all_channels.items() if not ch.always_given | |
| } | |
| # ---- Build candidates ---- | |
| case_id = f"olives_{eye_key}" | |
| correct_profile = biomarker_vector_to_profile_string(biomarker_vec) | |
| candidates = self._generate_profile_candidates( | |
| eye_key, biomarker_vec, all_profiles, case_id | |
| ) | |
| if correct_profile not in candidates: | |
| candidates[0] = correct_profile | |
| rng = _case_rng(case_id) | |
| rng.shuffle(candidates) | |
| return MedicalCase( | |
| case_id=case_id, | |
| dataset="olives", | |
| initial_channels=initial_channels, | |
| requestable_channels=requestable, | |
| candidates=candidates, | |
| ground_truth=correct_profile, | |
| ground_truth_rank=( | |
| candidates.index(correct_profile) | |
| if correct_profile in candidates else 0 | |
| ), | |
| metadata={ | |
| "eye_id": eye_key, | |
| "disease": disease, | |
| "biomarker_vector": biomarker_vec, | |
| }, | |
| ) | |
| def _generate_profile_candidates( | |
| self, | |
| eye_id: str, | |
| correct_vec: dict[str, bool], | |
| all_profiles: dict[str, dict[str, bool]], | |
| case_id: str, | |
| ) -> list[str]: | |
| """Generate biomarker profile candidates.""" | |
| n = self.n_candidates | |
| rng = _case_rng(case_id) | |
| correct_str = biomarker_vector_to_profile_string(correct_vec) | |
| scored = [] | |
| for eid, vec in all_profiles.items(): | |
| if eid == eye_id: | |
| continue | |
| dist = compute_profile_distance(correct_vec, vec) | |
| profile_str = biomarker_vector_to_profile_string(vec) | |
| if profile_str != correct_str: | |
| scored.append((dist, profile_str, vec)) | |
| scored.sort(key=lambda x: x[0]) | |
| distractors = [] | |
| if scored: | |
| distractors.append(scored[0][1]) # Hard distractor | |
| if len(scored) > 1: | |
| distractors.append(scored[-1][1]) # Easy distractor | |
| mid_pool = scored[len(scored) // 4: 3 * len(scored) // 4] | |
| rng.shuffle(mid_pool) | |
| for dist, prof, vec in mid_pool: | |
| if prof not in distractors and len(distractors) < n - 1: | |
| distractors.append(prof) | |
| while len(distractors) < n - 1 and scored: | |
| pick = rng.choice(scored) | |
| if pick[1] not in distractors: | |
| distractors.append(pick[1]) | |
| candidates = [correct_str] + distractors[:n - 1] | |
| rng.shuffle(candidates) | |
| return candidates | |