medrax2 / medrax /tools /segmentation /segmentation.py
samwell's picture
fix: Don't invert processed DICOM files for segmentation
8bb9754
from typing import Dict, List, Optional, Tuple, Type, Any
from pathlib import Path
import uuid
import tempfile
import numpy as np
import torch
import torchvision
import torchxrayvision as xrv
import matplotlib.pyplot as plt
import skimage.io
import skimage.measure
import skimage.transform
import traceback
from pydantic import BaseModel, Field
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool
from medrax.utils.utils import preprocess_medical_image
class ChestXRaySegmentationInput(BaseModel):
"""Input schema for the Chest X-ray Segmentation Tool."""
image_path: str = Field(..., description="Path to the chest X-ray image file to be segmented")
organs: Optional[List[str]] = Field(
None,
description="List of organs to segment. If None, all available organs will be segmented. "
"Available organs: Left/Right Clavicle, Left/Right Scapula, Left/Right Lung, "
"Left/Right Hilus Pulmonis, Heart, Aorta, Facies Diaphragmatica, "
"Mediastinum, Weasand, Spine",
)
class OrganMetrics(BaseModel):
"""Detailed metrics for a segmented organ."""
# Basic metrics
area_pixels: int = Field(..., description="Area in pixels")
area_cm2: float = Field(..., description="Approximate area in cm²")
centroid: Tuple[float, float] = Field(..., description="(y, x) coordinates of centroid")
bbox: Tuple[int, int, int, int] = Field(..., description="Bounding box coordinates (min_y, min_x, max_y, max_x)")
# Size metrics
width: int = Field(..., description="Width of the organ in pixels")
height: int = Field(..., description="Height of the organ in pixels")
aspect_ratio: float = Field(..., description="Height/width ratio")
# Position metrics
relative_position: Dict[str, float] = Field(..., description="Position relative to image boundaries (0-1 scale)")
# Analysis metrics
mean_intensity: float = Field(..., description="Mean pixel intensity in the organ region")
std_intensity: float = Field(..., description="Standard deviation of pixel intensity")
confidence_score: float = Field(..., description="Model confidence score for this organ")
class ChestXRaySegmentationTool(BaseTool):
"""Tool for performing detailed segmentation analysis of chest X-ray images."""
name: str = "chest_xray_segmentation"
description: str = (
"Segments chest X-ray images to specified anatomical structures. "
"Available organs: Left/Right Clavicle (collar bones), Left/Right Scapula (shoulder blades), "
"Left/Right Lung, Left/Right Hilus Pulmonis (lung roots), Heart, Aorta, "
"Facies Diaphragmatica (diaphragm), Mediastinum (central cavity), Weasand (esophagus), "
"and Spine. Returns segmentation visualization and comprehensive metrics. "
"Let the user know the area is not accurate unless input has been DICOM."
)
args_schema: Type[BaseModel] = ChestXRaySegmentationInput
model: Any = None
device: Optional[str] = "cuda"
transform: Any = None
pixel_spacing_mm: float = 0.2
temp_dir: Path = Path("temp")
organ_map: Dict[str, int] = None
def __init__(self, device: Optional[str] = "cuda", temp_dir: Optional[Path] = Path("temp")):
"""Initialize the segmentation tool with model and temporary directory."""
super().__init__()
self.model = xrv.baseline_models.chestx_det.PSPNet()
self.device = torch.device(device) if device else "cuda"
self.model = self.model.to(self.device)
self.model.eval()
self.transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(512)])
self.temp_dir = temp_dir if isinstance(temp_dir, Path) else Path(temp_dir)
self.temp_dir.mkdir(exist_ok=True)
# Map friendly names to model target indices
self.organ_map = {
"Left Clavicle": 0,
"Right Clavicle": 1,
"Left Scapula": 2,
"Right Scapula": 3,
"Left Lung": 4,
"Right Lung": 5,
"Left Hilus Pulmonis": 6,
"Right Hilus Pulmonis": 7,
"Heart": 8,
"Aorta": 9,
"Facies Diaphragmatica": 10,
"Mediastinum": 11,
"Weasand": 12,
"Spine": 13,
}
def _align_mask_to_original(self, mask: np.ndarray, original_shape: Tuple[int, int]) -> np.ndarray:
"""
Align a mask from the transformed (cropped/resized) space back to the full original image.
Assumes that the transform does a center crop to a square of side = min(original height, width)
and then resizes to (512,512).
"""
orig_h, orig_w = original_shape
crop_size = min(orig_h, orig_w)
crop_top = (orig_h - crop_size) // 2
crop_left = (orig_w - crop_size) // 2
# Resize mask (from 512x512) to the cropped region size
resized_mask = skimage.transform.resize(
mask, (crop_size, crop_size), order=0, preserve_range=True, anti_aliasing=False
)
full_mask = np.zeros(original_shape)
full_mask[crop_top : crop_top + crop_size, crop_left : crop_left + crop_size] = resized_mask
return full_mask
def _compute_organ_metrics(
self, mask: np.ndarray, original_img: np.ndarray, confidence: float
) -> Optional[OrganMetrics]:
"""Compute comprehensive metrics for a single organ mask."""
# Align mask to the original image coordinates if needed
if mask.shape != original_img.shape:
mask = self._align_mask_to_original(mask, original_img.shape)
props = skimage.measure.regionprops(mask.astype(int))
if not props:
return None
props = props[0]
area_cm2 = mask.sum() * (self.pixel_spacing_mm / 10) ** 2
img_height, img_width = mask.shape
cy, cx = props.centroid
relative_pos = {
"top": cy / img_height,
"left": cx / img_width,
"center_dist": np.sqrt(((cy / img_height - 0.5) ** 2 + (cx / img_width - 0.5) ** 2)),
}
organ_pixels = original_img[mask > 0]
mean_intensity = organ_pixels.mean() if len(organ_pixels) > 0 else 0
std_intensity = organ_pixels.std() if len(organ_pixels) > 0 else 0
return OrganMetrics(
area_pixels=int(mask.sum()),
area_cm2=float(area_cm2),
centroid=(float(cy), float(cx)),
bbox=tuple(map(int, props.bbox)),
width=int(props.bbox[3] - props.bbox[1]),
height=int(props.bbox[2] - props.bbox[0]),
aspect_ratio=float((props.bbox[2] - props.bbox[0]) / max(1, props.bbox[3] - props.bbox[1])),
relative_position=relative_pos,
mean_intensity=float(mean_intensity),
std_intensity=float(std_intensity),
confidence_score=float(confidence),
)
def _save_visualization(self, original_img: np.ndarray, pred_masks: torch.Tensor, organ_indices: List[int]) -> str:
"""Save visualization of original image with segmentation masks overlaid."""
# Create single panel with overlay
fig, ax = plt.subplots(figsize=(10, 10))
# Generate color palette for organs
colors = plt.cm.tab10(np.linspace(0, 1, min(len(organ_indices), 10)))
# Create combined mask for visualization
combined_mask = np.zeros(original_img.shape)
masks_found = 0
legend_items = []
for idx, (organ_idx, color) in enumerate(zip(organ_indices, colors)):
mask = pred_masks[0, organ_idx].cpu().numpy()
# Debug: print mask info
print(f"Organ index {organ_idx}: mask sum = {mask.sum()}, mask shape = {mask.shape}")
if mask.sum() > 0:
masks_found += 1
original_mask_sum = mask.sum()
# Align the mask to the original image coordinates
if mask.shape != original_img.shape:
aligned_mask = self._align_mask_to_original(mask, original_img.shape)
print(f"Aligned mask shape: {aligned_mask.shape}, sum: {aligned_mask.sum()} (was {original_mask_sum})")
# If alignment lost too much of the mask, use unaligned
if aligned_mask.sum() < original_mask_sum * 0.1:
print(f"Warning: Alignment lost {(1 - aligned_mask.sum()/original_mask_sum)*100:.1f}% of mask, using unaligned")
# Resize to original without alignment
aligned_mask = skimage.transform.resize(
mask, original_img.shape, order=0, preserve_range=True, anti_aliasing=False
)
mask = aligned_mask
else:
mask = mask
# Add to combined mask with organ index
combined_mask[mask > 0] = idx + 1
# Add legend entry
organ_name = list(self.organ_map.keys())[list(self.organ_map.values()).index(organ_idx)]
legend_items.append((organ_name, color))
print(f"Total masks found and rendered: {masks_found}")
# Display original image
ax.imshow(original_img, cmap="gray")
# Overlay masks with contours
if masks_found > 0:
from matplotlib.patches import Patch
for idx, (organ_name, color) in enumerate(legend_items):
mask_region = (combined_mask == idx + 1)
# Create colored overlay
overlay = np.zeros((*original_img.shape, 4))
overlay[mask_region] = [color[0], color[1], color[2], 0.5]
ax.imshow(overlay)
# Draw contours for clear boundaries
contours = skimage.measure.find_contours(mask_region.astype(float), 0.5)
for contour in contours:
ax.plot(contour[:, 1], contour[:, 0], color=color, linewidth=3, alpha=0.9)
# Add legend
patches = [Patch(facecolor=c, edgecolor=c, label=n, alpha=0.7) for n, c in legend_items]
ax.legend(handles=patches, loc="upper right", fontsize=10, framealpha=0.9)
ax.set_title("Segmentation Overlay", fontsize=14, color='white', pad=15)
else:
ax.set_title("No Masks Detected", fontsize=14, color='red', pad=15)
ax.axis("off")
fig.patch.set_facecolor('black')
save_path = self.temp_dir / f"segmentation_{uuid.uuid4().hex[:8]}.png"
plt.savefig(save_path, bbox_inches="tight", dpi=150, facecolor='black')
plt.close(fig)
return str(save_path)
def _run(
self,
image_path: str,
organs: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Tuple[Dict[str, Any], Dict]:
"""Run segmentation analysis for specified organs."""
try:
# Validate and get organ indices
if organs:
organs = [o.strip() for o in organs]
invalid_organs = [o for o in organs if o not in self.organ_map]
if invalid_organs:
raise ValueError(f"Invalid organs specified: {invalid_organs}")
organ_indices = [self.organ_map[o] for o in organs]
else:
organ_indices = list(self.organ_map.values())
organs = list(self.organ_map.keys())
# Load and process image
original_img = skimage.io.imread(image_path)
print(f"\n=== Image Loading Debug ===")
print(f"Image path: {image_path}")
print(f"Original shape: {original_img.shape}, dtype: {original_img.dtype}")
print(f"Original range: [{original_img.min()}, {original_img.max()}]")
if len(original_img.shape) > 2:
original_img = original_img[:, :, 0]
print(f"After channel extraction: {original_img.shape}")
# TorchXRayVision models expect images in the range [-1024, 1024] (Hounsfield units)
# NOT normalized to [0, 1]! We need to scale 8-bit images to this range.
# IMPORTANT: PNG/JPEG X-rays are typically INVERTED compared to DICOM
# (lungs appear bright instead of dark), so we need to invert them first
# EXCEPTION: Processed DICOM files (saved as PNG) should NOT be inverted
is_processed_dicom = "processed_dicom" in image_path
if original_img.dtype == np.uint8 or original_img.max() <= 255:
if is_processed_dicom:
# Processed DICOM - don't invert, just scale to HU range
img = (original_img.astype(np.float32) / 255.0) * 1624 - 1024
print(f"Processed DICOM - converted to HU range without inversion: [{img.min():.1f}, {img.max():.1f}]")
else:
# Regular PNG/JPEG - invert first, then scale
inverted = 255 - original_img.astype(np.float32)
img = (inverted / 255.0) * 1624 - 1024
print(f"PNG/JPEG - inverted and converted to HU range: [{img.min():.1f}, {img.max():.1f}]")
else:
# Assume already in HU or similar range (raw DICOM)
img = original_img.astype(np.float32)
print(f"Kept original range (raw DICOM): [{img.min():.1f}, {img.max():.1f}]")
img = img[None, ...]
print(f"After adding batch dim: {img.shape}")
img = self.transform(img)
print(f"After transform: {img.shape}")
img = torch.from_numpy(img)
img = img.to(self.device)
print(f"Final tensor: shape={img.shape}, dtype={img.dtype}, device={img.device}")
print(f"Tensor stats: mean={img.mean():.3f}, std={img.std():.3f}, min={img.min():.3f}, max={img.max():.3f}")
# Generate predictions
with torch.no_grad():
pred = self.model(img)
print(f"\nModel output shape: {pred.shape}")
print(f"Raw predictions: min={pred.min().item():.3f}, max={pred.max().item():.3f}, mean={pred.mean().item():.3f}")
pred_probs = torch.sigmoid(pred)
print(f"After sigmoid: min={pred_probs.min().item():.3f}, max={pred_probs.max().item():.3f}, mean={pred_probs.mean().item():.3f}")
# Print probabilities for debugging
print(f"\n=== Segmentation Probabilities ===")
for organ_idx in organ_indices:
organ_name = list(self.organ_map.keys())[list(self.organ_map.values()).index(organ_idx)]
prob = pred_probs[0, organ_idx].mean().item()
max_prob = pred_probs[0, organ_idx].max().item()
print(f"{organ_name}: mean={prob:.3f}, max={max_prob:.3f}")
print(f"==================================\n")
# Use lower threshold (0.3) to capture more masks, especially for non-DICOM images
# The model tends to be less confident on non-DICOM images
pred_masks = (pred_probs > 0.3).float()
print(f"Masks after 0.3 threshold: {[f'{(pred_masks[0,i].sum().item())} pixels' for i in organ_indices]}")
# Save visualization
viz_path = self._save_visualization(original_img, pred_masks, organ_indices)
# Compute metrics for selected organs
results = {}
for idx, organ_name in zip(organ_indices, organs):
mask = pred_masks[0, idx].cpu().numpy()
if mask.sum() > 0:
metrics = self._compute_organ_metrics(mask, original_img, float(pred_probs[0, idx].mean().cpu()))
if metrics:
results[organ_name] = metrics
output = {
"segmentation_image_path": viz_path,
"metrics": {organ: metrics.dict() for organ, metrics in results.items()},
}
metadata = {
"image_path": image_path,
"segmentation_image_path": viz_path,
"original_size": original_img.shape,
"model_size": tuple(img.shape[-2:]),
"pixel_spacing_mm": self.pixel_spacing_mm,
"requested_organs": organs,
"processed_organs": list(results.keys()),
"analysis_status": "completed",
}
return output, metadata
except Exception as e:
error_output = {"error": str(e)}
error_metadata = {
"image_path": image_path,
"analysis_status": "failed",
"error_traceback": traceback.format_exc(),
}
return error_output, error_metadata
async def _arun(
self,
image_path: str,
organs: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Tuple[Dict[str, Any], Dict]:
"""Async version of _run."""
return self._run(image_path, organs)