import os from typing import ClassVar, Dict, List, Optional, Tuple, Type import numpy as np import torch import torch.nn as nn import torchvision.transforms as transforms from PIL import Image from pydantic import BaseModel, Field from timm.models.swin_transformer import SwinTransformer from langchain_core.callbacks import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) from langchain_core.tools import BaseTool class OmniSwinTransformer(SwinTransformer): """OmniSwinTransformer with multiple classification heads and optional projector.""" def __init__(self, num_classes_list, projector_features=None, use_mlp=False, *args, **kwargs): super().__init__(*args, **kwargs) assert num_classes_list is not None self.projector = None if projector_features: encoder_features = self.num_features self.num_features = projector_features if use_mlp: self.projector = nn.Sequential( nn.Linear(encoder_features, self.num_features), nn.ReLU(inplace=True), nn.Linear(self.num_features, self.num_features), ) else: self.projector = nn.Linear(encoder_features, self.num_features) self.omni_heads = [] for num_classes in num_classes_list: self.omni_heads.append(nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) self.omni_heads = nn.ModuleList(self.omni_heads) def forward(self, x, head_n=None): x = self.forward_features(x) if self.projector: x = self.projector(x) if head_n is not None: return x, self.omni_heads[head_n](x) else: return [head(x) for head in self.omni_heads] def generate_embeddings(self, x, after_proj=True): x = self.forward_features(x) if after_proj and self.projector: x = self.projector(x) return x class ArcPlusInput(BaseModel): """Input for ArcPlus chest X-ray analysis tool. Only supports JPG or PNG images.""" image_path: str = Field(..., description="Path to the radiology image file, only supports JPG or PNG images") class ArcPlusClassifierTool(BaseTool): """Tool that classifies chest X-ray images using the ArcPlus OmniSwinTransformer model. This tool uses a pre-trained OmniSwinTransformer model (ArcPlus) to analyze chest X-ray images and predict the likelihood of various pathologies across multiple medical datasets. The model employs a Swin Transformer architecture with multiple classification heads, each specialized for different medical datasets and conditions. The ArcPlus model is trained on 6 different medical datasets: - MIMIC-CXR: 14 pathologies including common chest conditions - CheXpert: 14 pathologies with standardized labeling - NIH ChestX-ray14: 14 pathologies from large-scale dataset - RSNA: 3 classes for pneumonia detection - VinDr-CXR: 6 categories including tuberculosis and lung tumors - Shenzhen: 1 class for tuberculosis detection Key Features: - Multi-head architecture with 6 specialized classification heads - 768x768 input resolution for high-detail analysis - Projector layer with 1376 features for enhanced representation - Sigmoid activation for multi-label classification - Covers 52+ distinct pathology categories across datasets The model outputs probabilities (0 to 1) for each condition, with higher values indicating higher likelihood of the pathology being present in the image. """ name: str = "arcplus_classifier" description: str = ( "Advanced chest X-ray classification tool using ArcPlus OmniSwinTransformer with multi-dataset training. " "Analyzes chest X-ray images and provides probability predictions for 52+ pathologies across 6 medical datasets. " "Input: Path to chest X-ray image file (JPG/PNG). " "Output: Dictionary mapping pathology names to probabilities (0-1). " "Features: Multi-head architecture, 768px resolution, projector layer, specialized for medical imaging. " "Pathologies include: Atelectasis, Cardiomegaly, Consolidation, Edema, Enlarged Cardiomediastinum, " "Fracture, Lung Lesion, Lung Opacity, Pleural Effusion, Pneumonia, Pneumothorax, Mass, Nodule, " "Emphysema, Fibrosis, PE, Lung Tumor, Tuberculosis, and many more across MIMIC, CheXpert, NIH, " "RSNA, VinDr, and Shenzhen datasets. Higher probabilities indicate higher likelihood of condition presence." ) args_schema: Type[BaseModel] = ArcPlusInput model: OmniSwinTransformer = None device: Optional[str] = "cuda" normalize: transforms.Normalize = None disease_list: List[str] = None num_classes_list: List[int] = None # Disease mappings from the analysis mimic_diseases: ClassVar[List[str]] = [ "Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Enlarged Cardiomediastinum", "Fracture", "Lung Lesion", "Lung Opacity", "No Finding", "Pleural Effusion", "Pleural Other", "Pneumonia", "Pneumothorax", "Support Devices", ] chexpert_diseases: ClassVar[List[str]] = [ "No Finding", "Enlarged Cardiomediastinum", "Cardiomegaly", "Lung Opacity", "Lung Lesion", "Edema", "Consolidation", "Pneumonia", "Atelectasis", "Pneumothorax", "Pleural Effusion", "Pleural Other", "Fracture", "Support Devices", ] nih14_diseases: ClassVar[List[str]] = [ "Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass", "Nodule", "Pneumonia", "Pneumothorax", "Consolidation", "Edema", "Emphysema", "Fibrosis", "Pleural_Thickening", "Hernia", ] rsna_diseases: ClassVar[List[str]] = ["No Lung Opacity/Not Normal", "Normal", "Lung Opacity"] vindr_diseases: ClassVar[List[str]] = [ "PE", "Lung tumor", "Pneumonia", "Tuberculosis", "Other diseases", "No finding", ] shenzhen_diseases: ClassVar[List[str]] = ["TB"] def __init__(self, cache_dir: str = None, device: Optional[str] = "cuda"): """Initialize the ArcPlus Classifier Tool. Args: cache_dir (str, optional): Directory containing the pre-trained ArcPlus model checkpoint. The tool will automatically look for 'Ark6_swinLarge768_ep50.pth.tar' in this directory. If None, model will be initialized with random weights (not recommended for inference). Default: None. device (str, optional): Device to run the model on ('cuda' for GPU, 'cpu' for CPU). GPU is recommended for better performance. Default: "cuda". Model Architecture Details: - OmniSwinTransformer with 6 classification heads - Input resolution: 768x768 pixels - Projector features: 1376 dimensions - Multi-head configuration: [14, 14, 14, 3, 6, 1] classes per head - Total pathologies: 52+ across 6 medical datasets - Preprocessing: ImageNet normalization (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) Raises: FileNotFoundError: If cache_dir is provided but model file doesn't exist. RuntimeError: If model loading fails or device is unavailable. """ super().__init__() # Create combined disease list from all supported datasets self.disease_list = ( self.mimic_diseases + self.chexpert_diseases + self.nih14_diseases + self.rsna_diseases + self.vindr_diseases + self.shenzhen_diseases ) # Multi-head configuration: [MIMIC, CheXpert, NIH, RSNA, VinDr, Shenzhen] self.num_classes_list = [14, 14, 14, 3, 6, 1] # Initialize the OmniSwinTransformer model with ArcPlus architecture self.model = OmniSwinTransformer( num_classes_list=self.num_classes_list, projector_features=1376, # Enhanced feature representation use_mlp=False, # Linear projector (not MLP) img_size=768, # High-resolution input patch_size=4, window_size=12, embed_dim=192, depths=(2, 2, 18, 2), # Swin-Large configuration num_heads=(6, 12, 24, 48), ) # Load pre-trained weights if provided if cache_dir: model_path = os.path.join(cache_dir, "Ark6_swinLarge768_ep50.pth.tar") self._load_checkpoint(model_path) self.model.eval() self.device = torch.device(device) if device else "cuda" self.model = self.model.to(self.device) # ImageNet normalization parameters for optimal performance self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) def _load_checkpoint(self, model_path: str) -> None: """ Load the ArcPlus model checkpoint. Args: model_path (str): Path to the model checkpoint file. """ # Load the checkpoint (set weights_only=False for PyTorch 2.6+ compatibility) checkpoint = torch.load(model_path, map_location=torch.device("cpu"), weights_only=False) state_dict = checkpoint["teacher"] # Use 'teacher' key # Remove "module." prefix if present (improved logic from example) if any([True if "module." in k else False for k in state_dict.keys()]): state_dict = {k.replace("module.", ""): v for k, v in state_dict.items() if k.startswith("module.")} # Load the model weights msg = self.model.load_state_dict(state_dict, strict=False) def _process_image(self, image_path: str) -> torch.Tensor: """ Process the input chest X-ray image for model inference. This method loads the image, applies necessary transformations, and prepares it as a torch.Tensor for model input. Args: image_path (str): The file path to the chest X-ray image. Returns: torch.Tensor: A processed image tensor ready for model inference. Raises: FileNotFoundError: If the specified image file does not exist. ValueError: If the image cannot be properly loaded or processed. """ try: # Load and preprocess image following the example pattern image = Image.open(image_path) # Properly handle 16-bit grayscale images (common in medical imaging) if image.mode == "I;16": # Convert 16-bit to 8-bit by normalizing to 0-255 range img_array = np.array(image) img_normalized = ((img_array - img_array.min()) / (img_array.max() - img_array.min()) * 255).astype(np.uint8) image = Image.fromarray(img_normalized, mode='L') image = image.convert("RGB").resize((768, 768)) # Convert to numpy array and normalize to [0, 1] image_array = np.array(image) / 255.0 # Apply ImageNet normalization image_tensor = torch.from_numpy(image_array).float() image_tensor = image_tensor.permute(2, 0, 1) # HWC to CHW image_tensor = self.normalize(image_tensor) # Add batch dimension and move to device image_tensor = image_tensor.unsqueeze(0).to(self.device) return image_tensor except Exception as e: raise ValueError(f"Error processing image {image_path}: {str(e)}") def _run( self, image_path: str, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> Tuple[Dict[str, float], Dict]: """Classify the chest X-ray image using ArcPlus SwinTransformer. Args: image_path (str): The path to the chest X-ray image file. run_manager (Optional[CallbackManagerForToolRun]): The callback manager for the tool run. Returns: Tuple[Dict[str, float], Dict]: A tuple containing the classification results (pathologies and their probabilities from 0 to 1) and any additional metadata. Raises: Exception: If there's an error processing the image or during classification. """ try: # Process the image image_tensor = self._process_image(image_path) # Run model inference with torch.no_grad(): pre_logits = self.model(image_tensor) # Apply sigmoid to each output head (as seen in example) preds = [torch.sigmoid(out) for out in pre_logits] # Concatenate all predictions into single tensor preds = torch.cat(preds, dim=1) # Convert to numpy predictions = preds.cpu().numpy().flatten() # Map predictions to disease names if len(predictions) != len(self.disease_list): print(f"Warning: Expected {len(self.disease_list)} predictions, got {len(predictions)}") # Pad or truncate as needed if len(predictions) < len(self.disease_list): predictions = np.pad(predictions, (0, len(self.disease_list) - len(predictions))) else: predictions = predictions[: len(self.disease_list)] # Create output dictionary mapping disease names to probabilities # Convert numpy floats to native Python floats for proper serialization output = dict(zip(self.disease_list, [float(pred) for pred in predictions])) metadata = { "image_path": image_path, "model": "ArcPlus OmniSwinTransformer", "analysis_status": "completed", "num_predictions": len(predictions), "num_heads": len(self.num_classes_list), "projector_features": 1376, "note": "Probabilities range from 0 to 1, with higher values indicating higher likelihood of the condition.", } return output, metadata except Exception as e: return {"error": str(e)}, { "image_path": image_path, "analysis_status": "failed", "error_details": str(e), } async def _arun( self, image_path: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> Tuple[Dict[str, float], Dict]: """Asynchronously classify the chest X-ray image using ArcPlus SwinTransformer. This method currently calls the synchronous version, as the model inference is not inherently asynchronous. For true asynchronous behavior, consider using a separate thread or process. Args: image_path (str): The path to the chest X-ray image file. run_manager (Optional[AsyncCallbackManagerForToolRun]): The async callback manager for the tool run. Returns: Tuple[Dict[str, float], Dict]: A tuple containing the classification results (pathologies and their probabilities from 0 to 1) and any additional metadata. Raises: Exception: If there's an error processing the image or during classification. """ return self._run(image_path)