Spaces:
Paused
Paused
| from typing import Dict, List, Optional, Tuple, Type, Any | |
| from pathlib import Path | |
| import uuid | |
| import tempfile | |
| import numpy as np | |
| import torch | |
| import torchvision | |
| import torchxrayvision as xrv | |
| import matplotlib.pyplot as plt | |
| import skimage.io | |
| import skimage.measure | |
| import skimage.transform | |
| import traceback | |
| from pydantic import BaseModel, Field | |
| from langchain_core.callbacks import ( | |
| AsyncCallbackManagerForToolRun, | |
| CallbackManagerForToolRun, | |
| ) | |
| from langchain_core.tools import BaseTool | |
| from medrax.utils.utils import preprocess_medical_image | |
| class ChestXRaySegmentationInput(BaseModel): | |
| """Input schema for the Chest X-ray Segmentation Tool.""" | |
| image_path: str = Field(..., description="Path to the chest X-ray image file to be segmented") | |
| organs: Optional[List[str]] = Field( | |
| None, | |
| description="List of organs to segment. If None, all available organs will be segmented. " | |
| "Available organs: Left/Right Clavicle, Left/Right Scapula, Left/Right Lung, " | |
| "Left/Right Hilus Pulmonis, Heart, Aorta, Facies Diaphragmatica, " | |
| "Mediastinum, Weasand, Spine", | |
| ) | |
| class OrganMetrics(BaseModel): | |
| """Detailed metrics for a segmented organ.""" | |
| # Basic metrics | |
| area_pixels: int = Field(..., description="Area in pixels") | |
| area_cm2: float = Field(..., description="Approximate area in cm²") | |
| centroid: Tuple[float, float] = Field(..., description="(y, x) coordinates of centroid") | |
| bbox: Tuple[int, int, int, int] = Field(..., description="Bounding box coordinates (min_y, min_x, max_y, max_x)") | |
| # Size metrics | |
| width: int = Field(..., description="Width of the organ in pixels") | |
| height: int = Field(..., description="Height of the organ in pixels") | |
| aspect_ratio: float = Field(..., description="Height/width ratio") | |
| # Position metrics | |
| relative_position: Dict[str, float] = Field(..., description="Position relative to image boundaries (0-1 scale)") | |
| # Analysis metrics | |
| mean_intensity: float = Field(..., description="Mean pixel intensity in the organ region") | |
| std_intensity: float = Field(..., description="Standard deviation of pixel intensity") | |
| confidence_score: float = Field(..., description="Model confidence score for this organ") | |
| class ChestXRaySegmentationTool(BaseTool): | |
| """Tool for performing detailed segmentation analysis of chest X-ray images.""" | |
| name: str = "chest_xray_segmentation" | |
| description: str = ( | |
| "Segments chest X-ray images to specified anatomical structures. " | |
| "Available organs: Left/Right Clavicle (collar bones), Left/Right Scapula (shoulder blades), " | |
| "Left/Right Lung, Left/Right Hilus Pulmonis (lung roots), Heart, Aorta, " | |
| "Facies Diaphragmatica (diaphragm), Mediastinum (central cavity), Weasand (esophagus), " | |
| "and Spine. Returns segmentation visualization and comprehensive metrics. " | |
| "Let the user know the area is not accurate unless input has been DICOM." | |
| ) | |
| args_schema: Type[BaseModel] = ChestXRaySegmentationInput | |
| model: Any = None | |
| device: Optional[str] = "cuda" | |
| transform: Any = None | |
| pixel_spacing_mm: float = 0.2 | |
| temp_dir: Path = Path("temp") | |
| organ_map: Dict[str, int] = None | |
| def __init__(self, device: Optional[str] = "cuda", temp_dir: Optional[Path] = Path("temp")): | |
| """Initialize the segmentation tool with model and temporary directory.""" | |
| super().__init__() | |
| self.model = xrv.baseline_models.chestx_det.PSPNet() | |
| self.device = torch.device(device) if device else "cuda" | |
| self.model = self.model.to(self.device) | |
| self.model.eval() | |
| self.transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(), xrv.datasets.XRayResizer(512)]) | |
| self.temp_dir = temp_dir if isinstance(temp_dir, Path) else Path(temp_dir) | |
| self.temp_dir.mkdir(exist_ok=True) | |
| # Map friendly names to model target indices | |
| self.organ_map = { | |
| "Left Clavicle": 0, | |
| "Right Clavicle": 1, | |
| "Left Scapula": 2, | |
| "Right Scapula": 3, | |
| "Left Lung": 4, | |
| "Right Lung": 5, | |
| "Left Hilus Pulmonis": 6, | |
| "Right Hilus Pulmonis": 7, | |
| "Heart": 8, | |
| "Aorta": 9, | |
| "Facies Diaphragmatica": 10, | |
| "Mediastinum": 11, | |
| "Weasand": 12, | |
| "Spine": 13, | |
| } | |
| def _align_mask_to_original(self, mask: np.ndarray, original_shape: Tuple[int, int]) -> np.ndarray: | |
| """ | |
| Align a mask from the transformed (cropped/resized) space back to the full original image. | |
| Assumes that the transform does a center crop to a square of side = min(original height, width) | |
| and then resizes to (512,512). | |
| """ | |
| orig_h, orig_w = original_shape | |
| crop_size = min(orig_h, orig_w) | |
| crop_top = (orig_h - crop_size) // 2 | |
| crop_left = (orig_w - crop_size) // 2 | |
| # Resize mask (from 512x512) to the cropped region size | |
| resized_mask = skimage.transform.resize( | |
| mask, (crop_size, crop_size), order=0, preserve_range=True, anti_aliasing=False | |
| ) | |
| full_mask = np.zeros(original_shape) | |
| full_mask[crop_top : crop_top + crop_size, crop_left : crop_left + crop_size] = resized_mask | |
| return full_mask | |
| def _compute_organ_metrics( | |
| self, mask: np.ndarray, original_img: np.ndarray, confidence: float | |
| ) -> Optional[OrganMetrics]: | |
| """Compute comprehensive metrics for a single organ mask.""" | |
| # Align mask to the original image coordinates if needed | |
| if mask.shape != original_img.shape: | |
| mask = self._align_mask_to_original(mask, original_img.shape) | |
| props = skimage.measure.regionprops(mask.astype(int)) | |
| if not props: | |
| return None | |
| props = props[0] | |
| area_cm2 = mask.sum() * (self.pixel_spacing_mm / 10) ** 2 | |
| img_height, img_width = mask.shape | |
| cy, cx = props.centroid | |
| relative_pos = { | |
| "top": cy / img_height, | |
| "left": cx / img_width, | |
| "center_dist": np.sqrt(((cy / img_height - 0.5) ** 2 + (cx / img_width - 0.5) ** 2)), | |
| } | |
| organ_pixels = original_img[mask > 0] | |
| mean_intensity = organ_pixels.mean() if len(organ_pixels) > 0 else 0 | |
| std_intensity = organ_pixels.std() if len(organ_pixels) > 0 else 0 | |
| return OrganMetrics( | |
| area_pixels=int(mask.sum()), | |
| area_cm2=float(area_cm2), | |
| centroid=(float(cy), float(cx)), | |
| bbox=tuple(map(int, props.bbox)), | |
| width=int(props.bbox[3] - props.bbox[1]), | |
| height=int(props.bbox[2] - props.bbox[0]), | |
| aspect_ratio=float((props.bbox[2] - props.bbox[0]) / max(1, props.bbox[3] - props.bbox[1])), | |
| relative_position=relative_pos, | |
| mean_intensity=float(mean_intensity), | |
| std_intensity=float(std_intensity), | |
| confidence_score=float(confidence), | |
| ) | |
| def _save_visualization(self, original_img: np.ndarray, pred_masks: torch.Tensor, organ_indices: List[int]) -> str: | |
| """Save visualization of original image with segmentation masks overlaid.""" | |
| # Create single panel with overlay | |
| fig, ax = plt.subplots(figsize=(10, 10)) | |
| # Generate color palette for organs | |
| colors = plt.cm.tab10(np.linspace(0, 1, min(len(organ_indices), 10))) | |
| # Create combined mask for visualization | |
| combined_mask = np.zeros(original_img.shape) | |
| masks_found = 0 | |
| legend_items = [] | |
| for idx, (organ_idx, color) in enumerate(zip(organ_indices, colors)): | |
| mask = pred_masks[0, organ_idx].cpu().numpy() | |
| # Debug: print mask info | |
| print(f"Organ index {organ_idx}: mask sum = {mask.sum()}, mask shape = {mask.shape}") | |
| if mask.sum() > 0: | |
| masks_found += 1 | |
| original_mask_sum = mask.sum() | |
| # Align the mask to the original image coordinates | |
| if mask.shape != original_img.shape: | |
| aligned_mask = self._align_mask_to_original(mask, original_img.shape) | |
| print(f"Aligned mask shape: {aligned_mask.shape}, sum: {aligned_mask.sum()} (was {original_mask_sum})") | |
| # If alignment lost too much of the mask, use unaligned | |
| if aligned_mask.sum() < original_mask_sum * 0.1: | |
| print(f"Warning: Alignment lost {(1 - aligned_mask.sum()/original_mask_sum)*100:.1f}% of mask, using unaligned") | |
| # Resize to original without alignment | |
| aligned_mask = skimage.transform.resize( | |
| mask, original_img.shape, order=0, preserve_range=True, anti_aliasing=False | |
| ) | |
| mask = aligned_mask | |
| else: | |
| mask = mask | |
| # Add to combined mask with organ index | |
| combined_mask[mask > 0] = idx + 1 | |
| # Add legend entry | |
| organ_name = list(self.organ_map.keys())[list(self.organ_map.values()).index(organ_idx)] | |
| legend_items.append((organ_name, color)) | |
| print(f"Total masks found and rendered: {masks_found}") | |
| # Display original image | |
| ax.imshow(original_img, cmap="gray") | |
| # Overlay masks with contours | |
| if masks_found > 0: | |
| from matplotlib.patches import Patch | |
| for idx, (organ_name, color) in enumerate(legend_items): | |
| mask_region = (combined_mask == idx + 1) | |
| # Create colored overlay | |
| overlay = np.zeros((*original_img.shape, 4)) | |
| overlay[mask_region] = [color[0], color[1], color[2], 0.5] | |
| ax.imshow(overlay) | |
| # Draw contours for clear boundaries | |
| contours = skimage.measure.find_contours(mask_region.astype(float), 0.5) | |
| for contour in contours: | |
| ax.plot(contour[:, 1], contour[:, 0], color=color, linewidth=3, alpha=0.9) | |
| # Add legend | |
| patches = [Patch(facecolor=c, edgecolor=c, label=n, alpha=0.7) for n, c in legend_items] | |
| ax.legend(handles=patches, loc="upper right", fontsize=10, framealpha=0.9) | |
| ax.set_title("Segmentation Overlay", fontsize=14, color='white', pad=15) | |
| else: | |
| ax.set_title("No Masks Detected", fontsize=14, color='red', pad=15) | |
| ax.axis("off") | |
| fig.patch.set_facecolor('black') | |
| save_path = self.temp_dir / f"segmentation_{uuid.uuid4().hex[:8]}.png" | |
| plt.savefig(save_path, bbox_inches="tight", dpi=150, facecolor='black') | |
| plt.close(fig) | |
| return str(save_path) | |
| def _run( | |
| self, | |
| image_path: str, | |
| organs: Optional[List[str]] = None, | |
| run_manager: Optional[CallbackManagerForToolRun] = None, | |
| ) -> Tuple[Dict[str, Any], Dict]: | |
| """Run segmentation analysis for specified organs.""" | |
| try: | |
| # Validate and get organ indices | |
| if organs: | |
| organs = [o.strip() for o in organs] | |
| invalid_organs = [o for o in organs if o not in self.organ_map] | |
| if invalid_organs: | |
| raise ValueError(f"Invalid organs specified: {invalid_organs}") | |
| organ_indices = [self.organ_map[o] for o in organs] | |
| else: | |
| organ_indices = list(self.organ_map.values()) | |
| organs = list(self.organ_map.keys()) | |
| # Load and process image | |
| original_img = skimage.io.imread(image_path) | |
| print(f"\n=== Image Loading Debug ===") | |
| print(f"Image path: {image_path}") | |
| print(f"Original shape: {original_img.shape}, dtype: {original_img.dtype}") | |
| print(f"Original range: [{original_img.min()}, {original_img.max()}]") | |
| if len(original_img.shape) > 2: | |
| original_img = original_img[:, :, 0] | |
| print(f"After channel extraction: {original_img.shape}") | |
| # TorchXRayVision models expect images in the range [-1024, 1024] (Hounsfield units) | |
| # NOT normalized to [0, 1]! We need to scale 8-bit images to this range. | |
| # IMPORTANT: PNG/JPEG X-rays are typically INVERTED compared to DICOM | |
| # (lungs appear bright instead of dark), so we need to invert them first | |
| # EXCEPTION: Processed DICOM files (saved as PNG) should NOT be inverted | |
| is_processed_dicom = "processed_dicom" in image_path | |
| if original_img.dtype == np.uint8 or original_img.max() <= 255: | |
| if is_processed_dicom: | |
| # Processed DICOM - don't invert, just scale to HU range | |
| img = (original_img.astype(np.float32) / 255.0) * 1624 - 1024 | |
| print(f"Processed DICOM - converted to HU range without inversion: [{img.min():.1f}, {img.max():.1f}]") | |
| else: | |
| # Regular PNG/JPEG - invert first, then scale | |
| inverted = 255 - original_img.astype(np.float32) | |
| img = (inverted / 255.0) * 1624 - 1024 | |
| print(f"PNG/JPEG - inverted and converted to HU range: [{img.min():.1f}, {img.max():.1f}]") | |
| else: | |
| # Assume already in HU or similar range (raw DICOM) | |
| img = original_img.astype(np.float32) | |
| print(f"Kept original range (raw DICOM): [{img.min():.1f}, {img.max():.1f}]") | |
| img = img[None, ...] | |
| print(f"After adding batch dim: {img.shape}") | |
| img = self.transform(img) | |
| print(f"After transform: {img.shape}") | |
| img = torch.from_numpy(img) | |
| img = img.to(self.device) | |
| print(f"Final tensor: shape={img.shape}, dtype={img.dtype}, device={img.device}") | |
| print(f"Tensor stats: mean={img.mean():.3f}, std={img.std():.3f}, min={img.min():.3f}, max={img.max():.3f}") | |
| # Generate predictions | |
| with torch.no_grad(): | |
| pred = self.model(img) | |
| print(f"\nModel output shape: {pred.shape}") | |
| print(f"Raw predictions: min={pred.min().item():.3f}, max={pred.max().item():.3f}, mean={pred.mean().item():.3f}") | |
| pred_probs = torch.sigmoid(pred) | |
| print(f"After sigmoid: min={pred_probs.min().item():.3f}, max={pred_probs.max().item():.3f}, mean={pred_probs.mean().item():.3f}") | |
| # Print probabilities for debugging | |
| print(f"\n=== Segmentation Probabilities ===") | |
| for organ_idx in organ_indices: | |
| organ_name = list(self.organ_map.keys())[list(self.organ_map.values()).index(organ_idx)] | |
| prob = pred_probs[0, organ_idx].mean().item() | |
| max_prob = pred_probs[0, organ_idx].max().item() | |
| print(f"{organ_name}: mean={prob:.3f}, max={max_prob:.3f}") | |
| print(f"==================================\n") | |
| # Use lower threshold (0.3) to capture more masks, especially for non-DICOM images | |
| # The model tends to be less confident on non-DICOM images | |
| pred_masks = (pred_probs > 0.3).float() | |
| print(f"Masks after 0.3 threshold: {[f'{(pred_masks[0,i].sum().item())} pixels' for i in organ_indices]}") | |
| # Save visualization | |
| viz_path = self._save_visualization(original_img, pred_masks, organ_indices) | |
| # Compute metrics for selected organs | |
| results = {} | |
| for idx, organ_name in zip(organ_indices, organs): | |
| mask = pred_masks[0, idx].cpu().numpy() | |
| if mask.sum() > 0: | |
| metrics = self._compute_organ_metrics(mask, original_img, float(pred_probs[0, idx].mean().cpu())) | |
| if metrics: | |
| results[organ_name] = metrics | |
| output = { | |
| "segmentation_image_path": viz_path, | |
| "metrics": {organ: metrics.dict() for organ, metrics in results.items()}, | |
| } | |
| metadata = { | |
| "image_path": image_path, | |
| "segmentation_image_path": viz_path, | |
| "original_size": original_img.shape, | |
| "model_size": tuple(img.shape[-2:]), | |
| "pixel_spacing_mm": self.pixel_spacing_mm, | |
| "requested_organs": organs, | |
| "processed_organs": list(results.keys()), | |
| "analysis_status": "completed", | |
| } | |
| return output, metadata | |
| except Exception as e: | |
| error_output = {"error": str(e)} | |
| error_metadata = { | |
| "image_path": image_path, | |
| "analysis_status": "failed", | |
| "error_traceback": traceback.format_exc(), | |
| } | |
| return error_output, error_metadata | |
| async def _arun( | |
| self, | |
| image_path: str, | |
| organs: Optional[List[str]] = None, | |
| run_manager: Optional[AsyncCallbackManagerForToolRun] = None, | |
| ) -> Tuple[Dict[str, Any], Dict]: | |
| """Async version of _run.""" | |
| return self._run(image_path, organs) | |