oncoseg-api / app.py
tp53's picture
Upload folder using huggingface_hub
4eaaaee verified
raw
history blame
20.2 kB
#!/usr/bin/env python3
"""
OncoSeg Inference API - HuggingFace Space
Optimized for programmatic access from oncoseg-viewer
This Space provides GPU-accelerated inference for medical image segmentation.
It exposes both a Gradio UI and programmatic API endpoints.
Usage from viewer:
POST /api/segment_slice
POST /api/segment_volume
"""
import os
import io
import base64
import tempfile
import time
import logging
from pathlib import Path
from typing import Optional, List, Tuple, Any
import gradio as gr
import numpy as np
import torch
import cv2
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Check for ZeroGPU (HF Spaces)
try:
import spaces
ZEROGPU_AVAILABLE = True
logger.info("ZeroGPU available")
except ImportError:
ZEROGPU_AVAILABLE = False
logger.info("ZeroGPU not available, using standard GPU/CPU")
# Device setup
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {DEVICE}")
# Global model cache
MODELS: Dict[str, Any] = {}
# Checkpoint mapping (HuggingFace Hub paths)
CHECKPOINTS = {
"brain": "checkpoints/medsam3-task20_brats_gli-final_latest/last.ckpt",
"liver": "checkpoints/medsam3-task03_liver-final_latest/last.ckpt",
"breast": "checkpoints/medsam3-task25_breastdcedl-final_latest/last.ckpt",
"lung": "checkpoints/medsam3-task06_lung-final_latest/last.ckpt",
"kidney": "checkpoints/medsam3-task17_kits23-final_latest/last.ckpt",
"spine": "checkpoints/medsam3-task11_lctsc-final_latest/last.ckpt",
}
# HF Repo ID for checkpoints
HF_REPO_ID = os.getenv("HF_REPO_ID", "tp53/oncoseg")
# Flag to track if we're using fallback mode
USE_FALLBACK = False
def get_model(checkpoint: str = "brain"):
"""Load or retrieve cached model. Falls back to simple segmentation if SAM3 unavailable."""
global MODELS, USE_FALLBACK
if checkpoint not in MODELS:
logger.info(f"Loading model: {checkpoint}")
try:
from huggingface_hub import hf_hub_download
ckpt_file = CHECKPOINTS.get(checkpoint, CHECKPOINTS["brain"])
ckpt_path = hf_hub_download(
repo_id=HF_REPO_ID,
filename=ckpt_file,
)
logger.info(f"Downloaded checkpoint to: {ckpt_path}")
# Import model (from local model/ directory in this Space)
from model.medsam3 import MedSAM3Model
# Initialize model with checkpoint
# Note: MedSAM3Model builds SAM3 internally and loads our LoRA weights
model = MedSAM3Model(checkpoint_path=ckpt_path)
model.to(DEVICE)
model.eval()
MODELS[checkpoint] = model
logger.info(f"Model {checkpoint} loaded on {DEVICE}")
except ImportError as e:
logger.warning(f"SAM3 not available, using fallback segmentation: {e}")
USE_FALLBACK = True
MODELS[checkpoint] = None
except Exception as e:
logger.error(f"Failed to load model {checkpoint}: {e}")
USE_FALLBACK = True
MODELS[checkpoint] = None
return MODELS.get(checkpoint)
def fallback_segment(slice_2d: np.ndarray):
"""
Simple intensity-based segmentation fallback when SAM3 is not available.
Works well for FLAIR MRI where tumors appear hyperintense.
"""
from skimage.filters import threshold_otsu
from skimage.morphology import binary_opening, binary_closing, disk
# Normalize
vmin, vmax = slice_2d.min(), slice_2d.max()
if vmax - vmin < 1e-8:
return np.zeros_like(slice_2d, dtype=np.uint8)
normalized = (slice_2d - vmin) / (vmax - vmin)
# Use percentile threshold (top 15% intensity = potential tumor)
threshold = np.percentile(normalized, 85)
mask = (normalized > threshold).astype(np.uint8)
# Morphological cleanup
try:
mask = binary_opening(mask, disk(2))
mask = binary_closing(mask, disk(3))
except:
pass
return mask.astype(np.uint8)
def preprocess_slice(slice_2d: np.ndarray, target_size: int = 1024) -> torch.Tensor:
"""
Preprocess a 2D slice for SAM3 input.
Args:
slice_2d: Input slice (H, W)
target_size: Target size for SAM3 (default 1024)
Returns:
Preprocessed tensor (1, 3, H, W) on DEVICE
"""
import cv2
# Normalize to [0, 1]
vmin, vmax = slice_2d.min(), slice_2d.max()
if vmax - vmin < 1e-8:
slice_norm = np.zeros_like(slice_2d)
else:
slice_norm = (slice_2d - vmin) / (vmax - vmin)
# Resize to target size
slice_resized = cv2.resize(
slice_norm.astype(np.float32), (target_size, target_size)
)
# Scale to [-1, 1] for SAM3
slice_scaled = slice_resized * 2 - 1
# Convert to 3-channel tensor (B, C, H, W)
slice_tensor = torch.from_numpy(slice_scaled).float()
slice_tensor = slice_tensor.unsqueeze(0).unsqueeze(0) # (1, 1, H, W)
slice_tensor = slice_tensor.repeat(1, 3, 1, 1) # (1, 3, H, W)
return slice_tensor.to(DEVICE)
def find_contours(mask: np.ndarray) -> List[List[List[float]]]:
"""Extract contours from binary mask."""
try:
from skimage.measure import find_contours as sk_find_contours
contours = sk_find_contours(mask, 0.5)
return [c.tolist() for c in contours]
except ImportError:
return []
def keep_largest_component(mask: np.ndarray) -> np.ndarray:
"""Keep only the largest connected component."""
try:
from scipy import ndimage
labeled, num_features = ndimage.label(mask)
if num_features <= 1:
return mask
sizes = ndimage.sum(mask, labeled, range(1, num_features + 1))
largest = np.argmax(sizes) + 1
return (labeled == largest).astype(np.uint8)
except ImportError:
return mask
# Define the inference function with optional ZeroGPU decorator
def _segment_slice_impl(
nifti_b64: str,
slice_idx: int,
text_prompt: str = "tumor",
checkpoint: str = "brain",
):
"""
Segment a single slice from a NIfTI volume.
Args:
nifti_b64: Base64-encoded NIfTI file bytes
slice_idx: Slice index to segment (0-indexed)
text_prompt: Text prompt for segmentation (e.g., "tumor", "lesion")
checkpoint: Model checkpoint name
Returns:
dict with keys: success, mask_b64, mask_shape, contours, slice_idx, inference_time_ms
"""
start_time = time.time()
try:
import nibabel as nib
# Decode NIfTI
nifti_bytes = base64.b64decode(nifti_b64)
with tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False) as f:
f.write(nifti_bytes)
temp_path = f.name
nii = nib.load(temp_path)
volume = nii.get_fdata().astype(np.float32)
os.unlink(temp_path)
logger.info(
f"Loaded volume shape: {volume.shape}, segmenting slice {slice_idx}"
)
# Validate slice index
if slice_idx < 0 or slice_idx >= volume.shape[0]:
return {
"success": False,
"error": f"Slice index {slice_idx} out of range [0, {volume.shape[0]})",
}
# Extract slice
slice_2d = volume[slice_idx]
original_shape = slice_2d.shape
# Load model (may return None if fallback mode)
model = get_model(checkpoint)
if model is None or USE_FALLBACK:
# Use fallback segmentation
logger.info("Using fallback segmentation (SAM3 not available)")
mask = fallback_segment(slice_2d)
backend = "fallback"
else:
# Use SAM3 model
slice_tensor = preprocess_slice(
slice_2d
) # (1, 3, 1024, 1024) tensor on DEVICE
# Create full-image bounding box prompt (auto-segment entire image)
# Format: [x_min, y_min, x_max, y_max] in pixel coordinates
target_size = slice_tensor.shape[-1] # 1024
input_boxes = torch.tensor(
[[0, 0, target_size, target_size]], dtype=torch.float32, device=DEVICE
)
# Run inference with text prompt for grounding
with torch.no_grad():
outputs = model(
pixel_values=slice_tensor,
input_boxes=input_boxes,
text_prompt=text_prompt,
)
# Extract mask from SAM3 output
# SAM3 returns a dict with 'pred_masks' key, shape (B, 1, H, W)
if isinstance(outputs, dict) and "pred_masks" in outputs:
pred_mask = outputs["pred_masks"][0, 0].cpu().numpy()
elif hasattr(outputs, "pred_masks"):
pred_mask = outputs.pred_masks[0, 0].cpu().numpy()
else:
# Fallback: try to extract from tuple/list
logger.warning(f"Unexpected output type: {type(outputs)}")
pred_mask = np.zeros((target_size, target_size))
# Resize mask back to original shape
mask = cv2.resize(pred_mask, (original_shape[1], original_shape[0]))
backend = "sam3"
# Threshold to binary
mask = (mask > 0.5).astype(np.uint8)
mask = keep_largest_component(mask)
# Extract contours
contours = find_contours(mask)
# Encode mask as base64
mask_b64 = base64.b64encode(mask.tobytes()).decode()
inference_time = int((time.time() - start_time) * 1000)
logger.info(
f"Segmented slice {slice_idx} in {inference_time}ms, mask sum: {mask.sum()}"
)
return {
"success": True,
"backend": backend,
"mask_b64": mask_b64,
"mask_shape": list(mask.shape),
"contours": contours,
"slice_idx": slice_idx,
"inference_time_ms": inference_time,
}
except Exception as e:
logger.error(f"Segmentation failed: {e}")
return {"success": False, "error": str(e)}
def _segment_volume_impl(
nifti_b64: str,
text_prompt: str = "tumor",
checkpoint: str = "brain",
skip_empty: bool = True,
min_area: int = 50,
):
"""
Segment entire volume and return contours for all slices with detections.
Args:
nifti_b64: Base64-encoded NIfTI file bytes
text_prompt: Text prompt for segmentation
checkpoint: Model checkpoint name
skip_empty: Skip mostly-empty slices
min_area: Minimum mask area to report
Returns:
dict with keys: success, contours (dict), num_slices, slices_with_tumor, inference_time_ms
"""
start_time = time.time()
try:
import nibabel as nib
# Decode NIfTI
nifti_bytes = base64.b64decode(nifti_b64)
with tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False) as f:
f.write(nifti_bytes)
temp_path = f.name
nii = nib.load(temp_path)
volume = nii.get_fdata().astype(np.float32)
os.unlink(temp_path)
logger.info(f"Loaded volume shape: {volume.shape}")
# Load model (may return None if fallback mode)
model = get_model(checkpoint)
use_fallback = model is None or USE_FALLBACK
num_slices = volume.shape[0]
all_contours = {}
target_size = 1024
for i in range(num_slices):
slice_2d = volume[i]
original_shape = slice_2d.shape
# Skip mostly-empty slices
if skip_empty and slice_2d.max() - slice_2d.min() < 0.01:
continue
if use_fallback:
# Use fallback segmentation
mask = fallback_segment(slice_2d)
else:
slice_tensor = preprocess_slice(slice_2d, target_size)
# Create full-image bounding box
input_boxes = torch.tensor(
[[0, 0, target_size, target_size]],
dtype=torch.float32,
device=DEVICE,
)
with torch.no_grad():
outputs = model(
pixel_values=slice_tensor,
input_boxes=input_boxes,
text_prompt=text_prompt,
)
# Extract mask from SAM3 output
if isinstance(outputs, dict) and "pred_masks" in outputs:
pred_mask = outputs["pred_masks"][0, 0].cpu().numpy()
elif hasattr(outputs, "pred_masks"):
pred_mask = outputs.pred_masks[0, 0].cpu().numpy()
else:
continue # Skip if no valid output
# Resize to original shape and threshold
mask = cv2.resize(pred_mask, (original_shape[1], original_shape[0]))
mask = (mask > 0.5).astype(np.uint8)
if mask.sum() >= min_area:
mask = keep_largest_component(mask)
contours = find_contours(mask)
if contours:
all_contours[str(i)] = contours
inference_time = int((time.time() - start_time) * 1000)
logger.info(
f"Segmented {num_slices} slices in {inference_time}ms, found tumor in {len(all_contours)} slices"
)
return {
"success": True,
"contours": all_contours,
"num_slices": num_slices,
"slices_with_tumor": list(all_contours.keys()),
"inference_time_ms": inference_time,
}
except Exception as e:
logger.error(f"Volume segmentation failed: {e}")
return {"success": False, "error": str(e)}
# Apply ZeroGPU decorator if available
if ZEROGPU_AVAILABLE:
@spaces.GPU(duration=60)
def segment_slice_api(
nifti_b64: str,
slice_idx: int,
text_prompt: str = "tumor",
checkpoint: str = "brain",
):
return _segment_slice_impl(nifti_b64, slice_idx, text_prompt, checkpoint)
@spaces.GPU(duration=300)
def segment_volume_api(
nifti_b64: str,
text_prompt: str = "tumor",
checkpoint: str = "brain",
skip_empty: bool = True,
min_area: int = 50,
):
return _segment_volume_impl(
nifti_b64, text_prompt, checkpoint, skip_empty, min_area
)
else:
segment_slice_api = _segment_slice_impl
segment_volume_api = _segment_volume_impl
# Gradio UI functions (for interactive demo)
def load_and_display_nifti(file):
"""Load NIfTI and return middle slice for display."""
if file is None:
return None, "No file uploaded", 0
try:
import nibabel as nib
nii = nib.load(file.name)
volume = nii.get_fdata()
middle_slice = volume.shape[0] // 2
slice_2d = volume[middle_slice]
# Normalize for display
vmin, vmax = slice_2d.min(), slice_2d.max()
if vmax - vmin > 0:
display = ((slice_2d - vmin) / (vmax - vmin) * 255).astype(np.uint8)
else:
display = np.zeros_like(slice_2d, dtype=np.uint8)
# Convert to RGB
display_rgb = np.stack([display] * 3, axis=-1)
return (
display_rgb,
f"Loaded: {volume.shape}, showing slice {middle_slice}",
volume.shape[0],
)
except Exception as e:
return None, f"Error: {e}", 0
def segment_and_overlay(file, slice_idx: int, text_prompt: str, checkpoint: str):
"""Segment a slice and overlay the mask."""
if file is None:
return None, "Please upload a file first"
try:
# Read file as base64
with open(file.name, "rb") as f:
nifti_b64 = base64.b64encode(f.read()).decode()
# Call segmentation API
result = segment_slice_api(nifti_b64, int(slice_idx), text_prompt, checkpoint)
if not result["success"]:
return None, f"Segmentation failed: {result.get('error', 'Unknown error')}"
# Load original slice for display
import nibabel as nib
nii = nib.load(file.name)
volume = nii.get_fdata()
slice_2d = volume[int(slice_idx)]
# Normalize for display
vmin, vmax = slice_2d.min(), slice_2d.max()
if vmax - vmin > 0:
display = ((slice_2d - vmin) / (vmax - vmin) * 255).astype(np.uint8)
else:
display = np.zeros_like(slice_2d, dtype=np.uint8)
# Decode mask
mask_bytes = base64.b64decode(result["mask_b64"])
mask = np.frombuffer(mask_bytes, dtype=np.uint8).reshape(result["mask_shape"])
# Create overlay
rgb = np.stack([display] * 3, axis=-1).astype(np.float32)
mask_bool = mask > 0
alpha = 0.4
rgb[mask_bool, 0] = rgb[mask_bool, 0] * (1 - alpha) + 255 * alpha # Red
rgb[mask_bool, 1] = rgb[mask_bool, 1] * (1 - alpha) + 50 * alpha
rgb[mask_bool, 2] = rgb[mask_bool, 2] * (1 - alpha) + 50 * alpha
info = f"Segmented in {result['inference_time_ms']}ms, mask area: {mask.sum()} pixels"
return rgb.astype(np.uint8), info
except Exception as e:
return None, f"Error: {e}"
# Build Gradio interface
def build_demo():
with gr.Blocks(
title="OncoSeg Inference API",
theme=gr.themes.Soft(),
) as demo:
gr.Markdown("""
# OncoSeg Medical Image Segmentation API
GPU-accelerated segmentation for CT and MRI volumes.
**API Endpoints** (for programmatic access):
- `POST /api/segment_slice_api` - Segment a single slice
- `POST /api/segment_volume_api` - Segment entire volume
**Interactive Demo** below:
""")
with gr.Row():
with gr.Column(scale=1):
file_input = gr.File(
label="Upload NIfTI (.nii, .nii.gz)", file_types=[".nii", ".nii.gz"]
)
checkpoint = gr.Dropdown(
choices=list(CHECKPOINTS.keys()),
value="brain",
label="Model Checkpoint",
)
text_prompt = gr.Textbox(
value="tumor",
label="Text Prompt",
placeholder="e.g., tumor, lesion, mass",
)
slice_idx = gr.Slider(
minimum=0,
maximum=200,
value=77,
step=1,
label="Slice Index",
)
segment_btn = gr.Button("Segment Slice", variant="primary")
with gr.Column(scale=2):
output_image = gr.Image(label="Segmentation Result", type="numpy")
status_text = gr.Textbox(label="Status", interactive=False)
# Event handlers
file_input.change(
fn=load_and_display_nifti,
inputs=[file_input],
outputs=[output_image, status_text, slice_idx],
)
segment_btn.click(
fn=segment_and_overlay,
inputs=[file_input, slice_idx, text_prompt, checkpoint],
outputs=[output_image, status_text],
)
gr.Markdown("""
---
### API Usage Example
```python
import requests
import base64
# Read NIfTI file
with open("brain.nii.gz", "rb") as f:
nifti_b64 = base64.b64encode(f.read()).decode()
# Call API
response = requests.post(
"https://YOUR-SPACE.hf.space/api/segment_slice_api",
json={
"nifti_b64": nifti_b64,
"slice_idx": 77,
"text_prompt": "tumor",
"checkpoint": "brain",
}
)
result = response.json()
# result["contours"] contains the segmentation contours
```
""")
return demo
# Launch
if __name__ == "__main__":
demo = build_demo()
demo.queue()
demo.launch(server_name="0.0.0.0", server_port=7860)