from typing import Dict, List, Optional, Tuple, Type, Any from pathlib import Path import uuid import tempfile import numpy as np import torch import torchvision import torchxrayvision as xrv import matplotlib.pyplot as plt import skimage.io import skimage.measure import skimage.transform import traceback from pydantic import BaseModel, Field from langchain_core.callbacks import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) from langchain_core.tools import BaseTool from medrax.utils.utils import preprocess_medical_image class ChestXRaySegmentationInput(BaseModel): """Input schema for the Chest X-ray Segmentation Tool.""" image_path: str = Field(..., description="Path to the chest X-ray image file to be segmented") organs: Optional[List[str]] = Field( None, description="List of organs to segment. If None, all available organs will be segmented. " "Available organs: Left/Right Clavicle, Left/Right Scapula, Left/Right Lung, " "Left/Right Hilus Pulmonis, Heart, Aorta, Facies Diaphragmatica, " "Mediastinum, Weasand, Spine", ) class OrganMetrics(BaseModel): """Detailed metrics for a segmented organ.""" # Basic metrics area_pixels: int = Field(..., description="Area in pixels") area_cm2: float = Field(..., description="Approximate area in cm²") centroid: Tuple[float, float] = Field(..., description="(y, x) coordinates of centroid") bbox: Tuple[int, int, int, int] = Field(..., description="Bounding box coordinates (min_y, min_x, max_y, max_x)") # Size metrics width: int = Field(..., description="Width of the organ in pixels") height: int = Field(..., description="Height of the organ in pixels") aspect_ratio: float = Field(..., description="Height/width ratio") # Position metrics relative_position: Dict[str, float] = Field(..., description="Position relative to image boundaries (0-1 scale)") # Analysis metrics mean_intensity: float = Field(..., description="Mean pixel intensity in the organ region") std_intensity: float = Field(..., description="Standard deviation of pixel intensity") confidence_score: float = Field(..., description="Model confidence score for this organ") class ChestXRaySegmentationTool(BaseTool): """Tool for performing detailed segmentation analysis of chest X-ray images.""" name: str = "chest_xray_segmentation" description: str = ( "Segments chest X-ray images to specified anatomical structures. " "Available organs: Left/Right Clavicle (collar bones), Left/Right Scapula (shoulder blades), " "Left/Right Lung, Left/Right Hilus Pulmonis (lung roots), Heart, Aorta, " "Facies Diaphragmatica (diaphragm), Mediastinum (central cavity), Weasand (esophagus), " "and Spine. Returns segmentation visualization and comprehensive metrics. " "Let the user know the area is not accurate unless input has been DICOM." ) args_schema: Type[BaseModel] = ChestXRaySegmentationInput model: Any = None device: Optional[str] = "cuda" transform: Any = None pixel_spacing_mm: float = 0.2 temp_dir: Path = Path("temp") organ_map: Dict[str, int] = None def __init__(self, device: Optional[str] = "cuda", temp_dir: Optional[Path] = Path("temp")): """Initialize the segmentation tool with model and temporary directory.""" super().__init__() self.model = xrv.baseline_models.chestx_det.PSPNet() self.device = torch.device(device) if device else "cuda" self.model = self.model.to(self.device) self.model.eval() self.transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(512)]) self.temp_dir = temp_dir if isinstance(temp_dir, Path) else Path(temp_dir) self.temp_dir.mkdir(exist_ok=True) # Map friendly names to model target indices self.organ_map = { "Left Clavicle": 0, "Right Clavicle": 1, "Left Scapula": 2, "Right Scapula": 3, "Left Lung": 4, "Right Lung": 5, "Left Hilus Pulmonis": 6, "Right Hilus Pulmonis": 7, "Heart": 8, "Aorta": 9, "Facies Diaphragmatica": 10, "Mediastinum": 11, "Weasand": 12, "Spine": 13, } def _align_mask_to_original(self, mask: np.ndarray, original_shape: Tuple[int, int]) -> np.ndarray: """ Align a mask from the transformed (cropped/resized) space back to the full original image. Assumes that the transform does a center crop to a square of side = min(original height, width) and then resizes to (512,512). """ orig_h, orig_w = original_shape crop_size = min(orig_h, orig_w) crop_top = (orig_h - crop_size) // 2 crop_left = (orig_w - crop_size) // 2 # Resize mask (from 512x512) to the cropped region size resized_mask = skimage.transform.resize( mask, (crop_size, crop_size), order=0, preserve_range=True, anti_aliasing=False ) full_mask = np.zeros(original_shape) full_mask[crop_top : crop_top + crop_size, crop_left : crop_left + crop_size] = resized_mask return full_mask def _compute_organ_metrics( self, mask: np.ndarray, original_img: np.ndarray, confidence: float ) -> Optional[OrganMetrics]: """Compute comprehensive metrics for a single organ mask.""" # Align mask to the original image coordinates if needed if mask.shape != original_img.shape: mask = self._align_mask_to_original(mask, original_img.shape) props = skimage.measure.regionprops(mask.astype(int)) if not props: return None props = props[0] area_cm2 = mask.sum() * (self.pixel_spacing_mm / 10) ** 2 img_height, img_width = mask.shape cy, cx = props.centroid relative_pos = { "top": cy / img_height, "left": cx / img_width, "center_dist": np.sqrt(((cy / img_height - 0.5) ** 2 + (cx / img_width - 0.5) ** 2)), } organ_pixels = original_img[mask > 0] mean_intensity = organ_pixels.mean() if len(organ_pixels) > 0 else 0 std_intensity = organ_pixels.std() if len(organ_pixels) > 0 else 0 return OrganMetrics( area_pixels=int(mask.sum()), area_cm2=float(area_cm2), centroid=(float(cy), float(cx)), bbox=tuple(map(int, props.bbox)), width=int(props.bbox[3] - props.bbox[1]), height=int(props.bbox[2] - props.bbox[0]), aspect_ratio=float((props.bbox[2] - props.bbox[0]) / max(1, props.bbox[3] - props.bbox[1])), relative_position=relative_pos, mean_intensity=float(mean_intensity), std_intensity=float(std_intensity), confidence_score=float(confidence), ) def _save_visualization(self, original_img: np.ndarray, pred_masks: torch.Tensor, organ_indices: List[int]) -> str: """Save visualization of original image with segmentation masks overlaid.""" # Create single panel with overlay fig, ax = plt.subplots(figsize=(10, 10)) # Generate color palette for organs colors = plt.cm.tab10(np.linspace(0, 1, min(len(organ_indices), 10))) # Create combined mask for visualization combined_mask = np.zeros(original_img.shape) masks_found = 0 legend_items = [] for idx, (organ_idx, color) in enumerate(zip(organ_indices, colors)): mask = pred_masks[0, organ_idx].cpu().numpy() # Debug: print mask info print(f"Organ index {organ_idx}: mask sum = {mask.sum()}, mask shape = {mask.shape}") if mask.sum() > 0: masks_found += 1 original_mask_sum = mask.sum() # Align the mask to the original image coordinates if mask.shape != original_img.shape: aligned_mask = self._align_mask_to_original(mask, original_img.shape) print(f"Aligned mask shape: {aligned_mask.shape}, sum: {aligned_mask.sum()} (was {original_mask_sum})") # If alignment lost too much of the mask, use unaligned if aligned_mask.sum() < original_mask_sum * 0.1: print(f"Warning: Alignment lost {(1 - aligned_mask.sum()/original_mask_sum)*100:.1f}% of mask, using unaligned") # Resize to original without alignment aligned_mask = skimage.transform.resize( mask, original_img.shape, order=0, preserve_range=True, anti_aliasing=False ) mask = aligned_mask else: mask = mask # Add to combined mask with organ index combined_mask[mask > 0] = idx + 1 # Add legend entry organ_name = list(self.organ_map.keys())[list(self.organ_map.values()).index(organ_idx)] legend_items.append((organ_name, color)) print(f"Total masks found and rendered: {masks_found}") # Display original image ax.imshow(original_img, cmap="gray") # Overlay masks with contours if masks_found > 0: from matplotlib.patches import Patch for idx, (organ_name, color) in enumerate(legend_items): mask_region = (combined_mask == idx + 1) # Create colored overlay overlay = np.zeros((*original_img.shape, 4)) overlay[mask_region] = [color[0], color[1], color[2], 0.5] ax.imshow(overlay) # Draw contours for clear boundaries contours = skimage.measure.find_contours(mask_region.astype(float), 0.5) for contour in contours: ax.plot(contour[:, 1], contour[:, 0], color=color, linewidth=3, alpha=0.9) # Add legend patches = [Patch(facecolor=c, edgecolor=c, label=n, alpha=0.7) for n, c in legend_items] ax.legend(handles=patches, loc="upper right", fontsize=10, framealpha=0.9) ax.set_title("Segmentation Overlay", fontsize=14, color='white', pad=15) else: ax.set_title("No Masks Detected", fontsize=14, color='red', pad=15) ax.axis("off") fig.patch.set_facecolor('black') save_path = self.temp_dir / f"segmentation_{uuid.uuid4().hex[:8]}.png" plt.savefig(save_path, bbox_inches="tight", dpi=150, facecolor='black') plt.close(fig) return str(save_path) def _run( self, image_path: str, organs: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> Tuple[Dict[str, Any], Dict]: """Run segmentation analysis for specified organs.""" try: # Validate and get organ indices if organs: organs = [o.strip() for o in organs] invalid_organs = [o for o in organs if o not in self.organ_map] if invalid_organs: raise ValueError(f"Invalid organs specified: {invalid_organs}") organ_indices = [self.organ_map[o] for o in organs] else: organ_indices = list(self.organ_map.values()) organs = list(self.organ_map.keys()) # Load and process image original_img = skimage.io.imread(image_path) print(f"\n=== Image Loading Debug ===") print(f"Image path: {image_path}") print(f"Original shape: {original_img.shape}, dtype: {original_img.dtype}") print(f"Original range: [{original_img.min()}, {original_img.max()}]") if len(original_img.shape) > 2: original_img = original_img[:, :, 0] print(f"After channel extraction: {original_img.shape}") # TorchXRayVision models expect images in the range [-1024, 1024] (Hounsfield units) # NOT normalized to [0, 1]! We need to scale 8-bit images to this range. # IMPORTANT: PNG/JPEG X-rays are typically INVERTED compared to DICOM # (lungs appear bright instead of dark), so we need to invert them first # EXCEPTION: Processed DICOM files (saved as PNG) should NOT be inverted is_processed_dicom = "processed_dicom" in image_path if original_img.dtype == np.uint8 or original_img.max() <= 255: if is_processed_dicom: # Processed DICOM - don't invert, just scale to HU range img = (original_img.astype(np.float32) / 255.0) * 1624 - 1024 print(f"Processed DICOM - converted to HU range without inversion: [{img.min():.1f}, {img.max():.1f}]") else: # Regular PNG/JPEG - invert first, then scale inverted = 255 - original_img.astype(np.float32) img = (inverted / 255.0) * 1624 - 1024 print(f"PNG/JPEG - inverted and converted to HU range: [{img.min():.1f}, {img.max():.1f}]") else: # Assume already in HU or similar range (raw DICOM) img = original_img.astype(np.float32) print(f"Kept original range (raw DICOM): [{img.min():.1f}, {img.max():.1f}]") img = img[None, ...] print(f"After adding batch dim: {img.shape}") img = self.transform(img) print(f"After transform: {img.shape}") img = torch.from_numpy(img) img = img.to(self.device) print(f"Final tensor: shape={img.shape}, dtype={img.dtype}, device={img.device}") print(f"Tensor stats: mean={img.mean():.3f}, std={img.std():.3f}, min={img.min():.3f}, max={img.max():.3f}") # Generate predictions with torch.no_grad(): pred = self.model(img) print(f"\nModel output shape: {pred.shape}") print(f"Raw predictions: min={pred.min().item():.3f}, max={pred.max().item():.3f}, mean={pred.mean().item():.3f}") pred_probs = torch.sigmoid(pred) print(f"After sigmoid: min={pred_probs.min().item():.3f}, max={pred_probs.max().item():.3f}, mean={pred_probs.mean().item():.3f}") # Print probabilities for debugging print(f"\n=== Segmentation Probabilities ===") for organ_idx in organ_indices: organ_name = list(self.organ_map.keys())[list(self.organ_map.values()).index(organ_idx)] prob = pred_probs[0, organ_idx].mean().item() max_prob = pred_probs[0, organ_idx].max().item() print(f"{organ_name}: mean={prob:.3f}, max={max_prob:.3f}") print(f"==================================\n") # Use lower threshold (0.3) to capture more masks, especially for non-DICOM images # The model tends to be less confident on non-DICOM images pred_masks = (pred_probs > 0.3).float() print(f"Masks after 0.3 threshold: {[f'{(pred_masks[0,i].sum().item())} pixels' for i in organ_indices]}") # Save visualization viz_path = self._save_visualization(original_img, pred_masks, organ_indices) # Compute metrics for selected organs results = {} for idx, organ_name in zip(organ_indices, organs): mask = pred_masks[0, idx].cpu().numpy() if mask.sum() > 0: metrics = self._compute_organ_metrics(mask, original_img, float(pred_probs[0, idx].mean().cpu())) if metrics: results[organ_name] = metrics output = { "segmentation_image_path": viz_path, "metrics": {organ: metrics.dict() for organ, metrics in results.items()}, } metadata = { "image_path": image_path, "segmentation_image_path": viz_path, "original_size": original_img.shape, "model_size": tuple(img.shape[-2:]), "pixel_spacing_mm": self.pixel_spacing_mm, "requested_organs": organs, "processed_organs": list(results.keys()), "analysis_status": "completed", } return output, metadata except Exception as e: error_output = {"error": str(e)} error_metadata = { "image_path": image_path, "analysis_status": "failed", "error_traceback": traceback.format_exc(), } return error_output, error_metadata async def _arun( self, image_path: str, organs: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> Tuple[Dict[str, Any], Dict]: """Async version of _run.""" return self._run(image_path, organs)