|
|
""" |
|
|
NeuroSAM 3: Medical Image Segmentation App |
|
|
A Gradio app for segmenting medical images (CT/MRI) using SAM 3 |
|
|
""" |
|
|
|
|
|
import os |
|
|
import tempfile |
|
|
import zipfile |
|
|
import io |
|
|
import json |
|
|
import time |
|
|
from datetime import datetime |
|
|
import gradio as gr |
|
|
import spaces |
|
|
import torch |
|
|
import pydicom |
|
|
import numpy as np |
|
|
from PIL import Image, ImageEnhance, ImageDraw |
|
|
try: |
|
|
from transformers import Sam3Processor, Sam3Model |
|
|
SAM3_AVAILABLE = True |
|
|
except ImportError: |
|
|
print("⚠️ Warning: Sam3Processor/Sam3Model not found in transformers.") |
|
|
print("⚠️ SAM3 requires transformers from GitHub main branch.") |
|
|
print("⚠️ Install with: pip install git+https://github.com/huggingface/transformers.git") |
|
|
SAM3_AVAILABLE = False |
|
|
|
|
|
Sam3Processor = None |
|
|
Sam3Model = None |
|
|
import matplotlib.pyplot as plt |
|
|
from matplotlib.patches import Rectangle |
|
|
from scipy import ndimage |
|
|
from huggingface_hub import login |
|
|
|
|
|
|
|
|
try: |
|
|
import nibabel as nib |
|
|
NIBABEL_AVAILABLE = True |
|
|
except ImportError: |
|
|
NIBABEL_AVAILABLE = False |
|
|
print("⚠️ nibabel not available - NIFTI export disabled") |
|
|
|
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
if not hf_token: |
|
|
print("⚠️ WARNING: HF_TOKEN environment variable not set!") |
|
|
print("⚠️ Some features may not work. Please set HF_TOKEN in Space settings.") |
|
|
hf_token = None |
|
|
else: |
|
|
|
|
|
try: |
|
|
login(token=hf_token, add_to_git_credential=False) |
|
|
except Exception as e: |
|
|
print(f"⚠️ Could not login to HF Hub (non-critical): {e}") |
|
|
|
|
|
|
|
|
print("🧠 Loading SAM 3 Model...") |
|
|
|
|
|
|
|
|
model = None |
|
|
processor = None |
|
|
|
|
|
if not SAM3_AVAILABLE: |
|
|
print("❌ SAM 3 classes not available in transformers library.") |
|
|
print("❌ Install with: pip install git+https://github.com/huggingface/transformers.git") |
|
|
print("⚠️ App will start but segmentation features will be disabled.") |
|
|
else: |
|
|
|
|
|
SAM_MODEL_ID = "facebook/sam3" |
|
|
|
|
|
if hf_token is None: |
|
|
print("⚠️ Cannot load model: HF_TOKEN not set") |
|
|
model = None |
|
|
processor = None |
|
|
else: |
|
|
try: |
|
|
|
|
|
|
|
|
model = Sam3Model.from_pretrained( |
|
|
SAM_MODEL_ID, |
|
|
torch_dtype=torch.float32, |
|
|
token=hf_token |
|
|
) |
|
|
processor = Sam3Processor.from_pretrained(SAM_MODEL_ID, token=hf_token) |
|
|
model.eval() |
|
|
print(f"✅ SAM 3 Model Loaded Successfully on CPU! ({SAM_MODEL_ID})") |
|
|
print("💡 Model will be moved to GPU when inference is called") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Failed to load SAM 3 model: {e}") |
|
|
print("Ensure you have:") |
|
|
print(" 1. transformers from GitHub main branch for SAM 3 support") |
|
|
print(" Install with: pip install git+https://github.com/huggingface/transformers.git") |
|
|
print(" 2. Valid Hugging Face token with access to SAM 3") |
|
|
print(" 3. Sufficient memory for the model") |
|
|
print("⚠️ App will start but segmentation features will be disabled until model loads.") |
|
|
|
|
|
model = None |
|
|
processor = None |
|
|
|
|
|
@spaces.GPU(duration=60) |
|
|
def run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0): |
|
|
""" |
|
|
Run SAM 3 inference - optimized for medical imaging. |
|
|
|
|
|
Args: |
|
|
pil_image: PIL Image to segment |
|
|
prompt_text: Text prompt for segmentation (e.g., "brain", "tumor", "skull") |
|
|
threshold: Detection confidence threshold, range [0.0, 1.0] (default 0.1 for medical images). |
|
|
Lower values (0.0-0.3) are more permissive and better for subtle features. |
|
|
Higher values (0.5-1.0) require high confidence, may miss detections. |
|
|
mask_threshold: Mask binarization threshold, range [0.0, 1.0] (default 0.0 for medical images). |
|
|
Lower values preserve more detail. Higher values create sharper masks. |
|
|
Medical images often benefit from 0.0 to capture subtle boundaries. |
|
|
|
|
|
Returns: |
|
|
results dict with 'masks' and 'scores' as numpy arrays or lists, or None if failed |
|
|
|
|
|
Note: |
|
|
Default thresholds (0.1, 0.0) are optimized for medical imaging where features |
|
|
may be subtle or low-contrast. For natural images, higher thresholds (0.5, 0.5) |
|
|
may be more appropriate. |
|
|
""" |
|
|
if model is None or processor is None: |
|
|
print("❌ Model not loaded - please check HF_TOKEN and model availability") |
|
|
raise ValueError("SAM 3 model not loaded. Please check that HF_TOKEN is set correctly and the model is accessible.") |
|
|
|
|
|
def to_serializable(obj): |
|
|
""" |
|
|
Convert all tensors to numpy arrays or Python primitives for safe serialization. |
|
|
This ensures NO PyTorch tensors (CPU or CUDA) are in the return value. |
|
|
""" |
|
|
if isinstance(obj, torch.Tensor): |
|
|
|
|
|
result = obj.cpu().numpy() |
|
|
print(f"🔄 Converted tensor to numpy: shape={result.shape}, dtype={result.dtype}") |
|
|
return result |
|
|
elif isinstance(obj, dict): |
|
|
return {k: to_serializable(v) for k, v in obj.items()} |
|
|
elif isinstance(obj, list): |
|
|
return [to_serializable(item) for item in obj] |
|
|
elif isinstance(obj, tuple): |
|
|
return tuple(to_serializable(item) for item in obj) |
|
|
elif isinstance(obj, (int, float, str, bool, type(None))): |
|
|
return obj |
|
|
elif hasattr(obj, 'item'): |
|
|
return obj.item() |
|
|
else: |
|
|
|
|
|
print(f"⚠️ Unknown type encountered: {type(obj)}, converting to string") |
|
|
return str(obj) |
|
|
|
|
|
try: |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"🔧 Using device: {device}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
model.to(device=device, dtype=dtype) |
|
|
print(f"✅ Model moved to {device} with dtype {dtype}") |
|
|
|
|
|
|
|
|
inputs = processor(images=pil_image, text=prompt_text.strip(), return_tensors="pt").to(device) |
|
|
|
|
|
|
|
|
for key in inputs: |
|
|
if isinstance(inputs[key], torch.Tensor) and inputs[key].dtype == torch.float32: |
|
|
inputs[key] = inputs[key].to(model.dtype) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
|
|
|
print(f"🧠 Inference complete, processing results...") |
|
|
|
|
|
|
|
|
results = processor.post_process_instance_segmentation( |
|
|
outputs, |
|
|
threshold=threshold, |
|
|
mask_threshold=mask_threshold, |
|
|
target_sizes=inputs.get("original_sizes").tolist() if "original_sizes" in inputs else [pil_image.size[::-1]] |
|
|
)[0] |
|
|
|
|
|
print(f"📊 Results type: {type(results)}") |
|
|
if isinstance(results, dict): |
|
|
print(f"📊 Results keys: {results.keys()}") |
|
|
for key, value in results.items(): |
|
|
print(f" - {key}: type={type(value)}") |
|
|
if isinstance(value, torch.Tensor): |
|
|
print(f" tensor device={value.device}, shape={value.shape}, dtype={value.dtype}") |
|
|
elif isinstance(value, list) and len(value) > 0: |
|
|
print(f" list length={len(value)}, first item type={type(value[0])}") |
|
|
if isinstance(value[0], torch.Tensor): |
|
|
print(f" first tensor device={value[0].device}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"🔄 Converting all tensors to numpy arrays...") |
|
|
results = to_serializable(results) |
|
|
|
|
|
print(f"✅ All tensors converted to serializable format") |
|
|
|
|
|
|
|
|
model.to("cpu") |
|
|
print(f"✅ Model moved back to CPU") |
|
|
|
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Error during SAM 3 inference: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
if model is not None: |
|
|
try: |
|
|
model.to("cpu") |
|
|
except RuntimeError as cleanup_error: |
|
|
print(f"⚠️ Could not move model back to CPU: {cleanup_error}") |
|
|
return None |
|
|
|
|
|
|
|
|
demo_dicom_path = "demo_brain_mri.dcm" |
|
|
demo_file_available = False |
|
|
|
|
|
try: |
|
|
from pydicom.data import get_testdata_file |
|
|
test_file = get_testdata_file("MR_small.dcm") |
|
|
if test_file and os.path.exists(test_file): |
|
|
import shutil |
|
|
shutil.copy(test_file, demo_dicom_path) |
|
|
demo_file_available = True |
|
|
print(f"✅ Demo file ready: {demo_dicom_path}") |
|
|
except: |
|
|
try: |
|
|
|
|
|
from pydicom.dataset import FileDataset, FileMetaDataset |
|
|
from pydicom.uid import generate_uid |
|
|
from datetime import datetime |
|
|
|
|
|
synthetic_image = np.random.randint(0, 255, (256, 256), dtype=np.uint16) |
|
|
center_x, center_y = 128, 128 |
|
|
y, x = np.ogrid[:256, :256] |
|
|
mask = (x - center_x)**2 + (y - center_y)**2 <= 100**2 |
|
|
synthetic_image[mask] = np.clip(synthetic_image[mask] + 50, 0, 255) |
|
|
|
|
|
file_meta = FileMetaDataset() |
|
|
file_meta.MediaStorageSOPClassUID = '1.2.840.10008.5.1.4.1.1.4' |
|
|
file_meta.MediaStorageSOPInstanceUID = generate_uid() |
|
|
file_meta.TransferSyntaxUID = '1.2.840.10008.1.2.1' |
|
|
|
|
|
ds = FileDataset(demo_dicom_path, {}, file_meta=file_meta, preamble=b"\x00" * 128) |
|
|
ds.PatientName = "Demo^Patient" |
|
|
ds.PatientID = "DEMO001" |
|
|
ds.Modality = "MR" |
|
|
ds.Rows = 256 |
|
|
ds.Columns = 256 |
|
|
ds.BitsAllocated = 16 |
|
|
ds.BitsStored = 16 |
|
|
ds.HighBit = 15 |
|
|
ds.SamplesPerPixel = 1 |
|
|
ds.PixelRepresentation = 0 |
|
|
ds.PhotometricInterpretation = "MONOCHROME2" |
|
|
ds.PixelSpacing = [1.0, 1.0] |
|
|
ds.RescaleIntercept = "0" |
|
|
ds.RescaleSlope = "1" |
|
|
ds.PixelData = synthetic_image.tobytes() |
|
|
|
|
|
ds.save_as(demo_dicom_path, write_like_original=False) |
|
|
demo_file_available = True |
|
|
print(f"✅ Synthetic demo file created: {demo_dicom_path}") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Could not create demo file: {e}") |
|
|
|
|
|
def compare_with_ground_truth(pred_mask, gt_mask_path): |
|
|
"""Compare SAM 3 prediction with ground truth mask and return comparison metrics.""" |
|
|
try: |
|
|
gt_mask = Image.open(gt_mask_path) |
|
|
gt_array = np.array(gt_mask.convert('L')) > 127 |
|
|
|
|
|
|
|
|
if pred_mask.shape != gt_array.shape: |
|
|
from PIL import Image as PILImage |
|
|
pred_pil = PILImage.fromarray((pred_mask * 255).astype(np.uint8)) |
|
|
pred_pil = pred_pil.resize(gt_mask.size, PILImage.NEAREST) |
|
|
pred_mask = np.array(pred_pil) > 127 |
|
|
|
|
|
|
|
|
intersection = np.logical_and(pred_mask, gt_array).sum() |
|
|
union = np.logical_or(pred_mask, gt_array).sum() |
|
|
dice_score = (2.0 * intersection) / (pred_mask.sum() + gt_array.sum()) if (pred_mask.sum() + gt_array.sum()) > 0 else 0.0 |
|
|
iou_score = intersection / union if union > 0 else 0.0 |
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(1, 3, figsize=(15, 5)) |
|
|
|
|
|
axes[0].imshow(pred_mask, cmap='spring') |
|
|
axes[0].set_title('SAM 3 Prediction') |
|
|
axes[0].axis('off') |
|
|
|
|
|
axes[1].imshow(gt_array, cmap='cool') |
|
|
axes[1].set_title('Ground Truth') |
|
|
axes[1].axis('off') |
|
|
|
|
|
|
|
|
comparison = np.zeros((*pred_mask.shape, 3)) |
|
|
comparison[pred_mask & gt_array] = [0, 1, 0] |
|
|
comparison[pred_mask & ~gt_array] = [1, 0, 0] |
|
|
comparison[~pred_mask & gt_array] = [0, 0, 1] |
|
|
|
|
|
axes[2].imshow(comparison) |
|
|
axes[2].set_title(f'Comparison\nDice: {dice_score:.3f}, IoU: {iou_score:.3f}') |
|
|
axes[2].axis('off') |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
|
|
output_path = output_file.name |
|
|
output_file.close() |
|
|
|
|
|
plt.savefig(output_path, bbox_inches='tight', dpi=100) |
|
|
plt.close() |
|
|
|
|
|
return output_path, dice_score, iou_score |
|
|
except Exception as e: |
|
|
print(f"⚠️ Error comparing with ground truth: {e}") |
|
|
return None, 0.0, 0.0 |
|
|
|
|
|
def process_medical_image(image_file, prompt_text, modality, window_type, return_mask=False): |
|
|
"""Process a DICOM or standard image file (PNG/JPG) and perform segmentation using SAM 3. |
|
|
|
|
|
Args: |
|
|
image_file: Path to image file |
|
|
prompt_text: Text prompt for segmentation |
|
|
modality: CT or MRI |
|
|
window_type: Windowing strategy |
|
|
return_mask: If True, also return the binary mask array |
|
|
|
|
|
Returns: |
|
|
Path to output image, and optionally the mask array |
|
|
""" |
|
|
if model is None or processor is None: |
|
|
print("❌ Error: Model not loaded.") |
|
|
return None |
|
|
|
|
|
if image_file is None: |
|
|
return None |
|
|
|
|
|
if not prompt_text or not prompt_text.strip(): |
|
|
prompt_text = "brain" |
|
|
|
|
|
try: |
|
|
file_path = image_file if isinstance(image_file, str) else str(image_file) |
|
|
|
|
|
if not os.path.exists(file_path): |
|
|
print(f"❌ Error: File not found at {file_path}") |
|
|
return None |
|
|
|
|
|
|
|
|
file_ext = os.path.splitext(file_path)[1].lower() |
|
|
is_dicom = file_ext == '.dcm' |
|
|
|
|
|
if is_dicom: |
|
|
|
|
|
ds = pydicom.dcmread(file_path) |
|
|
|
|
|
if not hasattr(ds, 'pixel_array'): |
|
|
print("❌ Error: DICOM file does not contain pixel data.") |
|
|
return None |
|
|
|
|
|
raw = ds.pixel_array.astype(np.float32) |
|
|
slope = getattr(ds, 'RescaleSlope', 1) |
|
|
intercept = getattr(ds, 'RescaleIntercept', 0) |
|
|
img_hu = raw * slope + intercept |
|
|
|
|
|
|
|
|
if modality == "CT": |
|
|
if window_type == "Brain (Grey Matter)": |
|
|
level, width = 40, 80 |
|
|
elif window_type == "Bone (Skull)": |
|
|
level, width = 500, 2000 |
|
|
else: |
|
|
level, width = 40, 400 |
|
|
img_min = level - (width / 2) |
|
|
img_max = level + (width / 2) |
|
|
else: |
|
|
img_min = np.percentile(img_hu, 1) |
|
|
img_max = np.percentile(img_hu, 99) |
|
|
|
|
|
img_range = img_max - img_min |
|
|
if img_range <= 0: |
|
|
img_min = np.min(img_hu) |
|
|
img_max = np.max(img_hu) |
|
|
img_range = img_max - img_min |
|
|
if img_range <= 0: |
|
|
return None |
|
|
|
|
|
img_windowed = (img_hu - img_min) / img_range |
|
|
img_windowed = np.clip(img_windowed, 0, 1) |
|
|
|
|
|
img_uint8 = (img_windowed * 255).astype(np.uint8) |
|
|
|
|
|
if len(img_uint8.shape) == 2: |
|
|
pil_image = Image.fromarray(img_uint8).convert('RGB') |
|
|
else: |
|
|
pil_image = Image.fromarray(img_uint8) |
|
|
else: |
|
|
|
|
|
pil_image = Image.open(file_path) |
|
|
|
|
|
|
|
|
if pil_image.mode != 'RGB': |
|
|
pil_image = pil_image.convert('RGB') |
|
|
|
|
|
|
|
|
img_array = np.array(pil_image) |
|
|
|
|
|
|
|
|
if len(img_array.shape) == 2: |
|
|
img_array = np.stack([img_array] * 3, axis=-1) |
|
|
|
|
|
|
|
|
img_float = img_array.astype(np.float32) |
|
|
if modality == "CT": |
|
|
|
|
|
if window_type == "Brain (Grey Matter)": |
|
|
level, width = 40, 80 |
|
|
elif window_type == "Bone (Skull)": |
|
|
level, width = 500, 2000 |
|
|
else: |
|
|
level, width = 40, 400 |
|
|
img_min = level - (width / 2) |
|
|
img_max = level + (width / 2) |
|
|
else: |
|
|
img_min = np.percentile(img_float, 1) |
|
|
img_max = np.percentile(img_float, 99) |
|
|
|
|
|
img_range = img_max - img_min |
|
|
if img_range <= 0: |
|
|
img_min = np.min(img_float) |
|
|
img_max = np.max(img_float) |
|
|
img_range = img_max - img_min |
|
|
if img_range <= 0: |
|
|
return None |
|
|
|
|
|
img_normalized = (img_float - img_min) / img_range |
|
|
img_normalized = np.clip(img_normalized, 0, 1) |
|
|
img_uint8 = (img_normalized * 255).astype(np.uint8) |
|
|
|
|
|
pil_image = Image.fromarray(img_uint8.astype(np.uint8)) |
|
|
|
|
|
|
|
|
|
|
|
results = run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0) |
|
|
|
|
|
if results is None: |
|
|
return None |
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 10)) |
|
|
plt.imshow(pil_image) |
|
|
|
|
|
final_mask = None |
|
|
if 'masks' in results and results['masks'] is not None: |
|
|
masks = results['masks'] |
|
|
scores = results.get('scores', []) |
|
|
|
|
|
if len(masks) > 0: |
|
|
|
|
|
|
|
|
mask_arrays = [] |
|
|
for mask in masks: |
|
|
if isinstance(mask, torch.Tensor): |
|
|
mask_np = mask.cpu().numpy() |
|
|
else: |
|
|
mask_np = np.array(mask) |
|
|
mask_arrays.append(mask_np) |
|
|
|
|
|
|
|
|
if len(mask_arrays) > 0: |
|
|
final_mask = np.any(mask_arrays, axis=0) |
|
|
plt.imshow(final_mask, alpha=0.5, cmap='spring') |
|
|
else: |
|
|
print("⚠️ Warning: No valid masks found.") |
|
|
else: |
|
|
print("⚠️ Warning: No masks in results.") |
|
|
else: |
|
|
print("⚠️ Warning: No masks in results.") |
|
|
|
|
|
plt.axis('off') |
|
|
plt.title(f"Segmentation: {prompt_text}", fontsize=12, pad=10) |
|
|
|
|
|
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
|
|
output_path = output_file.name |
|
|
output_file.close() |
|
|
|
|
|
plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100) |
|
|
plt.close() |
|
|
|
|
|
if return_mask: |
|
|
return output_path, final_mask |
|
|
return output_path |
|
|
|
|
|
except pydicom.errors.InvalidDicomError as e: |
|
|
print(f"❌ Error: Invalid DICOM file format. {e}") |
|
|
return None |
|
|
except Exception as e: |
|
|
print(f"❌ Error processing image: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None |
|
|
|
|
|
def process_medical_image_enhanced(image_file, prompt_text, modality, window_type, |
|
|
brightness=1.0, contrast=1.0, colormap='spring', |
|
|
transparency=0.5, return_mask=False): |
|
|
"""Enhanced version of process_medical_image with image adjustments and visualization options. |
|
|
|
|
|
Args: |
|
|
image_file: Path to image file |
|
|
prompt_text: Text prompt for segmentation |
|
|
modality: CT or MRI |
|
|
window_type: Windowing strategy |
|
|
brightness: Brightness multiplier (0.5-2.0) |
|
|
contrast: Contrast multiplier (0.5-2.0) |
|
|
colormap: Matplotlib colormap name |
|
|
transparency: Mask overlay transparency (0.0-1.0) |
|
|
return_mask: If True, also return the binary mask array |
|
|
|
|
|
Returns: |
|
|
Path to output image, and optionally the mask array |
|
|
""" |
|
|
if model is None or processor is None: |
|
|
print("❌ Error: Model not loaded.") |
|
|
return None |
|
|
|
|
|
if image_file is None: |
|
|
return None |
|
|
|
|
|
if not prompt_text or not prompt_text.strip(): |
|
|
prompt_text = "brain" |
|
|
|
|
|
try: |
|
|
file_path = image_file if isinstance(image_file, str) else str(image_file) |
|
|
|
|
|
if not os.path.exists(file_path): |
|
|
print(f"❌ Error: File not found at {file_path}") |
|
|
return None |
|
|
|
|
|
|
|
|
file_ext = os.path.splitext(file_path)[1].lower() |
|
|
is_dicom = file_ext == '.dcm' |
|
|
|
|
|
if is_dicom: |
|
|
|
|
|
ds = pydicom.dcmread(file_path) |
|
|
|
|
|
if not hasattr(ds, 'pixel_array'): |
|
|
print("❌ Error: DICOM file does not contain pixel data.") |
|
|
return None |
|
|
|
|
|
raw = ds.pixel_array.astype(np.float32) |
|
|
slope = getattr(ds, 'RescaleSlope', 1) |
|
|
intercept = getattr(ds, 'RescaleIntercept', 0) |
|
|
img_hu = raw * slope + intercept |
|
|
|
|
|
|
|
|
if modality == "CT": |
|
|
if window_type == "Brain (Grey Matter)": |
|
|
level, width = 40, 80 |
|
|
elif window_type == "Bone (Skull)": |
|
|
level, width = 500, 2000 |
|
|
else: |
|
|
level, width = 40, 400 |
|
|
img_min = level - (width / 2) |
|
|
img_max = level + (width / 2) |
|
|
else: |
|
|
img_min = np.percentile(img_hu, 1) |
|
|
img_max = np.percentile(img_hu, 99) |
|
|
|
|
|
img_range = img_max - img_min |
|
|
if img_range <= 0: |
|
|
img_min = np.min(img_hu) |
|
|
img_max = np.max(img_hu) |
|
|
img_range = img_max - img_min |
|
|
if img_range <= 0: |
|
|
return None |
|
|
|
|
|
img_windowed = (img_hu - img_min) / img_range |
|
|
img_windowed = np.clip(img_windowed, 0, 1) |
|
|
|
|
|
img_uint8 = (img_windowed * 255).astype(np.uint8) |
|
|
|
|
|
if len(img_uint8.shape) == 2: |
|
|
pil_image = Image.fromarray(img_uint8).convert('RGB') |
|
|
else: |
|
|
pil_image = Image.fromarray(img_uint8) |
|
|
else: |
|
|
|
|
|
pil_image = Image.open(file_path) |
|
|
|
|
|
|
|
|
if pil_image.mode != 'RGB': |
|
|
pil_image = pil_image.convert('RGB') |
|
|
|
|
|
|
|
|
img_array = np.array(pil_image) |
|
|
|
|
|
|
|
|
if len(img_array.shape) == 2: |
|
|
img_array = np.stack([img_array] * 3, axis=-1) |
|
|
|
|
|
|
|
|
img_float = img_array.astype(np.float32) |
|
|
if modality == "CT": |
|
|
|
|
|
if window_type == "Brain (Grey Matter)": |
|
|
level, width = 40, 80 |
|
|
elif window_type == "Bone (Skull)": |
|
|
level, width = 500, 2000 |
|
|
else: |
|
|
level, width = 40, 400 |
|
|
img_min = level - (width / 2) |
|
|
img_max = level + (width / 2) |
|
|
else: |
|
|
img_min = np.percentile(img_float, 1) |
|
|
img_max = np.percentile(img_float, 99) |
|
|
|
|
|
img_range = img_max - img_min |
|
|
if img_range <= 0: |
|
|
img_min = np.min(img_float) |
|
|
img_max = np.max(img_float) |
|
|
img_range = img_max - img_min |
|
|
if img_range <= 0: |
|
|
return None |
|
|
|
|
|
img_normalized = (img_float - img_min) / img_range |
|
|
img_normalized = np.clip(img_normalized, 0, 1) |
|
|
img_uint8 = (img_normalized * 255).astype(np.uint8) |
|
|
|
|
|
pil_image = Image.fromarray(img_uint8.astype(np.uint8)) |
|
|
|
|
|
|
|
|
enhancer = ImageEnhance.Brightness(pil_image) |
|
|
pil_image = enhancer.enhance(brightness) |
|
|
enhancer = ImageEnhance.Contrast(pil_image) |
|
|
pil_image = enhancer.enhance(contrast) |
|
|
|
|
|
|
|
|
|
|
|
results = run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0) |
|
|
|
|
|
if results is None: |
|
|
return None |
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 10)) |
|
|
plt.imshow(pil_image) |
|
|
|
|
|
final_mask = None |
|
|
if 'masks' in results and results['masks'] is not None: |
|
|
masks = results['masks'] |
|
|
scores = results.get('scores', []) |
|
|
|
|
|
if len(masks) > 0: |
|
|
|
|
|
mask_arrays = [] |
|
|
for mask in masks: |
|
|
if isinstance(mask, torch.Tensor): |
|
|
mask_np = mask.cpu().numpy() |
|
|
else: |
|
|
mask_np = np.array(mask) |
|
|
mask_arrays.append(mask_np) |
|
|
|
|
|
|
|
|
if len(mask_arrays) > 0: |
|
|
final_mask = np.any(mask_arrays, axis=0) |
|
|
plt.imshow(final_mask, alpha=transparency, cmap=colormap) |
|
|
else: |
|
|
print("⚠️ Warning: No valid masks found.") |
|
|
else: |
|
|
print("⚠️ Warning: No masks in results.") |
|
|
else: |
|
|
print("⚠️ Warning: No masks in results.") |
|
|
|
|
|
plt.axis('off') |
|
|
plt.title(f"Segmentation: {prompt_text}", fontsize=12, pad=10) |
|
|
|
|
|
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
|
|
output_path = output_file.name |
|
|
output_file.close() |
|
|
|
|
|
plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100) |
|
|
plt.close() |
|
|
|
|
|
if return_mask: |
|
|
return output_path, final_mask |
|
|
return output_path |
|
|
|
|
|
except pydicom.errors.InvalidDicomError as e: |
|
|
print(f"❌ Error: Invalid DICOM file format. {e}") |
|
|
return None |
|
|
except Exception as e: |
|
|
print(f"❌ Error processing image: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None |
|
|
|
|
|
def process_with_progress(image_file, prompt_text, modality, window_type, |
|
|
brightness=1.0, contrast=1.0, colormap='spring', |
|
|
transparency=0.5, progress=gr.Progress()): |
|
|
"""Process with progress indicator.""" |
|
|
if model is None or processor is None: |
|
|
return None, "❌ Error: Model not loaded.", "" |
|
|
|
|
|
if image_file is None: |
|
|
return None, "⚠️ Please upload a medical image file.", "" |
|
|
|
|
|
progress(0, desc="Starting...") |
|
|
progress(0.1, desc="Reading image...") |
|
|
|
|
|
result = process_medical_image_enhanced( |
|
|
image_file, prompt_text, modality, window_type, |
|
|
brightness, contrast, colormap, transparency |
|
|
) |
|
|
|
|
|
progress(0.8, desc="Processing segmentation...") |
|
|
|
|
|
if result is None: |
|
|
progress(1.0, desc="Failed!") |
|
|
return None, "❌ Processing failed. Check console for error details.", "" |
|
|
|
|
|
progress(1.0, desc="Complete!") |
|
|
|
|
|
|
|
|
metrics = f"Prompt: {prompt_text}\nModality: {modality}\nProcessed: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" |
|
|
|
|
|
return result, "✅ Segmentation complete!", metrics |
|
|
|
|
|
def create_downloadable_result(output_path, prompt_text): |
|
|
"""Create a downloadable file with metadata.""" |
|
|
return output_path |
|
|
|
|
|
def process_batch_enhanced(image_files, prompt_text, modality, window_type, |
|
|
brightness=1.0, contrast=1.0, colormap='spring', |
|
|
transparency=0.5, progress=gr.Progress()): |
|
|
"""Process multiple images with enhanced features and create ZIP download.""" |
|
|
if model is None or processor is None: |
|
|
return [], None, "❌ Error: Model not loaded." |
|
|
|
|
|
if not image_files: |
|
|
return [], None, "⚠️ Please upload medical image files." |
|
|
|
|
|
|
|
|
if isinstance(image_files, str): |
|
|
image_files = [image_files] |
|
|
|
|
|
results = [] |
|
|
total = len(image_files) |
|
|
|
|
|
for idx, image_file in enumerate(image_files): |
|
|
if image_file is None: |
|
|
continue |
|
|
|
|
|
progress((idx + 1) / total, desc=f"Processing image {idx + 1}/{total}...") |
|
|
|
|
|
result = process_medical_image_enhanced( |
|
|
image_file, prompt_text, modality, window_type, |
|
|
brightness, contrast, colormap, transparency |
|
|
) |
|
|
|
|
|
if result: |
|
|
results.append(result) |
|
|
|
|
|
if not results: |
|
|
return [], None, "❌ No images were processed successfully." |
|
|
|
|
|
|
|
|
zip_buffer = io.BytesIO() |
|
|
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file: |
|
|
for idx, result_path in enumerate(results): |
|
|
if os.path.exists(result_path): |
|
|
filename = f"segmentation_{idx + 1}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png" |
|
|
zip_file.write(result_path, filename) |
|
|
|
|
|
zip_buffer.seek(0) |
|
|
zip_path = tempfile.NamedTemporaryFile(delete=False, suffix='.zip') |
|
|
zip_path.write(zip_buffer.read()) |
|
|
zip_path.close() |
|
|
|
|
|
status = f"✅ Processed {len(results)}/{total} images successfully!\nZIP file ready for download." |
|
|
return results, zip_path.name, status |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto_play_state = {"running": False, "current_idx": 0} |
|
|
|
|
|
def calculate_roi_statistics(image_file, mask, modality): |
|
|
"""Calculate ROI statistics from the segmented region. |
|
|
|
|
|
Returns: |
|
|
dict: Statistics including area, mean intensity, std, min, max, centroid |
|
|
""" |
|
|
if mask is None or not isinstance(mask, np.ndarray): |
|
|
return { |
|
|
"error": "No valid mask available", |
|
|
"area_pixels": 0, |
|
|
"area_percentage": 0, |
|
|
"mean_intensity": 0, |
|
|
"std_intensity": 0, |
|
|
"min_intensity": 0, |
|
|
"max_intensity": 0, |
|
|
"centroid": (0, 0), |
|
|
"bounding_box": (0, 0, 0, 0) |
|
|
} |
|
|
|
|
|
try: |
|
|
|
|
|
file_path = image_file if isinstance(image_file, str) else str(image_file) |
|
|
file_ext = os.path.splitext(file_path)[1].lower() |
|
|
|
|
|
if file_ext == '.dcm': |
|
|
ds = pydicom.dcmread(file_path) |
|
|
img_array = ds.pixel_array.astype(np.float32) |
|
|
slope = getattr(ds, 'RescaleSlope', 1) |
|
|
intercept = getattr(ds, 'RescaleIntercept', 0) |
|
|
img_array = img_array * slope + intercept |
|
|
else: |
|
|
img = Image.open(file_path) |
|
|
if img.mode == 'RGB': |
|
|
img = img.convert('L') |
|
|
img_array = np.array(img).astype(np.float32) |
|
|
|
|
|
|
|
|
if mask.shape != img_array.shape: |
|
|
from scipy.ndimage import zoom |
|
|
zoom_factors = (img_array.shape[0] / mask.shape[0], img_array.shape[1] / mask.shape[1]) |
|
|
mask = zoom(mask.astype(float), zoom_factors, order=0) > 0.5 |
|
|
|
|
|
|
|
|
mask_bool = mask.astype(bool) |
|
|
total_pixels = mask.size |
|
|
roi_pixels = np.sum(mask_bool) |
|
|
|
|
|
if roi_pixels == 0: |
|
|
return { |
|
|
"error": "No pixels in ROI", |
|
|
"area_pixels": 0, |
|
|
"area_percentage": 0, |
|
|
"mean_intensity": 0, |
|
|
"std_intensity": 0, |
|
|
"min_intensity": 0, |
|
|
"max_intensity": 0, |
|
|
"centroid": (0, 0), |
|
|
"bounding_box": (0, 0, 0, 0) |
|
|
} |
|
|
|
|
|
roi_intensities = img_array[mask_bool] |
|
|
|
|
|
|
|
|
labeled_mask, num_features = ndimage.label(mask_bool) |
|
|
centroid = ndimage.center_of_mass(mask_bool) |
|
|
|
|
|
|
|
|
rows = np.any(mask_bool, axis=1) |
|
|
cols = np.any(mask_bool, axis=0) |
|
|
rmin, rmax = np.where(rows)[0][[0, -1]] |
|
|
cmin, cmax = np.where(cols)[0][[0, -1]] |
|
|
|
|
|
stats = { |
|
|
"area_pixels": int(roi_pixels), |
|
|
"area_percentage": float(roi_pixels / total_pixels * 100), |
|
|
"mean_intensity": float(np.mean(roi_intensities)), |
|
|
"std_intensity": float(np.std(roi_intensities)), |
|
|
"min_intensity": float(np.min(roi_intensities)), |
|
|
"max_intensity": float(np.max(roi_intensities)), |
|
|
"centroid": (float(centroid[1]), float(centroid[0])), |
|
|
"bounding_box": (int(cmin), int(rmin), int(cmax), int(rmax)), |
|
|
"num_components": num_features |
|
|
} |
|
|
|
|
|
|
|
|
if modality == "CT": |
|
|
stats["mean_hu"] = stats["mean_intensity"] |
|
|
stats["std_hu"] = stats["std_intensity"] |
|
|
|
|
|
return stats |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error calculating ROI statistics: {e}") |
|
|
return {"error": str(e)} |
|
|
|
|
|
def format_roi_statistics(stats): |
|
|
"""Format ROI statistics as a readable string.""" |
|
|
if "error" in stats and stats.get("area_pixels", 0) == 0: |
|
|
return f"⚠️ {stats.get('error', 'No statistics available')}" |
|
|
|
|
|
text = "📊 **ROI Statistics**\n\n" |
|
|
text += f"**Area:** {stats['area_pixels']:,} pixels ({stats['area_percentage']:.2f}%)\n" |
|
|
text += f"**Intensity:** {stats['mean_intensity']:.2f} ± {stats['std_intensity']:.2f}\n" |
|
|
text += f"**Range:** [{stats['min_intensity']:.2f}, {stats['max_intensity']:.2f}]\n" |
|
|
text += f"**Centroid:** ({stats['centroid'][0]:.1f}, {stats['centroid'][1]:.1f})\n" |
|
|
text += f"**Bounding Box:** {stats['bounding_box']}\n" |
|
|
text += f"**Components:** {stats.get('num_components', 1)}" |
|
|
|
|
|
if "mean_hu" in stats: |
|
|
text += f"\n\n**CT (Hounsfield Units):**\n" |
|
|
text += f"Mean HU: {stats['mean_hu']:.1f} ± {stats['std_hu']:.1f}" |
|
|
|
|
|
return text |
|
|
|
|
|
def process_with_roi_stats(image_file, prompt_text, modality, window_type): |
|
|
"""Process image and return both segmentation and ROI statistics.""" |
|
|
if model is None or processor is None: |
|
|
return None, "❌ Error: Model not loaded.", "" |
|
|
|
|
|
if image_file is None: |
|
|
return None, "⚠️ Please upload a medical image file.", "" |
|
|
|
|
|
result, mask = process_medical_image(image_file, prompt_text, modality, window_type, return_mask=True) |
|
|
|
|
|
if result is None: |
|
|
return None, "❌ Processing failed.", "" |
|
|
|
|
|
|
|
|
stats = calculate_roi_statistics(image_file, mask, modality) |
|
|
stats_text = format_roi_statistics(stats) |
|
|
|
|
|
return result, "✅ Segmentation complete!", stats_text |
|
|
|
|
|
def process_with_point_prompt(image_file, point_x, point_y, modality, window_type, colormap='spring', transparency=0.5): |
|
|
"""Process image with a point prompt for segmentation. |
|
|
|
|
|
Note: This simulates point-based prompting by using the point location |
|
|
as a seed for region-based segmentation. |
|
|
""" |
|
|
if model is None or processor is None: |
|
|
return None, "❌ Error: Model not loaded." |
|
|
|
|
|
if image_file is None: |
|
|
return None, "⚠️ Please upload a medical image file." |
|
|
|
|
|
try: |
|
|
|
|
|
file_path = image_file if isinstance(image_file, str) else str(image_file) |
|
|
file_ext = os.path.splitext(file_path)[1].lower() |
|
|
|
|
|
if file_ext == '.dcm': |
|
|
ds = pydicom.dcmread(file_path) |
|
|
img_array = ds.pixel_array.astype(np.float32) |
|
|
slope = getattr(ds, 'RescaleSlope', 1) |
|
|
intercept = getattr(ds, 'RescaleIntercept', 0) |
|
|
img_array = img_array * slope + intercept |
|
|
|
|
|
|
|
|
img_min = np.percentile(img_array, 1) |
|
|
img_max = np.percentile(img_array, 99) |
|
|
img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1) |
|
|
img_uint8 = (img_norm * 255).astype(np.uint8) |
|
|
pil_image = Image.fromarray(img_uint8).convert('RGB') |
|
|
else: |
|
|
pil_image = Image.open(file_path).convert('RGB') |
|
|
|
|
|
img_array = np.array(pil_image) |
|
|
h, w = img_array.shape[:2] |
|
|
|
|
|
|
|
|
point_x = max(0, min(int(point_x), w - 1)) |
|
|
point_y = max(0, min(int(point_y), h - 1)) |
|
|
|
|
|
|
|
|
prompt_text = f"segment region at point" |
|
|
|
|
|
|
|
|
|
|
|
results = run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0) |
|
|
|
|
|
final_mask = None |
|
|
if results and 'masks' in results and results['masks'] is not None: |
|
|
masks = results['masks'] |
|
|
|
|
|
for mask in masks: |
|
|
if isinstance(mask, torch.Tensor): |
|
|
mask_np = mask.cpu().numpy() |
|
|
else: |
|
|
mask_np = np.array(mask) |
|
|
|
|
|
|
|
|
mask_resized = np.array(Image.fromarray((mask_np * 255).astype(np.uint8)).resize((w, h))) > 127 |
|
|
if mask_resized[point_y, point_x]: |
|
|
final_mask = mask_resized |
|
|
break |
|
|
|
|
|
|
|
|
if final_mask is None and len(masks) > 0: |
|
|
mask = masks[0] |
|
|
if isinstance(mask, torch.Tensor): |
|
|
mask_np = mask.cpu().numpy() |
|
|
else: |
|
|
mask_np = np.array(mask) |
|
|
final_mask = np.array(Image.fromarray((mask_np * 255).astype(np.uint8)).resize((w, h))) > 127 |
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 10)) |
|
|
plt.imshow(pil_image) |
|
|
|
|
|
if final_mask is not None: |
|
|
plt.imshow(final_mask, alpha=transparency, cmap=colormap) |
|
|
|
|
|
|
|
|
plt.scatter([point_x], [point_y], c='red', s=200, marker='+', linewidths=3) |
|
|
plt.scatter([point_x], [point_y], c='red', s=100, marker='o', facecolors='none', linewidths=2) |
|
|
|
|
|
plt.axis('off') |
|
|
plt.title(f"Point Prompt Segmentation at ({point_x}, {point_y})", fontsize=12) |
|
|
|
|
|
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
|
|
output_path = output_file.name |
|
|
output_file.close() |
|
|
|
|
|
plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100) |
|
|
plt.close() |
|
|
|
|
|
return output_path, f"✅ Point-based segmentation at ({point_x}, {point_y})" |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in point prompt processing: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None, f"❌ Error: {str(e)}" |
|
|
|
|
|
def process_with_box_prompt(image_file, x1, y1, x2, y2, modality, window_type, colormap='spring', transparency=0.5): |
|
|
"""Process image with a bounding box prompt for segmentation.""" |
|
|
if model is None or processor is None: |
|
|
return None, "❌ Error: Model not loaded." |
|
|
|
|
|
if image_file is None: |
|
|
return None, "⚠️ Please upload a medical image file." |
|
|
|
|
|
try: |
|
|
|
|
|
file_path = image_file if isinstance(image_file, str) else str(image_file) |
|
|
file_ext = os.path.splitext(file_path)[1].lower() |
|
|
|
|
|
if file_ext == '.dcm': |
|
|
ds = pydicom.dcmread(file_path) |
|
|
img_array = ds.pixel_array.astype(np.float32) |
|
|
slope = getattr(ds, 'RescaleSlope', 1) |
|
|
intercept = getattr(ds, 'RescaleIntercept', 0) |
|
|
img_array = img_array * slope + intercept |
|
|
|
|
|
img_min = np.percentile(img_array, 1) |
|
|
img_max = np.percentile(img_array, 99) |
|
|
img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1) |
|
|
img_uint8 = (img_norm * 255).astype(np.uint8) |
|
|
pil_image = Image.fromarray(img_uint8).convert('RGB') |
|
|
else: |
|
|
pil_image = Image.open(file_path).convert('RGB') |
|
|
|
|
|
img_array = np.array(pil_image) |
|
|
h, w = img_array.shape[:2] |
|
|
|
|
|
|
|
|
x1, x2 = min(x1, x2), max(x1, x2) |
|
|
y1, y2 = min(y1, y2), max(y1, y2) |
|
|
x1, y1 = max(0, int(x1)), max(0, int(y1)) |
|
|
x2, y2 = min(w, int(x2)), min(h, int(y2)) |
|
|
|
|
|
prompt_text = "segment region in bounding box" |
|
|
|
|
|
|
|
|
|
|
|
results = run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0) |
|
|
|
|
|
final_mask = None |
|
|
if results and 'masks' in results and results['masks'] is not None: |
|
|
masks = results['masks'] |
|
|
|
|
|
mask_arrays = [] |
|
|
for mask in masks: |
|
|
if isinstance(mask, torch.Tensor): |
|
|
mask_np = mask.cpu().numpy() |
|
|
else: |
|
|
mask_np = np.array(mask) |
|
|
|
|
|
mask_resized = np.array(Image.fromarray((mask_np * 255).astype(np.uint8)).resize((w, h))) > 127 |
|
|
mask_arrays.append(mask_resized) |
|
|
|
|
|
if len(mask_arrays) > 0: |
|
|
combined = np.any(mask_arrays, axis=0) |
|
|
|
|
|
|
|
|
box_mask = np.zeros((h, w), dtype=bool) |
|
|
box_mask[y1:y2, x1:x2] = True |
|
|
final_mask = combined & box_mask |
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 10)) |
|
|
plt.imshow(pil_image) |
|
|
|
|
|
if final_mask is not None: |
|
|
plt.imshow(final_mask, alpha=transparency, cmap=colormap) |
|
|
|
|
|
|
|
|
rect = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=3, edgecolor='red', facecolor='none') |
|
|
plt.gca().add_patch(rect) |
|
|
|
|
|
plt.axis('off') |
|
|
plt.title(f"Box Prompt Segmentation [{x1}, {y1}, {x2}, {y2}]", fontsize=12) |
|
|
|
|
|
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
|
|
output_path = output_file.name |
|
|
output_file.close() |
|
|
|
|
|
plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100) |
|
|
plt.close() |
|
|
|
|
|
return output_path, f"✅ Box-based segmentation at [{x1}, {y1}, {x2}, {y2}]" |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in box prompt processing: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None, f"❌ Error: {str(e)}" |
|
|
|
|
|
def process_multi_mask(image_file, prompt_text, modality, window_type, num_masks=3): |
|
|
"""Process image and return multiple mask candidates with confidence scores.""" |
|
|
if model is None or processor is None: |
|
|
return [], "❌ Error: Model not loaded.", "" |
|
|
|
|
|
if image_file is None: |
|
|
return [], "⚠️ Please upload a medical image file.", "" |
|
|
|
|
|
try: |
|
|
file_path = image_file if isinstance(image_file, str) else str(image_file) |
|
|
file_ext = os.path.splitext(file_path)[1].lower() |
|
|
|
|
|
if file_ext == '.dcm': |
|
|
ds = pydicom.dcmread(file_path) |
|
|
img_array = ds.pixel_array.astype(np.float32) |
|
|
slope = getattr(ds, 'RescaleSlope', 1) |
|
|
intercept = getattr(ds, 'RescaleIntercept', 0) |
|
|
img_array = img_array * slope + intercept |
|
|
|
|
|
img_min = np.percentile(img_array, 1) |
|
|
img_max = np.percentile(img_array, 99) |
|
|
img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1) |
|
|
img_uint8 = (img_norm * 255).astype(np.uint8) |
|
|
pil_image = Image.fromarray(img_uint8).convert('RGB') |
|
|
else: |
|
|
pil_image = Image.open(file_path).convert('RGB') |
|
|
|
|
|
if not prompt_text or not prompt_text.strip(): |
|
|
prompt_text = "brain" |
|
|
|
|
|
|
|
|
|
|
|
sam_results = run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0) |
|
|
|
|
|
results = [] |
|
|
mask_info = [] |
|
|
|
|
|
if sam_results and 'masks' in sam_results and sam_results['masks'] is not None: |
|
|
masks = sam_results['masks'] |
|
|
scores = sam_results.get('scores', []) |
|
|
|
|
|
num_available = len(masks) |
|
|
num_to_show = min(num_masks, num_available) |
|
|
|
|
|
colormaps = ['spring', 'cool', 'hot', 'viridis', 'plasma'] |
|
|
|
|
|
for i in range(num_to_show): |
|
|
mask = masks[i] |
|
|
if isinstance(mask, torch.Tensor): |
|
|
mask_np = mask.cpu().numpy() |
|
|
else: |
|
|
mask_np = np.array(mask) |
|
|
|
|
|
|
|
|
if mask_np.dtype != bool: |
|
|
mask_np = mask_np > 0.5 |
|
|
|
|
|
score = scores[i].item() if i < len(scores) and isinstance(scores[i], torch.Tensor) else (scores[i] if i < len(scores) else 0.5) |
|
|
|
|
|
|
|
|
plt.figure(figsize=(8, 8)) |
|
|
plt.imshow(pil_image) |
|
|
plt.imshow(mask_np, alpha=0.5, cmap=colormaps[i % len(colormaps)]) |
|
|
plt.axis('off') |
|
|
plt.title(f"Mask {i+1} - Confidence: {score:.2%}", fontsize=12) |
|
|
|
|
|
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
|
|
output_path = output_file.name |
|
|
output_file.close() |
|
|
|
|
|
plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100) |
|
|
plt.close() |
|
|
|
|
|
results.append(output_path) |
|
|
mask_info.append(f"Mask {i+1}: {score:.2%} confidence, {np.sum(mask_np):,} pixels") |
|
|
|
|
|
status = f"✅ Generated {len(results)} mask candidate(s)" |
|
|
info = "\n".join(mask_info) if mask_info else "No mask information available" |
|
|
|
|
|
return results, status, info |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in multi-mask processing: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return [], f"❌ Error: {str(e)}", "" |
|
|
|
|
|
def export_to_nifti(image_file, mask, output_name="segmentation"): |
|
|
"""Export segmentation mask to NIFTI format. |
|
|
|
|
|
Returns: |
|
|
str: Path to the exported NIFTI file, or None if export failed |
|
|
""" |
|
|
if not NIBABEL_AVAILABLE: |
|
|
return None, "⚠️ NIFTI export not available - nibabel not installed" |
|
|
|
|
|
if mask is None or not isinstance(mask, np.ndarray): |
|
|
return None, "⚠️ No valid mask to export" |
|
|
|
|
|
try: |
|
|
|
|
|
mask_data = mask.astype(np.float32) |
|
|
|
|
|
|
|
|
|
|
|
affine = np.eye(4) |
|
|
|
|
|
|
|
|
if image_file: |
|
|
file_path = image_file if isinstance(image_file, str) else str(image_file) |
|
|
if file_path.lower().endswith('.dcm'): |
|
|
try: |
|
|
ds = pydicom.dcmread(file_path, stop_before_pixels=True) |
|
|
pixel_spacing = getattr(ds, 'PixelSpacing', [1.0, 1.0]) |
|
|
slice_thickness = getattr(ds, 'SliceThickness', 1.0) |
|
|
affine[0, 0] = float(pixel_spacing[0]) |
|
|
affine[1, 1] = float(pixel_spacing[1]) |
|
|
affine[2, 2] = float(slice_thickness) |
|
|
except: |
|
|
pass |
|
|
|
|
|
nifti_img = nib.Nifti1Image(mask_data, affine) |
|
|
|
|
|
|
|
|
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.nii.gz') |
|
|
output_path = output_file.name |
|
|
output_file.close() |
|
|
|
|
|
nib.save(nifti_img, output_path) |
|
|
|
|
|
return output_path, f"✅ Exported to NIFTI: {output_path}" |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error exporting to NIFTI: {e}") |
|
|
return None, f"❌ Export failed: {str(e)}" |
|
|
|
|
|
def save_annotation(image_file, mask, prompt_text, modality, stats=None): |
|
|
"""Save annotation to a JSON file for later loading.""" |
|
|
if mask is None: |
|
|
return None, "⚠️ No annotation to save" |
|
|
|
|
|
try: |
|
|
annotation = { |
|
|
"timestamp": datetime.now().isoformat(), |
|
|
"image_file": os.path.basename(image_file) if image_file else "unknown", |
|
|
"prompt": prompt_text, |
|
|
"modality": modality, |
|
|
"mask_shape": list(mask.shape), |
|
|
"mask_sum": int(np.sum(mask)), |
|
|
"mask_base64": None, |
|
|
"statistics": stats if stats else {} |
|
|
} |
|
|
|
|
|
|
|
|
mask_file = tempfile.NamedTemporaryFile(delete=False, suffix='.npz') |
|
|
mask_path = mask_file.name |
|
|
mask_file.close() |
|
|
np.savez_compressed(mask_path, mask=mask) |
|
|
|
|
|
|
|
|
json_file = tempfile.NamedTemporaryFile(delete=False, suffix='.json', mode='w') |
|
|
json_path = json_file.name |
|
|
annotation["mask_file"] = mask_path |
|
|
json.dump(annotation, json_file, indent=2) |
|
|
json_file.close() |
|
|
|
|
|
|
|
|
zip_buffer = io.BytesIO() |
|
|
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zf: |
|
|
zf.write(json_path, 'annotation.json') |
|
|
zf.write(mask_path, 'mask.npz') |
|
|
|
|
|
zip_buffer.seek(0) |
|
|
zip_file = tempfile.NamedTemporaryFile(delete=False, suffix='.zip') |
|
|
zip_path = zip_file.name |
|
|
zip_file.write(zip_buffer.read()) |
|
|
zip_file.close() |
|
|
|
|
|
return zip_path, f"✅ Annotation saved: {os.path.basename(zip_path)}" |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error saving annotation: {e}") |
|
|
return None, f"❌ Save failed: {str(e)}" |
|
|
|
|
|
def load_annotation(annotation_file): |
|
|
"""Load a previously saved annotation.""" |
|
|
if annotation_file is None: |
|
|
return None, None, "⚠️ No file selected" |
|
|
|
|
|
try: |
|
|
file_path = annotation_file if isinstance(annotation_file, str) else str(annotation_file) |
|
|
|
|
|
if file_path.endswith('.zip'): |
|
|
|
|
|
with zipfile.ZipFile(file_path, 'r') as zf: |
|
|
|
|
|
with zf.open('annotation.json') as f: |
|
|
annotation = json.load(f) |
|
|
|
|
|
|
|
|
mask_temp = tempfile.NamedTemporaryFile(delete=False, suffix='.npz') |
|
|
mask_temp.write(zf.read('mask.npz')) |
|
|
mask_temp.close() |
|
|
|
|
|
mask_data = np.load(mask_temp.name) |
|
|
mask = mask_data['mask'] |
|
|
|
|
|
info = f"📋 **Loaded Annotation**\n" |
|
|
info += f"Image: {annotation.get('image_file', 'unknown')}\n" |
|
|
info += f"Prompt: {annotation.get('prompt', 'N/A')}\n" |
|
|
info += f"Modality: {annotation.get('modality', 'N/A')}\n" |
|
|
info += f"Saved: {annotation.get('timestamp', 'N/A')}\n" |
|
|
info += f"Mask size: {annotation.get('mask_sum', 0):,} pixels" |
|
|
|
|
|
return mask, annotation, info |
|
|
else: |
|
|
return None, None, "⚠️ Invalid file format. Please upload a .zip annotation file." |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading annotation: {e}") |
|
|
return None, None, f"❌ Load failed: {str(e)}" |
|
|
|
|
|
def visualize_loaded_annotation(image_file, annotation_file, colormap='spring', transparency=0.5): |
|
|
"""Visualize a loaded annotation on the original image.""" |
|
|
mask, annotation, info = load_annotation(annotation_file) |
|
|
|
|
|
if mask is None: |
|
|
return None, info |
|
|
|
|
|
if image_file is None: |
|
|
return None, "⚠️ Please upload the original image to visualize" |
|
|
|
|
|
try: |
|
|
file_path = image_file if isinstance(image_file, str) else str(image_file) |
|
|
file_ext = os.path.splitext(file_path)[1].lower() |
|
|
|
|
|
if file_ext == '.dcm': |
|
|
ds = pydicom.dcmread(file_path) |
|
|
img_array = ds.pixel_array.astype(np.float32) |
|
|
slope = getattr(ds, 'RescaleSlope', 1) |
|
|
intercept = getattr(ds, 'RescaleIntercept', 0) |
|
|
img_array = img_array * slope + intercept |
|
|
|
|
|
img_min = np.percentile(img_array, 1) |
|
|
img_max = np.percentile(img_array, 99) |
|
|
img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1) |
|
|
img_uint8 = (img_norm * 255).astype(np.uint8) |
|
|
pil_image = Image.fromarray(img_uint8).convert('RGB') |
|
|
else: |
|
|
pil_image = Image.open(file_path).convert('RGB') |
|
|
|
|
|
|
|
|
w, h = pil_image.size |
|
|
if mask.shape != (h, w): |
|
|
mask = np.array(Image.fromarray(mask.astype(np.uint8) * 255).resize((w, h))) > 127 |
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 10)) |
|
|
plt.imshow(pil_image) |
|
|
plt.imshow(mask, alpha=transparency, cmap=colormap) |
|
|
plt.axis('off') |
|
|
plt.title(f"Loaded Annotation: {annotation.get('prompt', 'N/A')}", fontsize=12) |
|
|
|
|
|
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
|
|
output_path = output_file.name |
|
|
output_file.close() |
|
|
|
|
|
plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=100) |
|
|
plt.close() |
|
|
|
|
|
return output_path, info |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error visualizing annotation: {e}") |
|
|
return None, f"❌ Visualization failed: {str(e)}" |
|
|
|
|
|
|
|
|
last_processed_mask = {"mask": None, "image_file": None, "prompt": None, "modality": None} |
|
|
|
|
|
def process_and_store_mask(image_file, prompt_text, modality, window_type): |
|
|
"""Process image and store mask for export/save operations.""" |
|
|
result, mask = process_medical_image(image_file, prompt_text, modality, window_type, return_mask=True) |
|
|
|
|
|
if result and mask is not None: |
|
|
last_processed_mask["mask"] = mask |
|
|
last_processed_mask["image_file"] = image_file |
|
|
last_processed_mask["prompt"] = prompt_text |
|
|
last_processed_mask["modality"] = modality |
|
|
|
|
|
|
|
|
stats = calculate_roi_statistics(image_file, mask, modality) |
|
|
stats_text = format_roi_statistics(stats) |
|
|
|
|
|
return result, "✅ Segmentation complete! Ready for export.", stats_text |
|
|
else: |
|
|
return result, "❌ Processing failed.", "" |
|
|
|
|
|
def export_last_mask_nifti(): |
|
|
"""Export the last processed mask to NIFTI.""" |
|
|
if last_processed_mask["mask"] is None: |
|
|
return None, "⚠️ No mask to export. Process an image first." |
|
|
|
|
|
return export_to_nifti( |
|
|
last_processed_mask["image_file"], |
|
|
last_processed_mask["mask"] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResizeLongestSide: |
|
|
""" |
|
|
Resizes images to the longest side target length while maintaining aspect ratio. |
|
|
Inspired by SAM-Medical-Imaging transforms.py |
|
|
""" |
|
|
def __init__(self, target_length: int = 1024): |
|
|
self.target_length = target_length |
|
|
|
|
|
def apply_image(self, image: np.ndarray) -> np.ndarray: |
|
|
"""Resize image maintaining aspect ratio.""" |
|
|
h, w = image.shape[:2] |
|
|
scale = self.target_length / max(h, w) |
|
|
new_h, new_w = int(h * scale), int(w * scale) |
|
|
|
|
|
pil_image = Image.fromarray(image) |
|
|
pil_image = pil_image.resize((new_w, new_h), Image.BILINEAR) |
|
|
return np.array(pil_image) |
|
|
|
|
|
def apply_coords(self, coords: np.ndarray, original_size: tuple) -> np.ndarray: |
|
|
"""Resize coordinates to match resized image.""" |
|
|
old_h, old_w = original_size |
|
|
scale = self.target_length / max(old_h, old_w) |
|
|
return coords * scale |
|
|
|
|
|
def apply_boxes(self, boxes: np.ndarray, original_size: tuple) -> np.ndarray: |
|
|
"""Resize bounding boxes to match resized image.""" |
|
|
boxes = boxes.copy() |
|
|
boxes[..., :2] = self.apply_coords(boxes[..., :2], original_size) |
|
|
boxes[..., 2:] = self.apply_coords(boxes[..., 2:], original_size) |
|
|
return boxes |
|
|
|
|
|
def generate_grid_points(image_size: tuple, points_per_side: int = 32) -> np.ndarray: |
|
|
""" |
|
|
Generate a grid of points for automatic mask generation. |
|
|
Inspired by SAM AMG (Automatic Mask Generator). |
|
|
|
|
|
Args: |
|
|
image_size: (height, width) of the image |
|
|
points_per_side: Number of points per side of the grid |
|
|
|
|
|
Returns: |
|
|
Array of (x, y) point coordinates |
|
|
""" |
|
|
h, w = image_size |
|
|
|
|
|
|
|
|
x_coords = np.linspace(0, w - 1, points_per_side) |
|
|
y_coords = np.linspace(0, h - 1, points_per_side) |
|
|
|
|
|
|
|
|
xx, yy = np.meshgrid(x_coords, y_coords) |
|
|
points = np.stack([xx.flatten(), yy.flatten()], axis=1) |
|
|
|
|
|
return points |
|
|
|
|
|
def automatic_mask_generator(image_file, modality, window_type, |
|
|
points_per_side=16, min_mask_area=100, |
|
|
colormap='tab20', progress=gr.Progress()): |
|
|
""" |
|
|
Automatic Mask Generator (AMG) - Generate masks for entire image without prompts. |
|
|
Uses a grid of points to query the model and generates masks automatically. |
|
|
|
|
|
Inspired by SAM-Medical-Imaging's amg.py |
|
|
""" |
|
|
if model is None or processor is None: |
|
|
return None, "❌ Error: Model not loaded.", "" |
|
|
|
|
|
if image_file is None: |
|
|
return None, "⚠️ Please upload a medical image file.", "" |
|
|
|
|
|
try: |
|
|
progress(0.1, desc="Loading image...") |
|
|
|
|
|
|
|
|
file_path = image_file if isinstance(image_file, str) else str(image_file) |
|
|
file_ext = os.path.splitext(file_path)[1].lower() |
|
|
|
|
|
if file_ext == '.dcm': |
|
|
ds = pydicom.dcmread(file_path) |
|
|
img_array = ds.pixel_array.astype(np.float32) |
|
|
slope = getattr(ds, 'RescaleSlope', 1) |
|
|
intercept = getattr(ds, 'RescaleIntercept', 0) |
|
|
img_array = img_array * slope + intercept |
|
|
|
|
|
|
|
|
if modality == "CT": |
|
|
if window_type == "Brain (Grey Matter)": |
|
|
level, width = 40, 80 |
|
|
elif window_type == "Bone (Skull)": |
|
|
level, width = 500, 2000 |
|
|
else: |
|
|
level, width = 40, 400 |
|
|
img_min = level - (width / 2) |
|
|
img_max = level + (width / 2) |
|
|
else: |
|
|
img_min = np.percentile(img_array, 1) |
|
|
img_max = np.percentile(img_array, 99) |
|
|
|
|
|
img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1) |
|
|
img_uint8 = (img_norm * 255).astype(np.uint8) |
|
|
pil_image = Image.fromarray(img_uint8).convert('RGB') |
|
|
else: |
|
|
pil_image = Image.open(file_path).convert('RGB') |
|
|
|
|
|
img_array = np.array(pil_image) |
|
|
h, w = img_array.shape[:2] |
|
|
|
|
|
progress(0.2, desc="Generating grid points...") |
|
|
|
|
|
|
|
|
grid_points = generate_grid_points((h, w), points_per_side) |
|
|
total_points = len(grid_points) |
|
|
|
|
|
|
|
|
all_masks = [] |
|
|
all_scores = [] |
|
|
|
|
|
progress(0.3, desc=f"Processing {total_points} points...") |
|
|
|
|
|
|
|
|
prompts = ["anatomical structure", "region", "tissue", "organ", "segment"] |
|
|
|
|
|
for prompt_idx, prompt in enumerate(prompts): |
|
|
progress(0.3 + 0.5 * (prompt_idx / len(prompts)), desc=f"Processing prompt: {prompt}...") |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
sam_results = run_sam3_inference(pil_image, prompt, threshold=0.1, mask_threshold=0.0) |
|
|
|
|
|
if sam_results and 'masks' in sam_results and sam_results['masks'] is not None: |
|
|
masks = sam_results['masks'] |
|
|
|
|
|
for mask in masks: |
|
|
if isinstance(mask, torch.Tensor): |
|
|
mask_np = mask.cpu().numpy() |
|
|
else: |
|
|
mask_np = np.array(mask) |
|
|
|
|
|
|
|
|
if mask_np.dtype != bool: |
|
|
mask_np = mask_np > 0.5 |
|
|
|
|
|
|
|
|
mask_area = np.sum(mask_np) |
|
|
if mask_area >= min_mask_area: |
|
|
|
|
|
mask_resized = np.array( |
|
|
Image.fromarray((mask_np * 255).astype(np.uint8)).resize((w, h)) |
|
|
) > 127 |
|
|
all_masks.append(mask_resized) |
|
|
all_scores.append(mask_area) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error with prompt '{prompt}': {e}") |
|
|
continue |
|
|
|
|
|
progress(0.85, desc="Combining masks...") |
|
|
|
|
|
if not all_masks: |
|
|
return None, "⚠️ No masks generated. Try different parameters.", "" |
|
|
|
|
|
|
|
|
unique_masks = [] |
|
|
for mask in all_masks: |
|
|
is_duplicate = False |
|
|
for existing in unique_masks: |
|
|
|
|
|
intersection = np.logical_and(mask, existing).sum() |
|
|
union = np.logical_or(mask, existing).sum() |
|
|
iou = intersection / union if union > 0 else 0 |
|
|
if iou > 0.8: |
|
|
is_duplicate = True |
|
|
break |
|
|
if not is_duplicate: |
|
|
unique_masks.append(mask) |
|
|
|
|
|
progress(0.9, desc="Creating visualization...") |
|
|
|
|
|
|
|
|
plt.figure(figsize=(12, 12)) |
|
|
plt.imshow(pil_image) |
|
|
|
|
|
|
|
|
cmap = plt.cm.get_cmap(colormap) |
|
|
|
|
|
for idx, mask in enumerate(unique_masks): |
|
|
color = cmap(idx / max(len(unique_masks), 1))[:3] |
|
|
colored_mask = np.zeros((*mask.shape, 4)) |
|
|
colored_mask[mask] = (*color, 0.4) |
|
|
plt.imshow(colored_mask) |
|
|
|
|
|
plt.axis('off') |
|
|
plt.title(f"Automatic Mask Generation: {len(unique_masks)} regions detected", fontsize=14) |
|
|
|
|
|
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
|
|
output_path = output_file.name |
|
|
output_file.close() |
|
|
|
|
|
plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=150) |
|
|
plt.close() |
|
|
|
|
|
progress(1.0, desc="Complete!") |
|
|
|
|
|
|
|
|
info_text = f"📊 **AMG Results**\n\n" |
|
|
info_text += f"**Total Regions:** {len(unique_masks)}\n" |
|
|
info_text += f"**Grid Points:** {points_per_side}x{points_per_side} = {points_per_side**2}\n" |
|
|
info_text += f"**Min Area Filter:** {min_mask_area} pixels\n\n" |
|
|
|
|
|
for idx, mask in enumerate(unique_masks[:10]): |
|
|
area = np.sum(mask) |
|
|
percentage = area / (h * w) * 100 |
|
|
info_text += f"Region {idx+1}: {area:,} pixels ({percentage:.2f}%)\n" |
|
|
|
|
|
if len(unique_masks) > 10: |
|
|
info_text += f"\n... and {len(unique_masks) - 10} more regions" |
|
|
|
|
|
return output_path, f"✅ AMG Complete! Found {len(unique_masks)} regions.", info_text |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in AMG: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None, f"❌ Error: {str(e)}", "" |
|
|
|
|
|
def process_with_advanced_transforms(image_file, prompt_text, modality, window_type, |
|
|
target_size=1024, apply_clahe=False, |
|
|
clahe_clip=2.0, colormap='spring', transparency=0.5): |
|
|
""" |
|
|
Process image with advanced transforms from SAM-Medical-Imaging. |
|
|
|
|
|
- ResizeLongestSide: Maintains aspect ratio |
|
|
- CLAHE: Contrast Limited Adaptive Histogram Equalization (optional) |
|
|
""" |
|
|
if model is None or processor is None: |
|
|
return None, "❌ Error: Model not loaded." |
|
|
|
|
|
if image_file is None: |
|
|
return None, "⚠️ Please upload a medical image file." |
|
|
|
|
|
try: |
|
|
|
|
|
file_path = image_file if isinstance(image_file, str) else str(image_file) |
|
|
file_ext = os.path.splitext(file_path)[1].lower() |
|
|
|
|
|
if file_ext == '.dcm': |
|
|
ds = pydicom.dcmread(file_path) |
|
|
img_array = ds.pixel_array.astype(np.float32) |
|
|
slope = getattr(ds, 'RescaleSlope', 1) |
|
|
intercept = getattr(ds, 'RescaleIntercept', 0) |
|
|
img_array = img_array * slope + intercept |
|
|
|
|
|
|
|
|
if modality == "CT": |
|
|
if window_type == "Brain (Grey Matter)": |
|
|
level, width = 40, 80 |
|
|
elif window_type == "Bone (Skull)": |
|
|
level, width = 500, 2000 |
|
|
else: |
|
|
level, width = 40, 400 |
|
|
img_min = level - (width / 2) |
|
|
img_max = level + (width / 2) |
|
|
else: |
|
|
img_min = np.percentile(img_array, 1) |
|
|
img_max = np.percentile(img_array, 99) |
|
|
|
|
|
img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1) |
|
|
img_uint8 = (img_norm * 255).astype(np.uint8) |
|
|
else: |
|
|
img = Image.open(file_path) |
|
|
if img.mode != 'L': |
|
|
img = img.convert('L') |
|
|
img_uint8 = np.array(img) |
|
|
|
|
|
original_size = img_uint8.shape[:2] |
|
|
|
|
|
|
|
|
if apply_clahe: |
|
|
try: |
|
|
from scipy.ndimage import uniform_filter |
|
|
|
|
|
|
|
|
local_mean = uniform_filter(img_uint8.astype(float), size=50) |
|
|
local_std = np.sqrt(uniform_filter(img_uint8.astype(float)**2, size=50) - local_mean**2 + 1e-8) |
|
|
|
|
|
|
|
|
enhanced = (img_uint8 - local_mean) / (local_std + 1e-8) * clahe_clip |
|
|
enhanced = np.clip(enhanced * 30 + 128, 0, 255).astype(np.uint8) |
|
|
img_uint8 = enhanced |
|
|
except Exception as e: |
|
|
print(f"CLAHE enhancement failed: {e}") |
|
|
|
|
|
|
|
|
transform = ResizeLongestSide(target_size) |
|
|
if len(img_uint8.shape) == 2: |
|
|
img_uint8_3ch = np.stack([img_uint8] * 3, axis=-1) |
|
|
else: |
|
|
img_uint8_3ch = img_uint8 |
|
|
|
|
|
img_resized = transform.apply_image(img_uint8_3ch) |
|
|
pil_image = Image.fromarray(img_resized) |
|
|
|
|
|
if not prompt_text or not prompt_text.strip(): |
|
|
prompt_text = "brain" |
|
|
|
|
|
|
|
|
|
|
|
results = run_sam3_inference(pil_image, prompt_text, threshold=0.1, mask_threshold=0.0) |
|
|
|
|
|
final_mask = None |
|
|
if results and 'masks' in results and results['masks'] is not None: |
|
|
masks = results['masks'] |
|
|
|
|
|
mask_arrays = [] |
|
|
for mask in masks: |
|
|
if isinstance(mask, torch.Tensor): |
|
|
mask_np = mask.cpu().numpy() |
|
|
else: |
|
|
mask_np = np.array(mask) |
|
|
mask_arrays.append(mask_np) |
|
|
|
|
|
if len(mask_arrays) > 0: |
|
|
final_mask = np.any(mask_arrays, axis=0) |
|
|
|
|
|
|
|
|
plt.figure(figsize=(12, 6)) |
|
|
|
|
|
|
|
|
plt.subplot(1, 2, 1) |
|
|
plt.imshow(Image.fromarray(img_uint8_3ch[:, :, 0] if len(img_uint8_3ch.shape) == 3 else img_uint8_3ch), cmap='gray') |
|
|
plt.title(f"Original ({original_size[0]}x{original_size[1]})", fontsize=10) |
|
|
plt.axis('off') |
|
|
|
|
|
|
|
|
plt.subplot(1, 2, 2) |
|
|
plt.imshow(pil_image) |
|
|
if final_mask is not None: |
|
|
|
|
|
mask_h, mask_w = final_mask.shape |
|
|
img_h, img_w = img_resized.shape[:2] |
|
|
if (mask_h, mask_w) != (img_h, img_w): |
|
|
final_mask = np.array( |
|
|
Image.fromarray(final_mask.astype(np.uint8) * 255).resize((img_w, img_h)) |
|
|
) > 127 |
|
|
plt.imshow(final_mask, alpha=transparency, cmap=colormap) |
|
|
plt.title(f"Transformed ({target_size}px) + Segmentation", fontsize=10) |
|
|
plt.axis('off') |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
|
|
output_path = output_file.name |
|
|
output_file.close() |
|
|
|
|
|
plt.savefig(output_path, bbox_inches='tight', dpi=100) |
|
|
plt.close() |
|
|
|
|
|
status = f"✅ Processed with ResizeLongestSide({target_size})" |
|
|
if apply_clahe: |
|
|
status += f" + CLAHE(clip={clahe_clip})" |
|
|
|
|
|
return output_path, status |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in advanced transforms: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None, f"❌ Error: {str(e)}" |
|
|
|
|
|
def edge_based_segmentation(image_file, modality, window_type, |
|
|
edge_threshold=50, dilation_size=3, |
|
|
colormap='spring', transparency=0.5): |
|
|
""" |
|
|
Edge-based automatic segmentation using Sobel/Canny-like detection. |
|
|
Useful for finding boundaries in medical images. |
|
|
""" |
|
|
if image_file is None: |
|
|
return None, "⚠️ Please upload a medical image file." |
|
|
|
|
|
try: |
|
|
|
|
|
file_path = image_file if isinstance(image_file, str) else str(image_file) |
|
|
file_ext = os.path.splitext(file_path)[1].lower() |
|
|
|
|
|
if file_ext == '.dcm': |
|
|
ds = pydicom.dcmread(file_path) |
|
|
img_array = ds.pixel_array.astype(np.float32) |
|
|
slope = getattr(ds, 'RescaleSlope', 1) |
|
|
intercept = getattr(ds, 'RescaleIntercept', 0) |
|
|
img_array = img_array * slope + intercept |
|
|
|
|
|
if modality == "CT": |
|
|
if window_type == "Brain (Grey Matter)": |
|
|
level, width = 40, 80 |
|
|
elif window_type == "Bone (Skull)": |
|
|
level, width = 500, 2000 |
|
|
else: |
|
|
level, width = 40, 400 |
|
|
img_min = level - (width / 2) |
|
|
img_max = level + (width / 2) |
|
|
else: |
|
|
img_min = np.percentile(img_array, 1) |
|
|
img_max = np.percentile(img_array, 99) |
|
|
|
|
|
img_norm = np.clip((img_array - img_min) / (img_max - img_min + 1e-8), 0, 1) |
|
|
img_uint8 = (img_norm * 255).astype(np.uint8) |
|
|
else: |
|
|
img = Image.open(file_path).convert('L') |
|
|
img_uint8 = np.array(img) |
|
|
|
|
|
|
|
|
from scipy.ndimage import sobel, binary_dilation, binary_fill_holes |
|
|
|
|
|
|
|
|
dx = sobel(img_uint8.astype(float), axis=1) |
|
|
dy = sobel(img_uint8.astype(float), axis=0) |
|
|
edges = np.hypot(dx, dy) |
|
|
|
|
|
|
|
|
edge_mask = edges > edge_threshold |
|
|
|
|
|
|
|
|
if dilation_size > 0: |
|
|
struct = np.ones((dilation_size, dilation_size)) |
|
|
edge_mask = binary_dilation(edge_mask, structure=struct) |
|
|
|
|
|
|
|
|
filled_mask = binary_fill_holes(edge_mask) |
|
|
|
|
|
|
|
|
labeled, num_features = ndimage.label(filled_mask) |
|
|
|
|
|
|
|
|
pil_image = Image.fromarray(img_uint8).convert('RGB') |
|
|
|
|
|
|
|
|
plt.figure(figsize=(15, 5)) |
|
|
|
|
|
plt.subplot(1, 3, 1) |
|
|
plt.imshow(img_uint8, cmap='gray') |
|
|
plt.title("Original Image", fontsize=10) |
|
|
plt.axis('off') |
|
|
|
|
|
plt.subplot(1, 3, 2) |
|
|
plt.imshow(edges, cmap='hot') |
|
|
plt.title(f"Edge Detection (threshold={edge_threshold})", fontsize=10) |
|
|
plt.axis('off') |
|
|
|
|
|
plt.subplot(1, 3, 3) |
|
|
plt.imshow(pil_image) |
|
|
plt.imshow(filled_mask, alpha=transparency, cmap=colormap) |
|
|
plt.title(f"Segmentation ({num_features} regions)", fontsize=10) |
|
|
plt.axis('off') |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
|
|
output_path = output_file.name |
|
|
output_file.close() |
|
|
|
|
|
plt.savefig(output_path, bbox_inches='tight', dpi=100) |
|
|
plt.close() |
|
|
|
|
|
return output_path, f"✅ Edge-based segmentation complete! Found {num_features} regions." |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in edge segmentation: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return None, f"❌ Error: {str(e)}" |
|
|
|
|
|
def save_last_annotation(): |
|
|
"""Save the last processed annotation.""" |
|
|
if last_processed_mask["mask"] is None: |
|
|
return None, "⚠️ No annotation to save. Process an image first." |
|
|
|
|
|
stats = calculate_roi_statistics( |
|
|
last_processed_mask["image_file"], |
|
|
last_processed_mask["mask"], |
|
|
last_processed_mask["modality"] |
|
|
) |
|
|
|
|
|
return save_annotation( |
|
|
last_processed_mask["image_file"], |
|
|
last_processed_mask["mask"], |
|
|
last_processed_mask["prompt"], |
|
|
last_processed_mask["modality"], |
|
|
stats |
|
|
) |
|
|
|
|
|
|
|
|
demo_file_path = demo_dicom_path if demo_file_available and os.path.exists(demo_dicom_path) else None |
|
|
|
|
|
def load_demo_file(): |
|
|
"""Load the demo DICOM file.""" |
|
|
if demo_file_path and os.path.exists(demo_file_path): |
|
|
return demo_file_path, f"✅ Demo file loaded: {demo_file_path}\nReady to segment!" |
|
|
else: |
|
|
return None, "⚠️ Demo file not found. Please upload a medical image file (DICOM, PNG, or JPG)." |
|
|
|
|
|
def process_with_status(image_file, prompt_text, modality, window_type): |
|
|
"""Wrapper function to update status during processing.""" |
|
|
if model is None or processor is None: |
|
|
return None, "❌ Error: Model not loaded." |
|
|
|
|
|
if image_file is None: |
|
|
return None, "⚠️ Please upload a medical image file (DICOM, PNG, or JPG) or load the demo file." |
|
|
|
|
|
result = process_medical_image(image_file, prompt_text, modality, window_type) |
|
|
|
|
|
if result is None: |
|
|
return None, "❌ Processing failed. Check console for error details." |
|
|
else: |
|
|
return result, "✅ Segmentation complete!" |
|
|
|
|
|
def process_with_ground_truth(image_file, gt_mask_file, prompt_text, modality, window_type): |
|
|
"""Process image and compare with ground truth segmentation mask.""" |
|
|
if model is None or processor is None: |
|
|
return None, None, 0.0, 0.0, "❌ Error: Model not loaded." |
|
|
|
|
|
if image_file is None: |
|
|
return None, None, 0.0, 0.0, "⚠️ Please upload a medical image file." |
|
|
|
|
|
if gt_mask_file is None: |
|
|
return None, None, 0.0, 0.0, "⚠️ Please upload a ground truth mask file." |
|
|
|
|
|
|
|
|
result, pred_mask = process_medical_image(image_file, prompt_text, modality, window_type, return_mask=True) |
|
|
|
|
|
if result is None or pred_mask is None: |
|
|
return None, None, 0.0, 0.0, "❌ Processing failed. Check console for error details." |
|
|
|
|
|
|
|
|
comparison_path, dice_score, iou_score = compare_with_ground_truth(pred_mask, gt_mask_file) |
|
|
|
|
|
if comparison_path: |
|
|
status = f"✅ Segmentation complete!\nDice Score: {dice_score:.3f}\nIoU Score: {iou_score:.3f}" |
|
|
return result, comparison_path, dice_score, iou_score, status |
|
|
else: |
|
|
return result, None, 0.0, 0.0, "✅ Segmentation complete, but comparison failed." |
|
|
|
|
|
def process_sequence(image_files, prompt_text, modality, window_type): |
|
|
"""Process multiple images from the same subject and return gallery of results.""" |
|
|
if model is None or processor is None: |
|
|
return [], "❌ Error: Model not loaded." |
|
|
|
|
|
if not image_files: |
|
|
return [], "⚠️ Please upload medical image files (DICOM, PNG, or JPG)." |
|
|
|
|
|
|
|
|
if isinstance(image_files, str): |
|
|
image_files = [image_files] |
|
|
|
|
|
results = [] |
|
|
status_messages = [] |
|
|
|
|
|
for idx, image_file in enumerate(image_files): |
|
|
if image_file is None: |
|
|
continue |
|
|
|
|
|
status_msg = f"Processing image {idx + 1}/{len(image_files)}..." |
|
|
status_messages.append(status_msg) |
|
|
|
|
|
result = process_medical_image(image_file, prompt_text, modality, window_type) |
|
|
|
|
|
if result: |
|
|
results.append(result) |
|
|
status_messages.append(f"✅ Image {idx + 1} segmented successfully") |
|
|
else: |
|
|
status_messages.append(f"❌ Failed to process image {idx + 1}") |
|
|
|
|
|
if results: |
|
|
status = f"✅ Processed {len(results)}/{len(image_files)} images successfully!\n" + "\n".join(status_messages) |
|
|
return results, status |
|
|
else: |
|
|
return [], "❌ No images were processed successfully. Check console for error details." |
|
|
|
|
|
|
|
|
processed_results_cache = {} |
|
|
|
|
|
def extract_subject_id(file_path): |
|
|
"""Extract subject/patient ID from file path. |
|
|
|
|
|
Common patterns: |
|
|
- Folder name: /subject_001/image.png -> subject_001 |
|
|
- Filename prefix: subject_001_slice_01.png -> subject_001 |
|
|
- Patient ID in filename: patient_123_slice_5.dcm -> patient_123 |
|
|
- Study UID in DICOM: extract from DICOM metadata |
|
|
|
|
|
Returns: |
|
|
tuple: (subject_id, confidence_level, source) |
|
|
confidence_level: 'high' (DICOM metadata), 'medium' (folder/filename pattern), 'low' (fallback) |
|
|
source: 'dicom_patientid', 'dicom_study', 'folder', 'filename', 'fallback' |
|
|
""" |
|
|
import re |
|
|
|
|
|
file_path = str(file_path) |
|
|
filename = os.path.basename(file_path) |
|
|
dir_path = os.path.dirname(file_path) |
|
|
|
|
|
|
|
|
if file_path.lower().endswith('.dcm'): |
|
|
try: |
|
|
ds = pydicom.dcmread(file_path, stop_before_pixels=True) |
|
|
patient_id = getattr(ds, 'PatientID', None) |
|
|
if patient_id and patient_id.strip(): |
|
|
return f"patient_{patient_id}", 'high', 'dicom_patientid' |
|
|
|
|
|
study_uid = getattr(ds, 'StudyInstanceUID', None) |
|
|
if study_uid: |
|
|
|
|
|
return f"study_{study_uid}", 'high', 'dicom_study' |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
folder_name = os.path.basename(dir_path.rstrip('/')) |
|
|
if folder_name and folder_name not in ['', '.', '..']: |
|
|
|
|
|
if re.match(r'(subject|patient|sub|pat|case|id)[_-]?\d+', folder_name, re.I): |
|
|
return folder_name, 'medium', 'folder' |
|
|
|
|
|
|
|
|
patterns = [ |
|
|
(r'(subject|patient|sub|pat|case|id)[_-]?(\d+)', 'medium'), |
|
|
(r'([A-Z]{2,}\d+)', 'medium'), |
|
|
] |
|
|
|
|
|
for pattern, confidence in patterns: |
|
|
match = re.search(pattern, filename, re.I) |
|
|
if match: |
|
|
if len(match.groups()) > 1: |
|
|
return f"{match.group(1)}_{match.group(2)}", confidence, 'filename' |
|
|
else: |
|
|
return match.group(1), confidence, 'filename' |
|
|
|
|
|
|
|
|
numeric_match = re.search(r'(\d{3,})', filename) |
|
|
if numeric_match: |
|
|
return numeric_match.group(1), 'low', 'filename_numeric' |
|
|
|
|
|
|
|
|
base_name = os.path.splitext(filename)[0] |
|
|
if len(base_name) > 0: |
|
|
return base_name, 'low', 'fallback' |
|
|
|
|
|
return "unknown", 'low', 'unknown' |
|
|
|
|
|
def group_images_by_subject(image_files): |
|
|
"""Group image files by subject/patient ID. |
|
|
|
|
|
Returns: |
|
|
dict: {subject_id: {'files': [...], 'confidence': 'high/medium/low', 'sources': set(...)}} |
|
|
""" |
|
|
if not image_files: |
|
|
return {} |
|
|
|
|
|
if isinstance(image_files, str): |
|
|
image_files = [image_files] |
|
|
|
|
|
|
|
|
image_files = [f for f in image_files if f is not None] |
|
|
|
|
|
|
|
|
subject_groups = {} |
|
|
for file_path in image_files: |
|
|
subject_id, confidence, source = extract_subject_id(file_path) |
|
|
|
|
|
if subject_id not in subject_groups: |
|
|
subject_groups[subject_id] = { |
|
|
'files': [], |
|
|
'confidence': confidence, |
|
|
'sources': set([source]) |
|
|
} |
|
|
|
|
|
subject_groups[subject_id]['files'].append(file_path) |
|
|
subject_groups[subject_id]['sources'].add(source) |
|
|
|
|
|
|
|
|
if confidence == 'high' or (confidence == 'medium' and subject_groups[subject_id]['confidence'] == 'low'): |
|
|
subject_groups[subject_id]['confidence'] = confidence |
|
|
|
|
|
|
|
|
for subject_id in subject_groups: |
|
|
subject_groups[subject_id]['files'].sort() |
|
|
subject_groups[subject_id]['sources'] = list(subject_groups[subject_id]['sources']) |
|
|
|
|
|
return subject_groups |
|
|
|
|
|
def detect_subjects(image_files): |
|
|
"""Detect and return subject groups from uploaded files.""" |
|
|
if not image_files: |
|
|
return gr.Dropdown(choices=[], value=None), "No files uploaded" |
|
|
|
|
|
subject_groups = group_images_by_subject(image_files) |
|
|
|
|
|
if not subject_groups: |
|
|
return gr.Dropdown(choices=[], value=None), "No subjects detected" |
|
|
|
|
|
choices = [] |
|
|
status_msg = f"✅ Detected {len(subject_groups)} subject(s):\n\n" |
|
|
|
|
|
for subject_id, info in sorted(subject_groups.items()): |
|
|
num_files = len(info['files']) |
|
|
confidence = info['confidence'] |
|
|
sources = ', '.join(info['sources']) |
|
|
|
|
|
|
|
|
if confidence == 'high': |
|
|
confidence_icon = "✅" |
|
|
confidence_text = "HIGH (DICOM metadata - very reliable)" |
|
|
elif confidence == 'medium': |
|
|
confidence_icon = "⚠️" |
|
|
confidence_text = "MEDIUM (folder/filename pattern - likely same patient)" |
|
|
else: |
|
|
confidence_icon = "⚠️⚠️" |
|
|
confidence_text = "LOW (filename-based - verify manually)" |
|
|
|
|
|
choices.append(f"{subject_id} ({num_files} slices)") |
|
|
status_msg += f"{confidence_icon} **{subject_id}**: {num_files} slices\n" |
|
|
status_msg += f" Confidence: {confidence_text}\n" |
|
|
status_msg += f" Source: {sources}\n\n" |
|
|
|
|
|
|
|
|
low_confidence_count = sum(1 for info in subject_groups.values() if info['confidence'] == 'low') |
|
|
if low_confidence_count > 0: |
|
|
status_msg += f"⚠️ **Warning**: {low_confidence_count} subject(s) detected with LOW confidence.\n" |
|
|
status_msg += "Please verify these are actually the same patient before proceeding.\n" |
|
|
|
|
|
return gr.Dropdown(choices=choices, value=choices[0] if choices else None), status_msg |
|
|
|
|
|
def process_slices_for_viewer(image_files, selected_subject, prompt_text, modality, window_type): |
|
|
"""Process all slices for selected subject and cache results for interactive viewing.""" |
|
|
if model is None or processor is None: |
|
|
return None, 0, "❌ Error: Model not loaded.", "No slices loaded", gr.Dropdown(choices=[], value=None), "" |
|
|
|
|
|
if not image_files: |
|
|
return None, 0, "⚠️ Please upload medical image files.", "No slices loaded", gr.Dropdown(choices=[], value=None), "" |
|
|
|
|
|
|
|
|
subject_groups = group_images_by_subject(image_files) |
|
|
|
|
|
if not subject_groups: |
|
|
return None, 0, "⚠️ Could not detect subjects in uploaded files.", "No slices loaded", gr.Dropdown(choices=[], value=None), "" |
|
|
|
|
|
|
|
|
if selected_subject: |
|
|
subject_id = selected_subject.split(" (")[0] |
|
|
else: |
|
|
|
|
|
subject_id = list(subject_groups.keys())[0] |
|
|
|
|
|
if subject_id not in subject_groups: |
|
|
return None, 0, f"⚠️ Subject '{subject_id}' not found.", "No slices loaded", gr.Dropdown(choices=[], value=None), "" |
|
|
|
|
|
|
|
|
subject_info = subject_groups[subject_id] |
|
|
subject_files = subject_info['files'] |
|
|
confidence = subject_info['confidence'] |
|
|
|
|
|
|
|
|
confidence_warning = "" |
|
|
if confidence == 'low': |
|
|
confidence_warning = "\n⚠️ LOW CONFIDENCE: These files may not be from the same patient. Please verify!" |
|
|
elif confidence == 'medium': |
|
|
confidence_warning = "\n⚠️ MEDIUM CONFIDENCE: Likely same patient, but verify if critical." |
|
|
|
|
|
results = [] |
|
|
status_messages = [] |
|
|
|
|
|
for idx, image_file in enumerate(subject_files): |
|
|
status_msg = f"Processing slice {idx + 1}/{len(subject_files)}..." |
|
|
status_messages.append(status_msg) |
|
|
|
|
|
result = process_medical_image(image_file, prompt_text, modality, window_type) |
|
|
|
|
|
if result: |
|
|
results.append(result) |
|
|
status_messages.append(f"✅ Slice {idx + 1} processed") |
|
|
else: |
|
|
status_messages.append(f"❌ Failed to process slice {idx + 1}") |
|
|
|
|
|
if results: |
|
|
|
|
|
cache_key = f"{subject_id}_{len(subject_files)}_{prompt_text}_{modality}" |
|
|
processed_results_cache[cache_key] = results |
|
|
|
|
|
max_slices = len(results) - 1 |
|
|
status = f"✅ Processed {len(results)}/{len(subject_files)} slices for {subject_id}!\nUse slider or buttons to navigate.{confidence_warning}" |
|
|
slice_info = f"Slice 1/{len(results)} ({subject_id})" |
|
|
|
|
|
|
|
|
choices = [] |
|
|
for sid, info in sorted(subject_groups.items()): |
|
|
marker = "✓" if sid == subject_id else "" |
|
|
num_files = len(info['files']) |
|
|
choices.append(f"{marker} {sid} ({num_files} slices)") |
|
|
|
|
|
return results[0], max_slices, status, slice_info, gr.Dropdown(choices=choices, value=choices[0] if choices else None), f"Viewing: {subject_id}" |
|
|
else: |
|
|
return None, 0, "❌ No slices were processed successfully.", "No slices loaded", gr.Dropdown(choices=[], value=None), "" |
|
|
|
|
|
def navigate_slice(slice_idx, image_files, selected_subject, prompt_text, modality, window_type): |
|
|
"""Navigate to a specific slice in the sequence.""" |
|
|
if not image_files: |
|
|
return None, "No slices loaded" |
|
|
|
|
|
|
|
|
subject_groups = group_images_by_subject(image_files) |
|
|
|
|
|
if selected_subject: |
|
|
subject_id = selected_subject.split(" (")[0] |
|
|
else: |
|
|
subject_id = list(subject_groups.keys())[0] if subject_groups else None |
|
|
|
|
|
if not subject_id or subject_id not in subject_groups: |
|
|
return None, "No slices loaded" |
|
|
|
|
|
subject_info = subject_groups[subject_id] |
|
|
subject_files = subject_info['files'] |
|
|
slice_idx = int(slice_idx) |
|
|
cache_key = f"{subject_id}_{len(subject_files)}_{prompt_text}_{modality}" |
|
|
|
|
|
if cache_key in processed_results_cache: |
|
|
results = processed_results_cache[cache_key] |
|
|
if 0 <= slice_idx < len(results): |
|
|
slice_info = f"Slice {slice_idx + 1}/{len(results)} ({subject_id})" |
|
|
return results[slice_idx], slice_info |
|
|
|
|
|
|
|
|
if 0 <= slice_idx < len(subject_files): |
|
|
result = process_medical_image(subject_files[slice_idx], prompt_text, modality, window_type) |
|
|
if result: |
|
|
slice_info = f"Slice {slice_idx + 1}/{len(subject_files)} ({subject_id})" |
|
|
return result, slice_info |
|
|
|
|
|
return None, f"Invalid slice index: {slice_idx}" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# 🏥 NeuroSAM 3: Medical Image Segmentation") |
|
|
|
|
|
demo_info = "" |
|
|
if demo_file_path: |
|
|
demo_info = f"\n\n**📁 Demo File Available:** A sample DICOM file is ready: `{demo_file_path}`\nClick 'Load Demo File' button below to use it!" |
|
|
|
|
|
gr.Markdown(f""" |
|
|
Upload a medical image (DICOM .dcm, PNG, or JPG) and type what you want to find (e.g., 'brain', 'tumor', 'skull'). |
|
|
{demo_info} |
|
|
|
|
|
**Instructions:** |
|
|
1. Upload a medical image file: |
|
|
- DICOM (.dcm) files for CT or MRI scans |
|
|
- PNG/JPG files (e.g., from Kaggle datasets like brain MRI images) |
|
|
2. Enter a text prompt describing what to segment |
|
|
3. Select the imaging modality (CT or MRI) |
|
|
4. Choose the windowing strategy (for CT images) |
|
|
5. Click "Segment Structure" to process |
|
|
|
|
|
**Supported Formats:** |
|
|
- DICOM (.dcm) - Standard medical imaging format |
|
|
- PNG/JPG - Standard image formats (works with Kaggle brain MRI datasets) |
|
|
""") |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.Tab("Single Image"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
file_input = gr.File( |
|
|
label="Upload Medical Image (DICOM .dcm, PNG, JPG)", |
|
|
file_types=[".dcm", ".png", ".jpg", ".jpeg"], |
|
|
type="filepath", |
|
|
value=demo_file_path |
|
|
) |
|
|
|
|
|
load_demo_btn = gr.Button( |
|
|
"📁 Load Demo File", |
|
|
variant="secondary", |
|
|
size="sm", |
|
|
visible=bool(demo_file_path) |
|
|
) |
|
|
|
|
|
text_input = gr.Textbox( |
|
|
label="Text Prompt", |
|
|
value="brain", |
|
|
placeholder="e.g. brain, tumor, skull, eyes", |
|
|
info="Describe what anatomical structure or region you want to segment" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
modality_dropdown = gr.Dropdown( |
|
|
["CT", "MRI"], |
|
|
label="Modality", |
|
|
value="MRI", |
|
|
info="Select the imaging modality" |
|
|
) |
|
|
window_dropdown = gr.Dropdown( |
|
|
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], |
|
|
label="Windowing Strategy (CT only)", |
|
|
value="Brain (Grey Matter)", |
|
|
info="CT windowing preset (ignored for MRI)" |
|
|
) |
|
|
|
|
|
submit_btn = gr.Button("Segment Structure", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
image_output = gr.Image( |
|
|
label="Segmentation Result", |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
gr.Markdown("### Status") |
|
|
status_text = gr.Textbox( |
|
|
label="Processing Status", |
|
|
value="Ready. Upload a medical image file (DICOM, PNG, or JPG) to begin.", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
with gr.Tab("Interactive Slice Viewer"): |
|
|
gr.Markdown("**Scroll through multiple slices/images from the same subject interactively**") |
|
|
gr.Markdown(""" |
|
|
**📋 Subject Detection:** The app automatically detects subject/patient IDs from: |
|
|
- Folder names (e.g., `subject_001/`, `patient_123/`) |
|
|
- Filenames (e.g., `subject_001_slice_01.png`, `patient_123.dcm`) |
|
|
- DICOM metadata (PatientID, StudyInstanceUID) |
|
|
|
|
|
**💡 Tip:** Upload images organized by subject folders for best results! |
|
|
""") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
files_input = gr.File( |
|
|
label="Upload Multiple Images/Slices (Select multiple files)", |
|
|
file_types=[".dcm", ".png", ".jpg", ".jpeg"], |
|
|
file_count="multiple", |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
subject_dropdown = gr.Dropdown( |
|
|
label="Select Subject/Patient", |
|
|
choices=[], |
|
|
value=None, |
|
|
interactive=True, |
|
|
info="Select which subject's slices to view (auto-detected from filenames/folders)" |
|
|
) |
|
|
|
|
|
text_input_batch = gr.Textbox( |
|
|
label="Text Prompt", |
|
|
value="brain", |
|
|
placeholder="e.g. brain, tumor, skull, eyes", |
|
|
info="Describe what anatomical structure or region you want to segment" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
modality_dropdown_batch = gr.Dropdown( |
|
|
["CT", "MRI"], |
|
|
label="Modality", |
|
|
value="MRI", |
|
|
info="Select the imaging modality" |
|
|
) |
|
|
window_dropdown_batch = gr.Dropdown( |
|
|
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], |
|
|
label="Windowing Strategy (CT only)", |
|
|
value="Brain (Grey Matter)", |
|
|
info="CT windowing preset (ignored for MRI)" |
|
|
) |
|
|
|
|
|
detect_subjects_btn = gr.Button("🔍 Detect Subjects", variant="secondary", size="sm") |
|
|
submit_batch_btn = gr.Button("Process All Slices", variant="primary", size="lg") |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("### 🎛️ Slice Navigator") |
|
|
slice_slider = gr.Slider( |
|
|
minimum=0, |
|
|
maximum=0, |
|
|
step=1, |
|
|
value=0, |
|
|
label="Slice Number", |
|
|
info="Use slider or arrow keys to navigate through slices", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
prev_btn = gr.Button("⬆️ Previous Slice", size="sm") |
|
|
next_btn = gr.Button("⬇️ Next Slice", size="sm") |
|
|
auto_play_btn = gr.Button("▶️ Auto-play", size="sm") |
|
|
|
|
|
with gr.Column(): |
|
|
current_slice_output = gr.Image( |
|
|
label="Current Slice Segmentation", |
|
|
type="filepath", |
|
|
height=600 |
|
|
) |
|
|
|
|
|
gr.Markdown("### Slice Info") |
|
|
slice_info_text = gr.Textbox( |
|
|
label="Current Slice", |
|
|
value="No slices loaded", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
subject_info_text = gr.Textbox( |
|
|
label="Subject Info", |
|
|
value="", |
|
|
interactive=False, |
|
|
visible=False |
|
|
) |
|
|
|
|
|
gr.Markdown("### Status") |
|
|
status_batch_text = gr.Textbox( |
|
|
label="Processing Status", |
|
|
value="Ready. Upload multiple medical image files to process a sequence.", |
|
|
interactive=False, |
|
|
lines=4 |
|
|
) |
|
|
|
|
|
with gr.Tab("Gallery View"): |
|
|
gr.Markdown("**View all segmentations in a gallery grid**") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
files_input_gallery = gr.File( |
|
|
label="Upload Multiple Images (Select multiple files)", |
|
|
file_types=[".dcm", ".png", ".jpg", ".jpeg"], |
|
|
file_count="multiple", |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
text_input_gallery = gr.Textbox( |
|
|
label="Text Prompt", |
|
|
value="brain", |
|
|
placeholder="e.g. brain, tumor, skull, eyes" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
modality_dropdown_gallery = gr.Dropdown( |
|
|
["CT", "MRI"], |
|
|
label="Modality", |
|
|
value="MRI" |
|
|
) |
|
|
window_dropdown_gallery = gr.Dropdown( |
|
|
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], |
|
|
label="Windowing Strategy (CT only)", |
|
|
value="Brain (Grey Matter)" |
|
|
) |
|
|
|
|
|
submit_gallery_btn = gr.Button("Process & Show Gallery", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
gallery_output = gr.Gallery( |
|
|
label="Segmentation Gallery", |
|
|
show_label=True, |
|
|
elem_id="gallery", |
|
|
columns=2, |
|
|
rows=2, |
|
|
height="auto" |
|
|
) |
|
|
|
|
|
status_gallery_text = gr.Textbox( |
|
|
label="Status", |
|
|
value="Ready. Upload multiple images to view in gallery.", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
with gr.Tab("Compare with Ground Truth"): |
|
|
gr.Markdown("**Compare SAM 3 segmentation with ground truth masks (e.g., from BraTS, Kaggle datasets)**") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
file_input_gt = gr.File( |
|
|
label="Upload Medical Image (DICOM .dcm, PNG, JPG)", |
|
|
file_types=[".dcm", ".png", ".jpg", ".jpeg"], |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
gt_mask_input = gr.File( |
|
|
label="Upload Ground Truth Mask (PNG, JPG)", |
|
|
file_types=[".png", ".jpg", ".jpeg"], |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
text_input_gt = gr.Textbox( |
|
|
label="Text Prompt", |
|
|
value="brain", |
|
|
placeholder="e.g. brain, tumor, skull", |
|
|
info="Describe what anatomical structure or region you want to segment" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
modality_dropdown_gt = gr.Dropdown( |
|
|
["CT", "MRI"], |
|
|
label="Modality", |
|
|
value="MRI", |
|
|
info="Select the imaging modality" |
|
|
) |
|
|
window_dropdown_gt = gr.Dropdown( |
|
|
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], |
|
|
label="Windowing Strategy (CT only)", |
|
|
value="Brain (Grey Matter)", |
|
|
info="CT windowing preset (ignored for MRI)" |
|
|
) |
|
|
|
|
|
submit_gt_btn = gr.Button("Compare Segmentation", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
image_output_gt = gr.Image( |
|
|
label="SAM 3 Segmentation", |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
comparison_output = gr.Image( |
|
|
label="Comparison: SAM 3 vs Ground Truth", |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
gr.Markdown("### Metrics") |
|
|
dice_score_text = gr.Textbox( |
|
|
label="Dice Score", |
|
|
value="--", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
iou_score_text = gr.Textbox( |
|
|
label="IoU Score", |
|
|
value="--", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
gr.Markdown("### Status") |
|
|
status_gt_text = gr.Textbox( |
|
|
label="Processing Status", |
|
|
value="Ready. Upload image and ground truth mask to compare.", |
|
|
interactive=False, |
|
|
lines=3 |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("✨ Enhanced Single Image"): |
|
|
gr.Markdown("**Enhanced version with image adjustments, visualization options, and progress tracking**") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
file_input_enh = gr.File( |
|
|
label="Upload Medical Image (DICOM .dcm, PNG, JPG)", |
|
|
file_types=[".dcm", ".png", ".jpg", ".jpeg"], |
|
|
type="filepath", |
|
|
value=demo_file_path |
|
|
) |
|
|
|
|
|
text_input_enh = gr.Textbox( |
|
|
label="Text Prompt", |
|
|
value="brain", |
|
|
placeholder="e.g. brain, tumor, skull, eyes", |
|
|
info="Describe what anatomical structure or region you want to segment" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
modality_enh = gr.Dropdown( |
|
|
["CT", "MRI"], |
|
|
label="Modality", |
|
|
value="MRI", |
|
|
info="Select the imaging modality" |
|
|
) |
|
|
window_enh = gr.Dropdown( |
|
|
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], |
|
|
label="Windowing Strategy (CT only)", |
|
|
value="Brain (Grey Matter)", |
|
|
info="CT windowing preset (ignored for MRI)" |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("### 🎨 Image Adjustments") |
|
|
brightness_slider = gr.Slider( |
|
|
0.5, 2.0, value=1.0, step=0.1, |
|
|
label="Brightness", |
|
|
info="Adjust image brightness (0.5 = darker, 2.0 = brighter)" |
|
|
) |
|
|
contrast_slider = gr.Slider( |
|
|
0.5, 2.0, value=1.0, step=0.1, |
|
|
label="Contrast", |
|
|
info="Adjust image contrast (0.5 = lower contrast, 2.0 = higher contrast)" |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("### 🎭 Visualization Options") |
|
|
colormap_dropdown = gr.Dropdown( |
|
|
["spring", "cool", "hot", "viridis", "plasma", "jet", "rainbow"], |
|
|
label="Mask Colormap", |
|
|
value="spring", |
|
|
info="Color scheme for segmentation overlay" |
|
|
) |
|
|
transparency_slider = gr.Slider( |
|
|
0.0, 1.0, value=0.5, step=0.1, |
|
|
label="Mask Transparency", |
|
|
info="Transparency of segmentation overlay (0.0 = opaque, 1.0 = transparent)" |
|
|
) |
|
|
|
|
|
submit_enh_btn = gr.Button("Segment with Progress", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
|
|
|
progress_text = gr.Textbox( |
|
|
label="Progress", |
|
|
value="Ready. Upload a medical image file to begin.", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
image_output_enh = gr.Image( |
|
|
label="Segmentation Result", |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
|
|
|
download_output = gr.File( |
|
|
label="Download Result", |
|
|
visible=True |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("### 📊 Metrics") |
|
|
metrics_text = gr.Textbox( |
|
|
label="Segmentation Info", |
|
|
value="", |
|
|
lines=3, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
with gr.Tab("✨ Enhanced Batch Processing"): |
|
|
gr.Markdown("**Enhanced batch processing with ZIP download and progress tracking**") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
files_input_enh_batch = gr.File( |
|
|
label="Upload Multiple Images for Batch Processing", |
|
|
file_types=[".dcm", ".png", ".jpg", ".jpeg"], |
|
|
file_count="multiple", |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
text_input_enh_batch = gr.Textbox( |
|
|
label="Text Prompt", |
|
|
value="brain", |
|
|
placeholder="e.g. brain, tumor, skull, eyes" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
modality_enh_batch = gr.Dropdown( |
|
|
["CT", "MRI"], |
|
|
label="Modality", |
|
|
value="MRI" |
|
|
) |
|
|
window_enh_batch = gr.Dropdown( |
|
|
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], |
|
|
label="Windowing Strategy (CT only)", |
|
|
value="Brain (Grey Matter)" |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("### 🎨 Image Adjustments") |
|
|
brightness_slider_batch = gr.Slider( |
|
|
0.5, 2.0, value=1.0, step=0.1, |
|
|
label="Brightness" |
|
|
) |
|
|
contrast_slider_batch = gr.Slider( |
|
|
0.5, 2.0, value=1.0, step=0.1, |
|
|
label="Contrast" |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("### 🎭 Visualization Options") |
|
|
colormap_dropdown_batch = gr.Dropdown( |
|
|
["spring", "cool", "hot", "viridis", "plasma", "jet", "rainbow"], |
|
|
label="Mask Colormap", |
|
|
value="spring" |
|
|
) |
|
|
transparency_slider_batch = gr.Slider( |
|
|
0.0, 1.0, value=0.5, step=0.1, |
|
|
label="Mask Transparency" |
|
|
) |
|
|
|
|
|
submit_enh_batch_btn = gr.Button("Process Batch with Progress", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
gallery_output_enh = gr.Gallery( |
|
|
label="Segmentation Gallery", |
|
|
show_label=True, |
|
|
elem_id="gallery", |
|
|
columns=2, |
|
|
rows=2, |
|
|
height="auto" |
|
|
) |
|
|
|
|
|
|
|
|
batch_download_output = gr.File( |
|
|
label="Download All Results (ZIP)", |
|
|
visible=True |
|
|
) |
|
|
|
|
|
status_enh_batch_text = gr.Textbox( |
|
|
label="Status", |
|
|
value="Ready. Upload multiple images to process in batch.", |
|
|
interactive=False, |
|
|
lines=4 |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("🎯 Point/Box Prompts"): |
|
|
gr.Markdown(""" |
|
|
**Interactive Point and Box-based Segmentation** |
|
|
|
|
|
Use precise point clicks or bounding boxes to guide the segmentation. |
|
|
- **Point Prompt**: Click on the region you want to segment |
|
|
- **Box Prompt**: Define a bounding box around the region of interest |
|
|
""") |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.Tab("Point Prompt"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
file_input_point = gr.File( |
|
|
label="Upload Medical Image", |
|
|
file_types=[".dcm", ".png", ".jpg", ".jpeg"], |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
gr.Markdown("### Point Coordinates") |
|
|
with gr.Row(): |
|
|
point_x = gr.Number(label="X coordinate", value=128, precision=0) |
|
|
point_y = gr.Number(label="Y coordinate", value=128, precision=0) |
|
|
|
|
|
with gr.Row(): |
|
|
modality_point = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI") |
|
|
window_point = gr.Dropdown( |
|
|
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], |
|
|
label="Windowing", value="Brain (Grey Matter)" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
colormap_point = gr.Dropdown( |
|
|
["spring", "cool", "hot", "viridis", "plasma"], |
|
|
label="Colormap", value="spring" |
|
|
) |
|
|
transparency_point = gr.Slider(0.0, 1.0, value=0.5, label="Transparency") |
|
|
|
|
|
submit_point_btn = gr.Button("Segment at Point", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_point = gr.Image(label="Point Segmentation Result", type="filepath") |
|
|
status_point = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
with gr.Tab("Box Prompt"): |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
file_input_box = gr.File( |
|
|
label="Upload Medical Image", |
|
|
file_types=[".dcm", ".png", ".jpg", ".jpeg"], |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
gr.Markdown("### Bounding Box Coordinates") |
|
|
with gr.Row(): |
|
|
box_x1 = gr.Number(label="X1 (left)", value=50, precision=0) |
|
|
box_y1 = gr.Number(label="Y1 (top)", value=50, precision=0) |
|
|
with gr.Row(): |
|
|
box_x2 = gr.Number(label="X2 (right)", value=200, precision=0) |
|
|
box_y2 = gr.Number(label="Y2 (bottom)", value=200, precision=0) |
|
|
|
|
|
with gr.Row(): |
|
|
modality_box = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI") |
|
|
window_box = gr.Dropdown( |
|
|
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], |
|
|
label="Windowing", value="Brain (Grey Matter)" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
colormap_box = gr.Dropdown( |
|
|
["spring", "cool", "hot", "viridis", "plasma"], |
|
|
label="Colormap", value="spring" |
|
|
) |
|
|
transparency_box = gr.Slider(0.0, 1.0, value=0.5, label="Transparency") |
|
|
|
|
|
submit_box_btn = gr.Button("Segment in Box", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_box = gr.Image(label="Box Segmentation Result", type="filepath") |
|
|
status_box = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
|
|
|
with gr.Tab("📊 ROI Statistics & Export"): |
|
|
gr.Markdown(""" |
|
|
**ROI Statistics and Export Options** |
|
|
|
|
|
Process an image and get detailed statistics about the segmented region: |
|
|
- Area (pixels and percentage) |
|
|
- Intensity statistics (mean, std, min, max) |
|
|
- Centroid and bounding box |
|
|
- Export to NIFTI format for medical imaging software |
|
|
- Save/Load annotations for later use |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
file_input_stats = gr.File( |
|
|
label="Upload Medical Image", |
|
|
file_types=[".dcm", ".png", ".jpg", ".jpeg"], |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
text_input_stats = gr.Textbox( |
|
|
label="Text Prompt", value="brain", |
|
|
placeholder="e.g. brain, tumor, skull" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
modality_stats = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI") |
|
|
window_stats = gr.Dropdown( |
|
|
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], |
|
|
label="Windowing", value="Brain (Grey Matter)" |
|
|
) |
|
|
|
|
|
submit_stats_btn = gr.Button("Process & Calculate Statistics", variant="primary") |
|
|
|
|
|
gr.Markdown("### Export Options") |
|
|
with gr.Row(): |
|
|
export_nifti_btn = gr.Button("📥 Export to NIFTI", size="sm") |
|
|
save_annotation_btn = gr.Button("💾 Save Annotation", size="sm") |
|
|
|
|
|
with gr.Column(): |
|
|
output_stats = gr.Image(label="Segmentation Result", type="filepath") |
|
|
status_stats = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
gr.Markdown("### 📊 ROI Statistics") |
|
|
roi_stats_text = gr.Markdown(value="*Process an image to see statistics*") |
|
|
|
|
|
nifti_download = gr.File(label="Download NIFTI", visible=True) |
|
|
annotation_download = gr.File(label="Download Annotation", visible=True) |
|
|
|
|
|
gr.Markdown("---") |
|
|
gr.Markdown("### Load Saved Annotation") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
annotation_upload = gr.File( |
|
|
label="Upload Annotation (.zip)", |
|
|
file_types=[".zip"], |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
original_image_upload = gr.File( |
|
|
label="Upload Original Image (for visualization)", |
|
|
file_types=[".dcm", ".png", ".jpg", ".jpeg"], |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
load_annotation_btn = gr.Button("Load & Visualize Annotation", variant="secondary") |
|
|
|
|
|
with gr.Column(): |
|
|
loaded_annotation_output = gr.Image(label="Loaded Annotation", type="filepath") |
|
|
loaded_annotation_info = gr.Markdown(value="*Upload an annotation file to load*") |
|
|
|
|
|
|
|
|
with gr.Tab("🎭 Multi-Mask Output"): |
|
|
gr.Markdown(""" |
|
|
**Generate Multiple Mask Candidates** |
|
|
|
|
|
SAM can generate multiple segmentation hypotheses with confidence scores. |
|
|
This is useful when the segmentation is ambiguous or you want to compare alternatives. |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
file_input_multi = gr.File( |
|
|
label="Upload Medical Image", |
|
|
file_types=[".dcm", ".png", ".jpg", ".jpeg"], |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
text_input_multi = gr.Textbox( |
|
|
label="Text Prompt", value="brain", |
|
|
placeholder="e.g. brain, tumor, skull" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
modality_multi = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI") |
|
|
window_multi = gr.Dropdown( |
|
|
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], |
|
|
label="Windowing", value="Brain (Grey Matter)" |
|
|
) |
|
|
|
|
|
num_masks_slider = gr.Slider(1, 5, value=3, step=1, label="Number of Masks") |
|
|
|
|
|
submit_multi_btn = gr.Button("Generate Multiple Masks", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
gallery_multi = gr.Gallery( |
|
|
label="Mask Candidates", |
|
|
show_label=True, |
|
|
columns=2, |
|
|
rows=2, |
|
|
height="auto" |
|
|
) |
|
|
|
|
|
status_multi = gr.Textbox(label="Status", interactive=False) |
|
|
mask_info_multi = gr.Textbox(label="Mask Information", lines=5, interactive=False) |
|
|
|
|
|
|
|
|
with gr.Tab("🔬 Automatic Mask Generator"): |
|
|
gr.Markdown(""" |
|
|
**Automatic Mask Generator (AMG)** |
|
|
|
|
|
Inspired by [SAM-Medical-Imaging](https://github.com/amine0110/SAM-Medical-Imaging) - |
|
|
automatically segment the entire image without needing specific prompts. |
|
|
|
|
|
- Uses multiple prompts internally to discover all regions |
|
|
- Filters and deduplicates overlapping masks |
|
|
- Great for exploratory analysis of medical images |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
file_input_amg = gr.File( |
|
|
label="Upload Medical Image", |
|
|
file_types=[".dcm", ".png", ".jpg", ".jpeg"], |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
modality_amg = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI") |
|
|
window_amg = gr.Dropdown( |
|
|
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], |
|
|
label="Windowing", value="Brain (Grey Matter)" |
|
|
) |
|
|
|
|
|
gr.Markdown("### AMG Parameters") |
|
|
points_per_side = gr.Slider( |
|
|
8, 32, value=16, step=4, |
|
|
label="Grid Density", |
|
|
info="Higher = more detailed but slower" |
|
|
) |
|
|
min_mask_area = gr.Slider( |
|
|
50, 1000, value=100, step=50, |
|
|
label="Minimum Mask Area (pixels)", |
|
|
info="Filter out small noise regions" |
|
|
) |
|
|
colormap_amg = gr.Dropdown( |
|
|
["tab20", "tab10", "Set1", "Set2", "Paired", "rainbow"], |
|
|
label="Colormap for Regions", |
|
|
value="tab20" |
|
|
) |
|
|
|
|
|
submit_amg_btn = gr.Button("🔬 Run Automatic Segmentation", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
output_amg = gr.Image(label="AMG Result", type="filepath") |
|
|
status_amg = gr.Textbox(label="Status", interactive=False) |
|
|
info_amg = gr.Markdown(value="*Run AMG to see region details*") |
|
|
|
|
|
with gr.Tab("🔧 Advanced Transforms"): |
|
|
gr.Markdown(""" |
|
|
**Advanced Image Transforms** |
|
|
|
|
|
Apply preprocessing transforms from SAM-Medical-Imaging before segmentation: |
|
|
|
|
|
- **ResizeLongestSide**: Resize maintaining aspect ratio (optimal for SAM) |
|
|
- **CLAHE-like Enhancement**: Enhance local contrast for better visibility |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
file_input_transform = gr.File( |
|
|
label="Upload Medical Image", |
|
|
file_types=[".dcm", ".png", ".jpg", ".jpeg"], |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
text_input_transform = gr.Textbox( |
|
|
label="Text Prompt", value="brain", |
|
|
placeholder="e.g. brain, tumor, skull" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
modality_transform = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI") |
|
|
window_transform = gr.Dropdown( |
|
|
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], |
|
|
label="Windowing", value="Brain (Grey Matter)" |
|
|
) |
|
|
|
|
|
gr.Markdown("### Transform Settings") |
|
|
target_size_slider = gr.Slider( |
|
|
256, 2048, value=1024, step=128, |
|
|
label="Target Size (ResizeLongestSide)", |
|
|
info="Resize image's longest side to this value" |
|
|
) |
|
|
|
|
|
apply_clahe_checkbox = gr.Checkbox( |
|
|
label="Apply CLAHE-like Enhancement", |
|
|
value=False, |
|
|
info="Enhance local contrast (useful for low-contrast images)" |
|
|
) |
|
|
clahe_clip_slider = gr.Slider( |
|
|
1.0, 4.0, value=2.0, step=0.5, |
|
|
label="CLAHE Clip Limit", |
|
|
info="Higher = more contrast enhancement" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
colormap_transform = gr.Dropdown( |
|
|
["spring", "cool", "hot", "viridis", "plasma"], |
|
|
label="Colormap", value="spring" |
|
|
) |
|
|
transparency_transform = gr.Slider( |
|
|
0.0, 1.0, value=0.5, step=0.1, |
|
|
label="Transparency" |
|
|
) |
|
|
|
|
|
submit_transform_btn = gr.Button("🔧 Process with Transforms", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_transform = gr.Image(label="Transform + Segmentation Result", type="filepath") |
|
|
status_transform = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
with gr.Tab("🌊 Edge-Based Segmentation"): |
|
|
gr.Markdown(""" |
|
|
**Edge-Based Automatic Segmentation** |
|
|
|
|
|
Uses Sobel edge detection to find boundaries in medical images. |
|
|
Works without AI model - useful for comparison or when model fails. |
|
|
|
|
|
- Detects edges using gradient-based methods |
|
|
- Fills regions within boundaries |
|
|
- Identifies distinct anatomical structures |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
file_input_edge = gr.File( |
|
|
label="Upload Medical Image", |
|
|
file_types=[".dcm", ".png", ".jpg", ".jpeg"], |
|
|
type="filepath" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
modality_edge = gr.Dropdown(["CT", "MRI"], label="Modality", value="MRI") |
|
|
window_edge = gr.Dropdown( |
|
|
["Brain (Grey Matter)", "Bone (Skull)", "Soft Tissue (Face)"], |
|
|
label="Windowing", value="Brain (Grey Matter)" |
|
|
) |
|
|
|
|
|
gr.Markdown("### Edge Detection Parameters") |
|
|
edge_threshold_slider = gr.Slider( |
|
|
10, 150, value=50, step=10, |
|
|
label="Edge Threshold", |
|
|
info="Lower = more edges detected" |
|
|
) |
|
|
dilation_size_slider = gr.Slider( |
|
|
0, 10, value=3, step=1, |
|
|
label="Dilation Size", |
|
|
info="Connect nearby edges (0 = no dilation)" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
colormap_edge = gr.Dropdown( |
|
|
["spring", "cool", "hot", "viridis", "plasma"], |
|
|
label="Colormap", value="spring" |
|
|
) |
|
|
transparency_edge = gr.Slider( |
|
|
0.0, 1.0, value=0.5, step=0.1, |
|
|
label="Transparency" |
|
|
) |
|
|
|
|
|
submit_edge_btn = gr.Button("🌊 Run Edge Segmentation", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_edge = gr.Image(label="Edge Segmentation Result", type="filepath") |
|
|
status_edge = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
|
|
|
load_demo_btn.click( |
|
|
fn=load_demo_file, |
|
|
inputs=[], |
|
|
outputs=[file_input, status_text] |
|
|
) |
|
|
|
|
|
submit_btn.click( |
|
|
fn=process_with_status, |
|
|
inputs=[file_input, text_input, modality_dropdown, window_dropdown], |
|
|
outputs=[image_output, status_text] |
|
|
) |
|
|
|
|
|
|
|
|
detect_subjects_btn.click( |
|
|
fn=detect_subjects, |
|
|
inputs=[files_input], |
|
|
outputs=[subject_dropdown, status_batch_text] |
|
|
) |
|
|
|
|
|
|
|
|
submit_batch_btn.click( |
|
|
fn=process_slices_for_viewer, |
|
|
inputs=[files_input, subject_dropdown, text_input_batch, modality_dropdown_batch, window_dropdown_batch], |
|
|
outputs=[current_slice_output, slice_slider, status_batch_text, slice_info_text, subject_dropdown, subject_info_text] |
|
|
).then( |
|
|
lambda max_val: gr.Slider(maximum=max(max_val, 1), interactive=True), |
|
|
inputs=[slice_slider], |
|
|
outputs=[slice_slider] |
|
|
) |
|
|
|
|
|
def update_slice(slice_num, files, selected_subject, prompt, mod, window): |
|
|
result, info = navigate_slice(int(slice_num), files, selected_subject, prompt, mod, window) |
|
|
return result, info |
|
|
|
|
|
slice_slider.change( |
|
|
fn=update_slice, |
|
|
inputs=[slice_slider, files_input, subject_dropdown, text_input_batch, modality_dropdown_batch, window_dropdown_batch], |
|
|
outputs=[current_slice_output, slice_info_text] |
|
|
) |
|
|
|
|
|
def prev_slice(current, files, selected_subject, prompt, mod, window): |
|
|
new_val = max(0, current - 1) |
|
|
result, info = navigate_slice(new_val, files, selected_subject, prompt, mod, window) |
|
|
return new_val, result, info |
|
|
|
|
|
def next_slice(current, max_val, files, selected_subject, prompt, mod, window): |
|
|
new_val = min(max_val, current + 1) |
|
|
result, info = navigate_slice(new_val, files, selected_subject, prompt, mod, window) |
|
|
return new_val, result, info |
|
|
|
|
|
prev_btn.click( |
|
|
fn=prev_slice, |
|
|
inputs=[slice_slider, files_input, subject_dropdown, text_input_batch, modality_dropdown_batch, window_dropdown_batch], |
|
|
outputs=[slice_slider, current_slice_output, slice_info_text] |
|
|
) |
|
|
|
|
|
next_btn.click( |
|
|
fn=next_slice, |
|
|
inputs=[slice_slider, slice_slider, files_input, subject_dropdown, text_input_batch, modality_dropdown_batch, window_dropdown_batch], |
|
|
outputs=[slice_slider, current_slice_output, slice_info_text] |
|
|
) |
|
|
|
|
|
|
|
|
submit_gallery_btn.click( |
|
|
fn=process_sequence, |
|
|
inputs=[files_input_gallery, text_input_gallery, modality_dropdown_gallery, window_dropdown_gallery], |
|
|
outputs=[gallery_output, status_gallery_text] |
|
|
) |
|
|
|
|
|
|
|
|
submit_gt_btn.click( |
|
|
fn=process_with_ground_truth, |
|
|
inputs=[file_input_gt, gt_mask_input, text_input_gt, modality_dropdown_gt, window_dropdown_gt], |
|
|
outputs=[image_output_gt, comparison_output, dice_score_text, iou_score_text, status_gt_text] |
|
|
) |
|
|
|
|
|
|
|
|
def process_enhanced_wrapper(image_file, prompt_text, modality, window_type, |
|
|
brightness, contrast, colormap, transparency, progress=gr.Progress()): |
|
|
"""Wrapper to return both image and download file.""" |
|
|
result, status, metrics = process_with_progress( |
|
|
image_file, prompt_text, modality, window_type, |
|
|
brightness, contrast, colormap, transparency, progress |
|
|
) |
|
|
download_file = result if result else None |
|
|
return result, status, metrics, download_file |
|
|
|
|
|
submit_enh_btn.click( |
|
|
fn=process_enhanced_wrapper, |
|
|
inputs=[ |
|
|
file_input_enh, text_input_enh, modality_enh, window_enh, |
|
|
brightness_slider, contrast_slider, colormap_dropdown, transparency_slider |
|
|
], |
|
|
outputs=[image_output_enh, progress_text, metrics_text, download_output] |
|
|
) |
|
|
|
|
|
|
|
|
submit_enh_batch_btn.click( |
|
|
fn=process_batch_enhanced, |
|
|
inputs=[ |
|
|
files_input_enh_batch, text_input_enh_batch, modality_enh_batch, window_enh_batch, |
|
|
brightness_slider_batch, contrast_slider_batch, colormap_dropdown_batch, transparency_slider_batch |
|
|
], |
|
|
outputs=[gallery_output_enh, batch_download_output, status_enh_batch_text] |
|
|
) |
|
|
|
|
|
|
|
|
submit_point_btn.click( |
|
|
fn=process_with_point_prompt, |
|
|
inputs=[file_input_point, point_x, point_y, modality_point, window_point, colormap_point, transparency_point], |
|
|
outputs=[output_point, status_point] |
|
|
) |
|
|
|
|
|
|
|
|
submit_box_btn.click( |
|
|
fn=process_with_box_prompt, |
|
|
inputs=[file_input_box, box_x1, box_y1, box_x2, box_y2, modality_box, window_box, colormap_box, transparency_box], |
|
|
outputs=[output_box, status_box] |
|
|
) |
|
|
|
|
|
|
|
|
submit_stats_btn.click( |
|
|
fn=process_and_store_mask, |
|
|
inputs=[file_input_stats, text_input_stats, modality_stats, window_stats], |
|
|
outputs=[output_stats, status_stats, roi_stats_text] |
|
|
) |
|
|
|
|
|
|
|
|
export_nifti_btn.click( |
|
|
fn=export_last_mask_nifti, |
|
|
inputs=[], |
|
|
outputs=[nifti_download, status_stats] |
|
|
) |
|
|
|
|
|
|
|
|
save_annotation_btn.click( |
|
|
fn=save_last_annotation, |
|
|
inputs=[], |
|
|
outputs=[annotation_download, status_stats] |
|
|
) |
|
|
|
|
|
|
|
|
load_annotation_btn.click( |
|
|
fn=visualize_loaded_annotation, |
|
|
inputs=[original_image_upload, annotation_upload], |
|
|
outputs=[loaded_annotation_output, loaded_annotation_info] |
|
|
) |
|
|
|
|
|
|
|
|
submit_multi_btn.click( |
|
|
fn=process_multi_mask, |
|
|
inputs=[file_input_multi, text_input_multi, modality_multi, window_multi, num_masks_slider], |
|
|
outputs=[gallery_multi, status_multi, mask_info_multi] |
|
|
) |
|
|
|
|
|
|
|
|
def auto_play_slices(files, selected_subject, prompt, mod, window): |
|
|
"""Auto-play through slices with a short delay.""" |
|
|
if not files: |
|
|
yield None, "No slices loaded", 0 |
|
|
return |
|
|
|
|
|
subject_groups = group_images_by_subject(files) |
|
|
if selected_subject: |
|
|
subject_id = selected_subject.split(" (")[0] |
|
|
else: |
|
|
subject_id = list(subject_groups.keys())[0] if subject_groups else None |
|
|
|
|
|
if not subject_id or subject_id not in subject_groups: |
|
|
yield None, "No slices loaded", 0 |
|
|
return |
|
|
|
|
|
subject_files = subject_groups[subject_id]['files'] |
|
|
cache_key = f"{subject_id}_{len(subject_files)}_{prompt}_{mod}" |
|
|
|
|
|
if cache_key not in processed_results_cache: |
|
|
yield None, "Please process slices first", 0 |
|
|
return |
|
|
|
|
|
results = processed_results_cache[cache_key] |
|
|
|
|
|
for idx in range(len(results)): |
|
|
slice_info = f"Slice {idx + 1}/{len(results)} ({subject_id}) - Auto-playing..." |
|
|
yield results[idx], slice_info, idx |
|
|
time.sleep(0.5) |
|
|
|
|
|
auto_play_btn.click( |
|
|
fn=auto_play_slices, |
|
|
inputs=[files_input, subject_dropdown, text_input_batch, modality_dropdown_batch, window_dropdown_batch], |
|
|
outputs=[current_slice_output, slice_info_text, slice_slider] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
submit_amg_btn.click( |
|
|
fn=automatic_mask_generator, |
|
|
inputs=[file_input_amg, modality_amg, window_amg, points_per_side, min_mask_area, colormap_amg], |
|
|
outputs=[output_amg, status_amg, info_amg] |
|
|
) |
|
|
|
|
|
|
|
|
submit_transform_btn.click( |
|
|
fn=process_with_advanced_transforms, |
|
|
inputs=[file_input_transform, text_input_transform, modality_transform, window_transform, |
|
|
target_size_slider, apply_clahe_checkbox, clahe_clip_slider, colormap_transform, transparency_transform], |
|
|
outputs=[output_transform, status_transform] |
|
|
) |
|
|
|
|
|
|
|
|
submit_edge_btn.click( |
|
|
fn=edge_based_segmentation, |
|
|
inputs=[file_input_edge, modality_edge, window_edge, edge_threshold_slider, dilation_size_slider, |
|
|
colormap_edge, transparency_edge], |
|
|
outputs=[output_edge, status_edge] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
if model is None or processor is None: |
|
|
print("⚠️ WARNING: SAM 3 model failed to load!") |
|
|
print("⚠️ The app will start but segmentation features will not work.") |
|
|
print("⚠️ Please check:") |
|
|
print(" 1. HF_TOKEN environment variable is set correctly") |
|
|
print(" 2. transformers>=4.45.0 is installed") |
|
|
print(" 3. Sufficient memory/GPU available") |
|
|
else: |
|
|
print("✅ SAM 3 model ready - app starting...") |
|
|
|
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|
|