| """ |
| Sample Data Loader for NeuroSAM3. |
| |
| Provides curated demo images (brain tumors + healthy) from an HF Dataset repo. |
| Images are loaded on-demand via HuggingFace Hub API — no bundling in the Space repo. |
| |
| Dataset: mmrech/neurosam3-samples (private HF Dataset) |
| Source images from: |
| - Figshare: brain_tumor_dataset (3064 T1-weighted MRIs, 3 tumor types) |
| - Kaggle: brain-mri-scans-for-brain-tumor-classification |
| """ |
|
|
| from typing import Optional, List, Dict, Any, Tuple |
| import os |
| import tempfile |
| from pathlib import Path |
| from logger_config import logger |
| from config import HF_TOKEN |
|
|
| |
| SAMPLE_DATASET_REPO = "mmrech/neurosam3-samples" |
| SAMPLE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "neurosam3_samples") |
|
|
| |
| SAMPLE_CATEGORIES = { |
| "glioma": { |
| "description": "Glioma tumors (T1-weighted MRI)", |
| "count": 8, |
| "modality": "MRI", |
| "pathology": True, |
| }, |
| "meningioma": { |
| "description": "Meningioma tumors (T1-weighted MRI)", |
| "count": 6, |
| "modality": "MRI", |
| "pathology": True, |
| }, |
| "pituitary": { |
| "description": "Pituitary tumors (T1-weighted MRI)", |
| "count": 6, |
| "modality": "MRI", |
| "pathology": True, |
| }, |
| "healthy": { |
| "description": "Normal brain (T1/T2 MRI)", |
| "count": 5, |
| "modality": "MRI", |
| "pathology": False, |
| }, |
| "ct_normal": { |
| "description": "Normal brain CT scans", |
| "count": 3, |
| "modality": "CT", |
| "pathology": False, |
| }, |
| "ct_hemorrhage": { |
| "description": "CT with intracranial hemorrhage", |
| "count": 2, |
| "modality": "CT", |
| "pathology": True, |
| }, |
| } |
|
|
|
|
| def get_sample_categories() -> Dict[str, Dict[str, Any]]: |
| """Get available sample categories and their metadata.""" |
| return SAMPLE_CATEGORIES |
|
|
|
|
| def list_samples(category: Optional[str] = None) -> List[Dict[str, str]]: |
| """ |
| List available sample images, optionally filtered by category. |
| |
| Returns: |
| List of dicts with 'filename', 'category', 'modality', 'description' |
| """ |
| samples = [] |
|
|
| categories = [category] if category else list(SAMPLE_CATEGORIES.keys()) |
|
|
| for cat in categories: |
| if cat not in SAMPLE_CATEGORIES: |
| continue |
| info = SAMPLE_CATEGORIES[cat] |
| for i in range(1, info["count"] + 1): |
| samples.append({ |
| "filename": f"{cat}/{cat}_{i:03d}.png", |
| "category": cat, |
| "modality": info["modality"], |
| "description": f"{info['description']} (sample {i})", |
| "has_pathology": info["pathology"], |
| }) |
|
|
| return samples |
|
|
|
|
| def load_sample_image( |
| category: str, |
| index: int = 1, |
| ) -> Optional[str]: |
| """ |
| Load a sample image from the HF Dataset repo. |
| |
| Args: |
| category: One of SAMPLE_CATEGORIES keys |
| index: Image index (1-based) |
| |
| Returns: |
| Local file path to the downloaded image, or None on failure |
| """ |
| if category not in SAMPLE_CATEGORIES: |
| logger.error(f"Unknown category: {category}. Available: {list(SAMPLE_CATEGORIES.keys())}") |
| return None |
|
|
| info = SAMPLE_CATEGORIES[category] |
| if index < 1 or index > info["count"]: |
| logger.error(f"Index {index} out of range for {category} (1-{info['count']})") |
| return None |
|
|
| filename = f"{category}/{category}_{index:03d}.png" |
|
|
| |
| cache_path = os.path.join(SAMPLE_CACHE_DIR, filename) |
| if os.path.exists(cache_path): |
| return cache_path |
|
|
| |
| try: |
| from huggingface_hub import hf_hub_download |
|
|
| local_path = hf_hub_download( |
| repo_id=SAMPLE_DATASET_REPO, |
| filename=filename, |
| repo_type="dataset", |
| token=HF_TOKEN, |
| cache_dir=SAMPLE_CACHE_DIR, |
| ) |
| logger.info(f"Loaded sample: {filename}") |
| return local_path |
|
|
| except Exception as e: |
| logger.warning(f"Could not load sample from HF Hub: {e}") |
| |
| return _generate_synthetic_sample(category, index) |
|
|
|
|
| def load_random_sample( |
| modality: Optional[str] = None, |
| pathology: Optional[bool] = None, |
| ) -> Optional[Tuple[str, Dict[str, Any]]]: |
| """ |
| Load a random sample image matching criteria. |
| |
| Args: |
| modality: Filter by "CT" or "MRI" (None for any) |
| pathology: Filter by pathology presence (None for any) |
| |
| Returns: |
| Tuple of (file_path, metadata) or None |
| """ |
| import random |
|
|
| candidates = [] |
| for cat, info in SAMPLE_CATEGORIES.items(): |
| if modality and info["modality"] != modality: |
| continue |
| if pathology is not None and info["pathology"] != pathology: |
| continue |
| candidates.append(cat) |
|
|
| if not candidates: |
| return None |
|
|
| category = random.choice(candidates) |
| info = SAMPLE_CATEGORIES[category] |
| index = random.randint(1, info["count"]) |
|
|
| path = load_sample_image(category, index) |
| if path: |
| return path, { |
| "category": category, |
| "index": index, |
| "modality": info["modality"], |
| "pathology": info["pathology"], |
| "description": info["description"], |
| } |
| return None |
|
|
|
|
| def load_category_batch(category: str) -> List[str]: |
| """ |
| Load all images from a category (for research pipeline demo). |
| |
| Args: |
| category: Category name |
| |
| Returns: |
| List of file paths |
| """ |
| if category not in SAMPLE_CATEGORIES: |
| return [] |
|
|
| paths = [] |
| info = SAMPLE_CATEGORIES[category] |
|
|
| for i in range(1, info["count"] + 1): |
| path = load_sample_image(category, i) |
| if path: |
| paths.append(path) |
|
|
| return paths |
|
|
|
|
| def _generate_synthetic_sample(category: str, index: int) -> Optional[str]: |
| """ |
| Generate a synthetic sample image as fallback when HF Dataset is unavailable. |
| Creates a simple grayscale image with simulated structures. |
| """ |
| try: |
| import numpy as np |
| from PIL import Image |
|
|
| |
| size = 256 |
| img = np.zeros((size, size), dtype=np.uint8) |
|
|
| |
| y, x = np.ogrid[:size, :size] |
| center = size // 2 |
|
|
| |
| skull_mask = ((x - center)**2 / (110**2) + (y - center)**2 / (120**2)) <= 1 |
| img[skull_mask] = 30 |
|
|
| |
| brain_mask = ((x - center)**2 / (90**2) + (y - center)**2 / (100**2)) <= 1 |
| img[brain_mask] = 80 + np.random.randint(0, 20, size=img[brain_mask].shape).astype(np.uint8) |
|
|
| |
| info = SAMPLE_CATEGORIES.get(category, {}) |
| if info.get("pathology", False): |
| |
| tumor_x = center + np.random.randint(-30, 30) |
| tumor_y = center + np.random.randint(-30, 30) |
| tumor_r = np.random.randint(15, 35) |
| tumor_mask = ((x - tumor_x)**2 + (y - tumor_y)**2) <= tumor_r**2 |
| img[tumor_mask & brain_mask] = 160 + np.random.randint(0, 40, size=img[tumor_mask & brain_mask].shape).astype(np.uint8) |
|
|
| |
| os.makedirs(os.path.join(SAMPLE_CACHE_DIR, category), exist_ok=True) |
| save_path = os.path.join(SAMPLE_CACHE_DIR, f"{category}/{category}_{index:03d}.png") |
| Image.fromarray(img).save(save_path) |
|
|
| logger.info(f"Generated synthetic sample: {save_path}") |
| return save_path |
|
|
| except Exception as e: |
| logger.error(f"Failed to generate synthetic sample: {e}") |
| return None |
|
|
|
|
| def get_dataset_info() -> Dict[str, Any]: |
| """Get information about the sample dataset.""" |
| total = sum(info["count"] for info in SAMPLE_CATEGORIES.values()) |
| return { |
| "repo": SAMPLE_DATASET_REPO, |
| "total_samples": total, |
| "categories": list(SAMPLE_CATEGORIES.keys()), |
| "modalities": ["CT", "MRI"], |
| "sources": [ |
| "Figshare: brain_tumor_dataset (Cheng, 2017) — 3064 T1-weighted MRIs", |
| "Kaggle: brain-mri-scans-for-brain-tumor-classification", |
| ], |
| "note": "Images loaded on-demand from HF Hub. Synthetic fallback if unavailable.", |
| } |
|
|