Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| OncoSeg Inference API - HuggingFace Space | |
| Optimized for programmatic access from oncoseg-viewer | |
| This Space provides GPU-accelerated inference for medical image segmentation. | |
| It exposes both a Gradio UI and programmatic API endpoints. | |
| Usage from viewer: | |
| POST /api/segment_slice | |
| POST /api/segment_volume | |
| """ | |
| import os | |
| import io | |
| import base64 | |
| import tempfile | |
| import time | |
| import logging | |
| from pathlib import Path | |
| from typing import Optional, List, Tuple, Any | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import cv2 | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Check for ZeroGPU (HF Spaces) | |
| try: | |
| import spaces | |
| ZEROGPU_AVAILABLE = True | |
| logger.info("ZeroGPU available") | |
| except ImportError: | |
| ZEROGPU_AVAILABLE = False | |
| logger.info("ZeroGPU not available, using standard GPU/CPU") | |
| # Device setup | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Using device: {DEVICE}") | |
| # Global model cache | |
| MODELS: Dict[str, Any] = {} | |
| # Checkpoint mapping (HuggingFace Hub paths) | |
| CHECKPOINTS = { | |
| "brain": "checkpoints/medsam3-task20_brats_gli-final_latest/last.ckpt", | |
| "liver": "checkpoints/medsam3-task03_liver-final_latest/last.ckpt", | |
| "breast": "checkpoints/medsam3-task25_breastdcedl-final_latest/last.ckpt", | |
| "lung": "checkpoints/medsam3-task06_lung-final_latest/last.ckpt", | |
| "kidney": "checkpoints/medsam3-task17_kits23-final_latest/last.ckpt", | |
| "spine": "checkpoints/medsam3-task11_lctsc-final_latest/last.ckpt", | |
| } | |
| # HF Repo ID for checkpoints | |
| HF_REPO_ID = os.getenv("HF_REPO_ID", "tp53/oncoseg") | |
| # Flag to track if we're using fallback mode | |
| USE_FALLBACK = False | |
| def get_model(checkpoint: str = "brain"): | |
| """Load or retrieve cached model. Falls back to simple segmentation if SAM3 unavailable.""" | |
| global MODELS, USE_FALLBACK | |
| if checkpoint not in MODELS: | |
| logger.info(f"Loading model: {checkpoint}") | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| ckpt_file = CHECKPOINTS.get(checkpoint, CHECKPOINTS["brain"]) | |
| ckpt_path = hf_hub_download( | |
| repo_id=HF_REPO_ID, | |
| filename=ckpt_file, | |
| ) | |
| logger.info(f"Downloaded checkpoint to: {ckpt_path}") | |
| # Import model (from local model/ directory in this Space) | |
| from model.medsam3 import MedSAM3Model | |
| # Initialize model with checkpoint | |
| # Note: MedSAM3Model builds SAM3 internally and loads our LoRA weights | |
| model = MedSAM3Model(checkpoint_path=ckpt_path) | |
| model.to(DEVICE) | |
| model.eval() | |
| MODELS[checkpoint] = model | |
| logger.info(f"Model {checkpoint} loaded on {DEVICE}") | |
| except ImportError as e: | |
| logger.warning(f"SAM3 not available, using fallback segmentation: {e}") | |
| USE_FALLBACK = True | |
| MODELS[checkpoint] = None | |
| except Exception as e: | |
| logger.error(f"Failed to load model {checkpoint}: {e}") | |
| USE_FALLBACK = True | |
| MODELS[checkpoint] = None | |
| return MODELS.get(checkpoint) | |
| def fallback_segment(slice_2d: np.ndarray): | |
| """ | |
| Simple intensity-based segmentation fallback when SAM3 is not available. | |
| Works well for FLAIR MRI where tumors appear hyperintense. | |
| """ | |
| from skimage.filters import threshold_otsu | |
| from skimage.morphology import binary_opening, binary_closing, disk | |
| # Normalize | |
| vmin, vmax = slice_2d.min(), slice_2d.max() | |
| if vmax - vmin < 1e-8: | |
| return np.zeros_like(slice_2d, dtype=np.uint8) | |
| normalized = (slice_2d - vmin) / (vmax - vmin) | |
| # Use percentile threshold (top 15% intensity = potential tumor) | |
| threshold = np.percentile(normalized, 85) | |
| mask = (normalized > threshold).astype(np.uint8) | |
| # Morphological cleanup | |
| try: | |
| mask = binary_opening(mask, disk(2)) | |
| mask = binary_closing(mask, disk(3)) | |
| except: | |
| pass | |
| return mask.astype(np.uint8) | |
| def preprocess_slice(slice_2d: np.ndarray, target_size: int = 1024) -> torch.Tensor: | |
| """ | |
| Preprocess a 2D slice for SAM3 input. | |
| Args: | |
| slice_2d: Input slice (H, W) | |
| target_size: Target size for SAM3 (default 1024) | |
| Returns: | |
| Preprocessed tensor (1, 3, H, W) on DEVICE | |
| """ | |
| import cv2 | |
| # Normalize to [0, 1] | |
| vmin, vmax = slice_2d.min(), slice_2d.max() | |
| if vmax - vmin < 1e-8: | |
| slice_norm = np.zeros_like(slice_2d) | |
| else: | |
| slice_norm = (slice_2d - vmin) / (vmax - vmin) | |
| # Resize to target size | |
| slice_resized = cv2.resize( | |
| slice_norm.astype(np.float32), (target_size, target_size) | |
| ) | |
| # Scale to [-1, 1] for SAM3 | |
| slice_scaled = slice_resized * 2 - 1 | |
| # Convert to 3-channel tensor (B, C, H, W) | |
| slice_tensor = torch.from_numpy(slice_scaled).float() | |
| slice_tensor = slice_tensor.unsqueeze(0).unsqueeze(0) # (1, 1, H, W) | |
| slice_tensor = slice_tensor.repeat(1, 3, 1, 1) # (1, 3, H, W) | |
| return slice_tensor.to(DEVICE) | |
| def find_contours(mask: np.ndarray) -> List[List[List[float]]]: | |
| """Extract contours from binary mask.""" | |
| try: | |
| from skimage.measure import find_contours as sk_find_contours | |
| contours = sk_find_contours(mask, 0.5) | |
| return [c.tolist() for c in contours] | |
| except ImportError: | |
| return [] | |
| def keep_largest_component(mask: np.ndarray) -> np.ndarray: | |
| """Keep only the largest connected component.""" | |
| try: | |
| from scipy import ndimage | |
| labeled, num_features = ndimage.label(mask) | |
| if num_features <= 1: | |
| return mask | |
| sizes = ndimage.sum(mask, labeled, range(1, num_features + 1)) | |
| largest = np.argmax(sizes) + 1 | |
| return (labeled == largest).astype(np.uint8) | |
| except ImportError: | |
| return mask | |
| # Define the inference function with optional ZeroGPU decorator | |
| def _segment_slice_impl( | |
| nifti_b64: str, | |
| slice_idx: int, | |
| text_prompt: str = "tumor", | |
| checkpoint: str = "brain", | |
| ): | |
| """ | |
| Segment a single slice from a NIfTI volume. | |
| Args: | |
| nifti_b64: Base64-encoded NIfTI file bytes | |
| slice_idx: Slice index to segment (0-indexed) | |
| text_prompt: Text prompt for segmentation (e.g., "tumor", "lesion") | |
| checkpoint: Model checkpoint name | |
| Returns: | |
| dict with keys: success, mask_b64, mask_shape, contours, slice_idx, inference_time_ms | |
| """ | |
| start_time = time.time() | |
| try: | |
| import nibabel as nib | |
| # Decode NIfTI | |
| nifti_bytes = base64.b64decode(nifti_b64) | |
| with tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False) as f: | |
| f.write(nifti_bytes) | |
| temp_path = f.name | |
| nii = nib.load(temp_path) | |
| volume = nii.get_fdata().astype(np.float32) | |
| os.unlink(temp_path) | |
| logger.info( | |
| f"Loaded volume shape: {volume.shape}, segmenting slice {slice_idx}" | |
| ) | |
| # Validate slice index | |
| if slice_idx < 0 or slice_idx >= volume.shape[0]: | |
| return { | |
| "success": False, | |
| "error": f"Slice index {slice_idx} out of range [0, {volume.shape[0]})", | |
| } | |
| # Extract slice | |
| slice_2d = volume[slice_idx] | |
| original_shape = slice_2d.shape | |
| # Load model (may return None if fallback mode) | |
| model = get_model(checkpoint) | |
| if model is None or USE_FALLBACK: | |
| # Use fallback segmentation | |
| logger.info("Using fallback segmentation (SAM3 not available)") | |
| mask = fallback_segment(slice_2d) | |
| backend = "fallback" | |
| else: | |
| # Use SAM3 model | |
| slice_tensor = preprocess_slice( | |
| slice_2d | |
| ) # (1, 3, 1024, 1024) tensor on DEVICE | |
| # Create full-image bounding box prompt (auto-segment entire image) | |
| # Format: [x_min, y_min, x_max, y_max] in pixel coordinates | |
| target_size = slice_tensor.shape[-1] # 1024 | |
| input_boxes = torch.tensor( | |
| [[0, 0, target_size, target_size]], dtype=torch.float32, device=DEVICE | |
| ) | |
| # Run inference with text prompt for grounding | |
| with torch.no_grad(): | |
| outputs = model( | |
| pixel_values=slice_tensor, | |
| input_boxes=input_boxes, | |
| text_prompt=text_prompt, | |
| ) | |
| # Extract mask from SAM3 output | |
| # SAM3 returns a dict with 'pred_masks' key, shape (B, 1, H, W) | |
| if isinstance(outputs, dict) and "pred_masks" in outputs: | |
| pred_mask = outputs["pred_masks"][0, 0].cpu().numpy() | |
| elif hasattr(outputs, "pred_masks"): | |
| pred_mask = outputs.pred_masks[0, 0].cpu().numpy() | |
| else: | |
| # Fallback: try to extract from tuple/list | |
| logger.warning(f"Unexpected output type: {type(outputs)}") | |
| pred_mask = np.zeros((target_size, target_size)) | |
| # Resize mask back to original shape | |
| mask = cv2.resize(pred_mask, (original_shape[1], original_shape[0])) | |
| backend = "sam3" | |
| # Threshold to binary | |
| mask = (mask > 0.5).astype(np.uint8) | |
| mask = keep_largest_component(mask) | |
| # Extract contours | |
| contours = find_contours(mask) | |
| # Encode mask as base64 | |
| mask_b64 = base64.b64encode(mask.tobytes()).decode() | |
| inference_time = int((time.time() - start_time) * 1000) | |
| logger.info( | |
| f"Segmented slice {slice_idx} in {inference_time}ms, mask sum: {mask.sum()}" | |
| ) | |
| return { | |
| "success": True, | |
| "backend": backend, | |
| "mask_b64": mask_b64, | |
| "mask_shape": list(mask.shape), | |
| "contours": contours, | |
| "slice_idx": slice_idx, | |
| "inference_time_ms": inference_time, | |
| } | |
| except Exception as e: | |
| logger.error(f"Segmentation failed: {e}") | |
| return {"success": False, "error": str(e)} | |
| def _segment_volume_impl( | |
| nifti_b64: str, | |
| text_prompt: str = "tumor", | |
| checkpoint: str = "brain", | |
| skip_empty: bool = True, | |
| min_area: int = 50, | |
| ): | |
| """ | |
| Segment entire volume and return contours for all slices with detections. | |
| Args: | |
| nifti_b64: Base64-encoded NIfTI file bytes | |
| text_prompt: Text prompt for segmentation | |
| checkpoint: Model checkpoint name | |
| skip_empty: Skip mostly-empty slices | |
| min_area: Minimum mask area to report | |
| Returns: | |
| dict with keys: success, contours (dict), num_slices, slices_with_tumor, inference_time_ms | |
| """ | |
| start_time = time.time() | |
| try: | |
| import nibabel as nib | |
| # Decode NIfTI | |
| nifti_bytes = base64.b64decode(nifti_b64) | |
| with tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False) as f: | |
| f.write(nifti_bytes) | |
| temp_path = f.name | |
| nii = nib.load(temp_path) | |
| volume = nii.get_fdata().astype(np.float32) | |
| os.unlink(temp_path) | |
| logger.info(f"Loaded volume shape: {volume.shape}") | |
| # Load model (may return None if fallback mode) | |
| model = get_model(checkpoint) | |
| use_fallback = model is None or USE_FALLBACK | |
| num_slices = volume.shape[0] | |
| all_contours = {} | |
| target_size = 1024 | |
| for i in range(num_slices): | |
| slice_2d = volume[i] | |
| original_shape = slice_2d.shape | |
| # Skip mostly-empty slices | |
| if skip_empty and slice_2d.max() - slice_2d.min() < 0.01: | |
| continue | |
| if use_fallback: | |
| # Use fallback segmentation | |
| mask = fallback_segment(slice_2d) | |
| else: | |
| slice_tensor = preprocess_slice(slice_2d, target_size) | |
| # Create full-image bounding box | |
| input_boxes = torch.tensor( | |
| [[0, 0, target_size, target_size]], | |
| dtype=torch.float32, | |
| device=DEVICE, | |
| ) | |
| with torch.no_grad(): | |
| outputs = model( | |
| pixel_values=slice_tensor, | |
| input_boxes=input_boxes, | |
| text_prompt=text_prompt, | |
| ) | |
| # Extract mask from SAM3 output | |
| if isinstance(outputs, dict) and "pred_masks" in outputs: | |
| pred_mask = outputs["pred_masks"][0, 0].cpu().numpy() | |
| elif hasattr(outputs, "pred_masks"): | |
| pred_mask = outputs.pred_masks[0, 0].cpu().numpy() | |
| else: | |
| continue # Skip if no valid output | |
| # Resize to original shape and threshold | |
| mask = cv2.resize(pred_mask, (original_shape[1], original_shape[0])) | |
| mask = (mask > 0.5).astype(np.uint8) | |
| if mask.sum() >= min_area: | |
| mask = keep_largest_component(mask) | |
| contours = find_contours(mask) | |
| if contours: | |
| all_contours[str(i)] = contours | |
| inference_time = int((time.time() - start_time) * 1000) | |
| logger.info( | |
| f"Segmented {num_slices} slices in {inference_time}ms, found tumor in {len(all_contours)} slices" | |
| ) | |
| return { | |
| "success": True, | |
| "contours": all_contours, | |
| "num_slices": num_slices, | |
| "slices_with_tumor": list(all_contours.keys()), | |
| "inference_time_ms": inference_time, | |
| } | |
| except Exception as e: | |
| logger.error(f"Volume segmentation failed: {e}") | |
| return {"success": False, "error": str(e)} | |
| # Apply ZeroGPU decorator if available | |
| if ZEROGPU_AVAILABLE: | |
| def segment_slice_api( | |
| nifti_b64: str, | |
| slice_idx: int, | |
| text_prompt: str = "tumor", | |
| checkpoint: str = "brain", | |
| ): | |
| return _segment_slice_impl(nifti_b64, slice_idx, text_prompt, checkpoint) | |
| def segment_volume_api( | |
| nifti_b64: str, | |
| text_prompt: str = "tumor", | |
| checkpoint: str = "brain", | |
| skip_empty: bool = True, | |
| min_area: int = 50, | |
| ): | |
| return _segment_volume_impl( | |
| nifti_b64, text_prompt, checkpoint, skip_empty, min_area | |
| ) | |
| else: | |
| segment_slice_api = _segment_slice_impl | |
| segment_volume_api = _segment_volume_impl | |
| # Gradio UI functions (for interactive demo) | |
| def load_and_display_nifti(file): | |
| """Load NIfTI and return middle slice for display.""" | |
| if file is None: | |
| return None, "No file uploaded", 0 | |
| try: | |
| import nibabel as nib | |
| nii = nib.load(file.name) | |
| volume = nii.get_fdata() | |
| middle_slice = volume.shape[0] // 2 | |
| slice_2d = volume[middle_slice] | |
| # Normalize for display | |
| vmin, vmax = slice_2d.min(), slice_2d.max() | |
| if vmax - vmin > 0: | |
| display = ((slice_2d - vmin) / (vmax - vmin) * 255).astype(np.uint8) | |
| else: | |
| display = np.zeros_like(slice_2d, dtype=np.uint8) | |
| # Convert to RGB | |
| display_rgb = np.stack([display] * 3, axis=-1) | |
| return ( | |
| display_rgb, | |
| f"Loaded: {volume.shape}, showing slice {middle_slice}", | |
| volume.shape[0], | |
| ) | |
| except Exception as e: | |
| return None, f"Error: {e}", 0 | |
| def segment_and_overlay(file, slice_idx: int, text_prompt: str, checkpoint: str): | |
| """Segment a slice and overlay the mask.""" | |
| if file is None: | |
| return None, "Please upload a file first" | |
| try: | |
| # Read file as base64 | |
| with open(file.name, "rb") as f: | |
| nifti_b64 = base64.b64encode(f.read()).decode() | |
| # Call segmentation API | |
| result = segment_slice_api(nifti_b64, int(slice_idx), text_prompt, checkpoint) | |
| if not result["success"]: | |
| return None, f"Segmentation failed: {result.get('error', 'Unknown error')}" | |
| # Load original slice for display | |
| import nibabel as nib | |
| nii = nib.load(file.name) | |
| volume = nii.get_fdata() | |
| slice_2d = volume[int(slice_idx)] | |
| # Normalize for display | |
| vmin, vmax = slice_2d.min(), slice_2d.max() | |
| if vmax - vmin > 0: | |
| display = ((slice_2d - vmin) / (vmax - vmin) * 255).astype(np.uint8) | |
| else: | |
| display = np.zeros_like(slice_2d, dtype=np.uint8) | |
| # Decode mask | |
| mask_bytes = base64.b64decode(result["mask_b64"]) | |
| mask = np.frombuffer(mask_bytes, dtype=np.uint8).reshape(result["mask_shape"]) | |
| # Create overlay | |
| rgb = np.stack([display] * 3, axis=-1).astype(np.float32) | |
| mask_bool = mask > 0 | |
| alpha = 0.4 | |
| rgb[mask_bool, 0] = rgb[mask_bool, 0] * (1 - alpha) + 255 * alpha # Red | |
| rgb[mask_bool, 1] = rgb[mask_bool, 1] * (1 - alpha) + 50 * alpha | |
| rgb[mask_bool, 2] = rgb[mask_bool, 2] * (1 - alpha) + 50 * alpha | |
| info = f"Segmented in {result['inference_time_ms']}ms, mask area: {mask.sum()} pixels" | |
| return rgb.astype(np.uint8), info | |
| except Exception as e: | |
| return None, f"Error: {e}" | |
| # Build Gradio interface | |
| def build_demo(): | |
| with gr.Blocks( | |
| title="OncoSeg Inference API", | |
| theme=gr.themes.Soft(), | |
| ) as demo: | |
| gr.Markdown(""" | |
| # OncoSeg Medical Image Segmentation API | |
| GPU-accelerated segmentation for CT and MRI volumes. | |
| **API Endpoints** (for programmatic access): | |
| - `POST /api/segment_slice_api` - Segment a single slice | |
| - `POST /api/segment_volume_api` - Segment entire volume | |
| **Interactive Demo** below: | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| file_input = gr.File( | |
| label="Upload NIfTI (.nii, .nii.gz)", file_types=[".nii", ".nii.gz"] | |
| ) | |
| checkpoint = gr.Dropdown( | |
| choices=list(CHECKPOINTS.keys()), | |
| value="brain", | |
| label="Model Checkpoint", | |
| ) | |
| text_prompt = gr.Textbox( | |
| value="tumor", | |
| label="Text Prompt", | |
| placeholder="e.g., tumor, lesion, mass", | |
| ) | |
| slice_idx = gr.Slider( | |
| minimum=0, | |
| maximum=200, | |
| value=77, | |
| step=1, | |
| label="Slice Index", | |
| ) | |
| segment_btn = gr.Button("Segment Slice", variant="primary") | |
| with gr.Column(scale=2): | |
| output_image = gr.Image(label="Segmentation Result", type="numpy") | |
| status_text = gr.Textbox(label="Status", interactive=False) | |
| # Event handlers | |
| file_input.change( | |
| fn=load_and_display_nifti, | |
| inputs=[file_input], | |
| outputs=[output_image, status_text, slice_idx], | |
| ) | |
| segment_btn.click( | |
| fn=segment_and_overlay, | |
| inputs=[file_input, slice_idx, text_prompt, checkpoint], | |
| outputs=[output_image, status_text], | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### API Usage Example | |
| ```python | |
| import requests | |
| import base64 | |
| # Read NIfTI file | |
| with open("brain.nii.gz", "rb") as f: | |
| nifti_b64 = base64.b64encode(f.read()).decode() | |
| # Call API | |
| response = requests.post( | |
| "https://YOUR-SPACE.hf.space/api/segment_slice_api", | |
| json={ | |
| "nifti_b64": nifti_b64, | |
| "slice_idx": 77, | |
| "text_prompt": "tumor", | |
| "checkpoint": "brain", | |
| } | |
| ) | |
| result = response.json() | |
| # result["contours"] contains the segmentation contours | |
| ``` | |
| """) | |
| return demo | |
| # Launch | |
| if __name__ == "__main__": | |
| demo = build_demo() | |
| demo.queue() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |