NeuroSAM3 / sample_data.py
mmrech's picture
feat: transform NeuroSAM3 into agentic neuroimaging platform
a7e0222 unverified
Raw
History Blame Contribute Delete
8.35 kB
"""
Sample Data Loader for NeuroSAM3.
Provides curated demo images (brain tumors + healthy) from an HF Dataset repo.
Images are loaded on-demand via HuggingFace Hub API — no bundling in the Space repo.
Dataset: mmrech/neurosam3-samples (private HF Dataset)
Source images from:
- Figshare: brain_tumor_dataset (3064 T1-weighted MRIs, 3 tumor types)
- Kaggle: brain-mri-scans-for-brain-tumor-classification
"""
from typing import Optional, List, Dict, Any, Tuple
import os
import tempfile
from pathlib import Path
from logger_config import logger
from config import HF_TOKEN
# Dataset configuration
SAMPLE_DATASET_REPO = "mmrech/neurosam3-samples"
SAMPLE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "neurosam3_samples")
# Sample categories available in the dataset
SAMPLE_CATEGORIES = {
"glioma": {
"description": "Glioma tumors (T1-weighted MRI)",
"count": 8,
"modality": "MRI",
"pathology": True,
},
"meningioma": {
"description": "Meningioma tumors (T1-weighted MRI)",
"count": 6,
"modality": "MRI",
"pathology": True,
},
"pituitary": {
"description": "Pituitary tumors (T1-weighted MRI)",
"count": 6,
"modality": "MRI",
"pathology": True,
},
"healthy": {
"description": "Normal brain (T1/T2 MRI)",
"count": 5,
"modality": "MRI",
"pathology": False,
},
"ct_normal": {
"description": "Normal brain CT scans",
"count": 3,
"modality": "CT",
"pathology": False,
},
"ct_hemorrhage": {
"description": "CT with intracranial hemorrhage",
"count": 2,
"modality": "CT",
"pathology": True,
},
}
def get_sample_categories() -> Dict[str, Dict[str, Any]]:
"""Get available sample categories and their metadata."""
return SAMPLE_CATEGORIES
def list_samples(category: Optional[str] = None) -> List[Dict[str, str]]:
"""
List available sample images, optionally filtered by category.
Returns:
List of dicts with 'filename', 'category', 'modality', 'description'
"""
samples = []
categories = [category] if category else list(SAMPLE_CATEGORIES.keys())
for cat in categories:
if cat not in SAMPLE_CATEGORIES:
continue
info = SAMPLE_CATEGORIES[cat]
for i in range(1, info["count"] + 1):
samples.append({
"filename": f"{cat}/{cat}_{i:03d}.png",
"category": cat,
"modality": info["modality"],
"description": f"{info['description']} (sample {i})",
"has_pathology": info["pathology"],
})
return samples
def load_sample_image(
category: str,
index: int = 1,
) -> Optional[str]:
"""
Load a sample image from the HF Dataset repo.
Args:
category: One of SAMPLE_CATEGORIES keys
index: Image index (1-based)
Returns:
Local file path to the downloaded image, or None on failure
"""
if category not in SAMPLE_CATEGORIES:
logger.error(f"Unknown category: {category}. Available: {list(SAMPLE_CATEGORIES.keys())}")
return None
info = SAMPLE_CATEGORIES[category]
if index < 1 or index > info["count"]:
logger.error(f"Index {index} out of range for {category} (1-{info['count']})")
return None
filename = f"{category}/{category}_{index:03d}.png"
# Check cache first
cache_path = os.path.join(SAMPLE_CACHE_DIR, filename)
if os.path.exists(cache_path):
return cache_path
# Download from HF Hub
try:
from huggingface_hub import hf_hub_download
local_path = hf_hub_download(
repo_id=SAMPLE_DATASET_REPO,
filename=filename,
repo_type="dataset",
token=HF_TOKEN,
cache_dir=SAMPLE_CACHE_DIR,
)
logger.info(f"Loaded sample: {filename}")
return local_path
except Exception as e:
logger.warning(f"Could not load sample from HF Hub: {e}")
# Fallback: try to generate a synthetic sample
return _generate_synthetic_sample(category, index)
def load_random_sample(
modality: Optional[str] = None,
pathology: Optional[bool] = None,
) -> Optional[Tuple[str, Dict[str, Any]]]:
"""
Load a random sample image matching criteria.
Args:
modality: Filter by "CT" or "MRI" (None for any)
pathology: Filter by pathology presence (None for any)
Returns:
Tuple of (file_path, metadata) or None
"""
import random
candidates = []
for cat, info in SAMPLE_CATEGORIES.items():
if modality and info["modality"] != modality:
continue
if pathology is not None and info["pathology"] != pathology:
continue
candidates.append(cat)
if not candidates:
return None
category = random.choice(candidates)
info = SAMPLE_CATEGORIES[category]
index = random.randint(1, info["count"])
path = load_sample_image(category, index)
if path:
return path, {
"category": category,
"index": index,
"modality": info["modality"],
"pathology": info["pathology"],
"description": info["description"],
}
return None
def load_category_batch(category: str) -> List[str]:
"""
Load all images from a category (for research pipeline demo).
Args:
category: Category name
Returns:
List of file paths
"""
if category not in SAMPLE_CATEGORIES:
return []
paths = []
info = SAMPLE_CATEGORIES[category]
for i in range(1, info["count"] + 1):
path = load_sample_image(category, i)
if path:
paths.append(path)
return paths
def _generate_synthetic_sample(category: str, index: int) -> Optional[str]:
"""
Generate a synthetic sample image as fallback when HF Dataset is unavailable.
Creates a simple grayscale image with simulated structures.
"""
try:
import numpy as np
from PIL import Image
# Create synthetic brain-like image
size = 256
img = np.zeros((size, size), dtype=np.uint8)
# Background with some texture
y, x = np.ogrid[:size, :size]
center = size // 2
# Skull (outer ellipse)
skull_mask = ((x - center)**2 / (110**2) + (y - center)**2 / (120**2)) <= 1
img[skull_mask] = 30
# Brain (inner ellipse)
brain_mask = ((x - center)**2 / (90**2) + (y - center)**2 / (100**2)) <= 1
img[brain_mask] = 80 + np.random.randint(0, 20, size=img[brain_mask].shape).astype(np.uint8)
# Add pathology if category indicates it
info = SAMPLE_CATEGORIES.get(category, {})
if info.get("pathology", False):
# Add a "tumor" blob
tumor_x = center + np.random.randint(-30, 30)
tumor_y = center + np.random.randint(-30, 30)
tumor_r = np.random.randint(15, 35)
tumor_mask = ((x - tumor_x)**2 + (y - tumor_y)**2) <= tumor_r**2
img[tumor_mask & brain_mask] = 160 + np.random.randint(0, 40, size=img[tumor_mask & brain_mask].shape).astype(np.uint8)
# Save
os.makedirs(os.path.join(SAMPLE_CACHE_DIR, category), exist_ok=True)
save_path = os.path.join(SAMPLE_CACHE_DIR, f"{category}/{category}_{index:03d}.png")
Image.fromarray(img).save(save_path)
logger.info(f"Generated synthetic sample: {save_path}")
return save_path
except Exception as e:
logger.error(f"Failed to generate synthetic sample: {e}")
return None
def get_dataset_info() -> Dict[str, Any]:
"""Get information about the sample dataset."""
total = sum(info["count"] for info in SAMPLE_CATEGORIES.values())
return {
"repo": SAMPLE_DATASET_REPO,
"total_samples": total,
"categories": list(SAMPLE_CATEGORIES.keys()),
"modalities": ["CT", "MRI"],
"sources": [
"Figshare: brain_tumor_dataset (Cheng, 2017) — 3064 T1-weighted MRIs",
"Kaggle: brain-mri-scans-for-brain-tumor-classification",
],
"note": "Images loaded on-demand from HF Hub. Synthetic fallback if unavailable.",
}