""" Sample Data Loader for NeuroSAM3. Provides curated demo images (brain tumors + healthy) from an HF Dataset repo. Images are loaded on-demand via HuggingFace Hub API — no bundling in the Space repo. Dataset: mmrech/neurosam3-samples (private HF Dataset) Source images from: - Figshare: brain_tumor_dataset (3064 T1-weighted MRIs, 3 tumor types) - Kaggle: brain-mri-scans-for-brain-tumor-classification """ from typing import Optional, List, Dict, Any, Tuple import os import tempfile from pathlib import Path from logger_config import logger from config import HF_TOKEN # Dataset configuration SAMPLE_DATASET_REPO = "mmrech/neurosam3-samples" SAMPLE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "neurosam3_samples") # Sample categories available in the dataset SAMPLE_CATEGORIES = { "glioma": { "description": "Glioma tumors (T1-weighted MRI)", "count": 8, "modality": "MRI", "pathology": True, }, "meningioma": { "description": "Meningioma tumors (T1-weighted MRI)", "count": 6, "modality": "MRI", "pathology": True, }, "pituitary": { "description": "Pituitary tumors (T1-weighted MRI)", "count": 6, "modality": "MRI", "pathology": True, }, "healthy": { "description": "Normal brain (T1/T2 MRI)", "count": 5, "modality": "MRI", "pathology": False, }, "ct_normal": { "description": "Normal brain CT scans", "count": 3, "modality": "CT", "pathology": False, }, "ct_hemorrhage": { "description": "CT with intracranial hemorrhage", "count": 2, "modality": "CT", "pathology": True, }, } def get_sample_categories() -> Dict[str, Dict[str, Any]]: """Get available sample categories and their metadata.""" return SAMPLE_CATEGORIES def list_samples(category: Optional[str] = None) -> List[Dict[str, str]]: """ List available sample images, optionally filtered by category. Returns: List of dicts with 'filename', 'category', 'modality', 'description' """ samples = [] categories = [category] if category else list(SAMPLE_CATEGORIES.keys()) for cat in categories: if cat not in SAMPLE_CATEGORIES: continue info = SAMPLE_CATEGORIES[cat] for i in range(1, info["count"] + 1): samples.append({ "filename": f"{cat}/{cat}_{i:03d}.png", "category": cat, "modality": info["modality"], "description": f"{info['description']} (sample {i})", "has_pathology": info["pathology"], }) return samples def load_sample_image( category: str, index: int = 1, ) -> Optional[str]: """ Load a sample image from the HF Dataset repo. Args: category: One of SAMPLE_CATEGORIES keys index: Image index (1-based) Returns: Local file path to the downloaded image, or None on failure """ if category not in SAMPLE_CATEGORIES: logger.error(f"Unknown category: {category}. Available: {list(SAMPLE_CATEGORIES.keys())}") return None info = SAMPLE_CATEGORIES[category] if index < 1 or index > info["count"]: logger.error(f"Index {index} out of range for {category} (1-{info['count']})") return None filename = f"{category}/{category}_{index:03d}.png" # Check cache first cache_path = os.path.join(SAMPLE_CACHE_DIR, filename) if os.path.exists(cache_path): return cache_path # Download from HF Hub try: from huggingface_hub import hf_hub_download local_path = hf_hub_download( repo_id=SAMPLE_DATASET_REPO, filename=filename, repo_type="dataset", token=HF_TOKEN, cache_dir=SAMPLE_CACHE_DIR, ) logger.info(f"Loaded sample: {filename}") return local_path except Exception as e: logger.warning(f"Could not load sample from HF Hub: {e}") # Fallback: try to generate a synthetic sample return _generate_synthetic_sample(category, index) def load_random_sample( modality: Optional[str] = None, pathology: Optional[bool] = None, ) -> Optional[Tuple[str, Dict[str, Any]]]: """ Load a random sample image matching criteria. Args: modality: Filter by "CT" or "MRI" (None for any) pathology: Filter by pathology presence (None for any) Returns: Tuple of (file_path, metadata) or None """ import random candidates = [] for cat, info in SAMPLE_CATEGORIES.items(): if modality and info["modality"] != modality: continue if pathology is not None and info["pathology"] != pathology: continue candidates.append(cat) if not candidates: return None category = random.choice(candidates) info = SAMPLE_CATEGORIES[category] index = random.randint(1, info["count"]) path = load_sample_image(category, index) if path: return path, { "category": category, "index": index, "modality": info["modality"], "pathology": info["pathology"], "description": info["description"], } return None def load_category_batch(category: str) -> List[str]: """ Load all images from a category (for research pipeline demo). Args: category: Category name Returns: List of file paths """ if category not in SAMPLE_CATEGORIES: return [] paths = [] info = SAMPLE_CATEGORIES[category] for i in range(1, info["count"] + 1): path = load_sample_image(category, i) if path: paths.append(path) return paths def _generate_synthetic_sample(category: str, index: int) -> Optional[str]: """ Generate a synthetic sample image as fallback when HF Dataset is unavailable. Creates a simple grayscale image with simulated structures. """ try: import numpy as np from PIL import Image # Create synthetic brain-like image size = 256 img = np.zeros((size, size), dtype=np.uint8) # Background with some texture y, x = np.ogrid[:size, :size] center = size // 2 # Skull (outer ellipse) skull_mask = ((x - center)**2 / (110**2) + (y - center)**2 / (120**2)) <= 1 img[skull_mask] = 30 # Brain (inner ellipse) brain_mask = ((x - center)**2 / (90**2) + (y - center)**2 / (100**2)) <= 1 img[brain_mask] = 80 + np.random.randint(0, 20, size=img[brain_mask].shape).astype(np.uint8) # Add pathology if category indicates it info = SAMPLE_CATEGORIES.get(category, {}) if info.get("pathology", False): # Add a "tumor" blob tumor_x = center + np.random.randint(-30, 30) tumor_y = center + np.random.randint(-30, 30) tumor_r = np.random.randint(15, 35) tumor_mask = ((x - tumor_x)**2 + (y - tumor_y)**2) <= tumor_r**2 img[tumor_mask & brain_mask] = 160 + np.random.randint(0, 40, size=img[tumor_mask & brain_mask].shape).astype(np.uint8) # Save os.makedirs(os.path.join(SAMPLE_CACHE_DIR, category), exist_ok=True) save_path = os.path.join(SAMPLE_CACHE_DIR, f"{category}/{category}_{index:03d}.png") Image.fromarray(img).save(save_path) logger.info(f"Generated synthetic sample: {save_path}") return save_path except Exception as e: logger.error(f"Failed to generate synthetic sample: {e}") return None def get_dataset_info() -> Dict[str, Any]: """Get information about the sample dataset.""" total = sum(info["count"] for info in SAMPLE_CATEGORIES.values()) return { "repo": SAMPLE_DATASET_REPO, "total_samples": total, "categories": list(SAMPLE_CATEGORIES.keys()), "modalities": ["CT", "MRI"], "sources": [ "Figshare: brain_tumor_dataset (Cheng, 2017) — 3064 T1-weighted MRIs", "Kaggle: brain-mri-scans-for-brain-tumor-classification", ], "note": "Images loaded on-demand from HF Hub. Synthetic fallback if unavailable.", }