Spaces:
Paused
Paused
| import os | |
| from typing import ClassVar, Dict, List, Optional, Tuple, Type | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| from pydantic import BaseModel, Field | |
| from timm.models.swin_transformer import SwinTransformer | |
| from langchain_core.callbacks import ( | |
| AsyncCallbackManagerForToolRun, | |
| CallbackManagerForToolRun, | |
| ) | |
| from langchain_core.tools import BaseTool | |
| class OmniSwinTransformer(SwinTransformer): | |
| """OmniSwinTransformer with multiple classification heads and optional projector.""" | |
| def __init__(self, num_classes_list, projector_features=None, use_mlp=False, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| assert num_classes_list is not None | |
| self.projector = None | |
| if projector_features: | |
| encoder_features = self.num_features | |
| self.num_features = projector_features | |
| if use_mlp: | |
| self.projector = nn.Sequential( | |
| nn.Linear(encoder_features, self.num_features), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(self.num_features, self.num_features), | |
| ) | |
| else: | |
| self.projector = nn.Linear(encoder_features, self.num_features) | |
| self.omni_heads = [] | |
| for num_classes in num_classes_list: | |
| self.omni_heads.append(nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) | |
| self.omni_heads = nn.ModuleList(self.omni_heads) | |
| def forward(self, x, head_n=None): | |
| x = self.forward_features(x) | |
| if self.projector: | |
| x = self.projector(x) | |
| if head_n is not None: | |
| return x, self.omni_heads[head_n](x) | |
| else: | |
| return [head(x) for head in self.omni_heads] | |
| def generate_embeddings(self, x, after_proj=True): | |
| x = self.forward_features(x) | |
| if after_proj and self.projector: | |
| x = self.projector(x) | |
| return x | |
| class ArcPlusInput(BaseModel): | |
| """Input for ArcPlus chest X-ray analysis tool. Only supports JPG or PNG images.""" | |
| image_path: str = Field(..., description="Path to the radiology image file, only supports JPG or PNG images") | |
| class ArcPlusClassifierTool(BaseTool): | |
| """Tool that classifies chest X-ray images using the ArcPlus OmniSwinTransformer model. | |
| This tool uses a pre-trained OmniSwinTransformer model (ArcPlus) to analyze chest X-ray images | |
| and predict the likelihood of various pathologies across multiple medical datasets. The model | |
| employs a Swin Transformer architecture with multiple classification heads, each specialized | |
| for different medical datasets and conditions. | |
| The ArcPlus model is trained on 6 different medical datasets: | |
| - MIMIC-CXR: 14 pathologies including common chest conditions | |
| - CheXpert: 14 pathologies with standardized labeling | |
| - NIH ChestX-ray14: 14 pathologies from large-scale dataset | |
| - RSNA: 3 classes for pneumonia detection | |
| - VinDr-CXR: 6 categories including tuberculosis and lung tumors | |
| - Shenzhen: 1 class for tuberculosis detection | |
| Key Features: | |
| - Multi-head architecture with 6 specialized classification heads | |
| - 768x768 input resolution for high-detail analysis | |
| - Projector layer with 1376 features for enhanced representation | |
| - Sigmoid activation for multi-label classification | |
| - Covers 52+ distinct pathology categories across datasets | |
| The model outputs probabilities (0 to 1) for each condition, with higher values | |
| indicating higher likelihood of the pathology being present in the image. | |
| """ | |
| name: str = "arcplus_classifier" | |
| description: str = ( | |
| "Advanced chest X-ray classification tool using ArcPlus OmniSwinTransformer with multi-dataset training. " | |
| "Analyzes chest X-ray images and provides probability predictions for 52+ pathologies across 6 medical datasets. " | |
| "Input: Path to chest X-ray image file (JPG/PNG). " | |
| "Output: Dictionary mapping pathology names to probabilities (0-1). " | |
| "Features: Multi-head architecture, 768px resolution, projector layer, specialized for medical imaging. " | |
| "Pathologies include: Atelectasis, Cardiomegaly, Consolidation, Edema, Enlarged Cardiomediastinum, " | |
| "Fracture, Lung Lesion, Lung Opacity, Pleural Effusion, Pneumonia, Pneumothorax, Mass, Nodule, " | |
| "Emphysema, Fibrosis, PE, Lung Tumor, Tuberculosis, and many more across MIMIC, CheXpert, NIH, " | |
| "RSNA, VinDr, and Shenzhen datasets. Higher probabilities indicate higher likelihood of condition presence." | |
| ) | |
| args_schema: Type[BaseModel] = ArcPlusInput | |
| model: OmniSwinTransformer = None | |
| device: Optional[str] = "cuda" | |
| normalize: transforms.Normalize = None | |
| disease_list: List[str] = None | |
| num_classes_list: List[int] = None | |
| # Disease mappings from the analysis | |
| mimic_diseases: ClassVar[List[str]] = [ | |
| "Atelectasis", | |
| "Cardiomegaly", | |
| "Consolidation", | |
| "Edema", | |
| "Enlarged Cardiomediastinum", | |
| "Fracture", | |
| "Lung Lesion", | |
| "Lung Opacity", | |
| "No Finding", | |
| "Pleural Effusion", | |
| "Pleural Other", | |
| "Pneumonia", | |
| "Pneumothorax", | |
| "Support Devices", | |
| ] | |
| chexpert_diseases: ClassVar[List[str]] = [ | |
| "No Finding", | |
| "Enlarged Cardiomediastinum", | |
| "Cardiomegaly", | |
| "Lung Opacity", | |
| "Lung Lesion", | |
| "Edema", | |
| "Consolidation", | |
| "Pneumonia", | |
| "Atelectasis", | |
| "Pneumothorax", | |
| "Pleural Effusion", | |
| "Pleural Other", | |
| "Fracture", | |
| "Support Devices", | |
| ] | |
| nih14_diseases: ClassVar[List[str]] = [ | |
| "Atelectasis", | |
| "Cardiomegaly", | |
| "Effusion", | |
| "Infiltration", | |
| "Mass", | |
| "Nodule", | |
| "Pneumonia", | |
| "Pneumothorax", | |
| "Consolidation", | |
| "Edema", | |
| "Emphysema", | |
| "Fibrosis", | |
| "Pleural_Thickening", | |
| "Hernia", | |
| ] | |
| rsna_diseases: ClassVar[List[str]] = ["No Lung Opacity/Not Normal", "Normal", "Lung Opacity"] | |
| vindr_diseases: ClassVar[List[str]] = [ | |
| "PE", | |
| "Lung tumor", | |
| "Pneumonia", | |
| "Tuberculosis", | |
| "Other diseases", | |
| "No finding", | |
| ] | |
| shenzhen_diseases: ClassVar[List[str]] = ["TB"] | |
| def __init__(self, cache_dir: str = None, device: Optional[str] = "cuda"): | |
| """Initialize the ArcPlus Classifier Tool. | |
| Args: | |
| cache_dir (str, optional): Directory containing the pre-trained ArcPlus model checkpoint. | |
| The tool will automatically look for 'Ark6_swinLarge768_ep50.pth.tar' in this directory. | |
| If None, model will be initialized with random weights (not recommended for inference). | |
| Default: None. | |
| device (str, optional): Device to run the model on ('cuda' for GPU, 'cpu' for CPU). | |
| GPU is recommended for better performance. Default: "cuda". | |
| Model Architecture Details: | |
| - OmniSwinTransformer with 6 classification heads | |
| - Input resolution: 768x768 pixels | |
| - Projector features: 1376 dimensions | |
| - Multi-head configuration: [14, 14, 14, 3, 6, 1] classes per head | |
| - Total pathologies: 52+ across 6 medical datasets | |
| - Preprocessing: ImageNet normalization (mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| Raises: | |
| FileNotFoundError: If cache_dir is provided but model file doesn't exist. | |
| RuntimeError: If model loading fails or device is unavailable. | |
| """ | |
| super().__init__() | |
| # Create combined disease list from all supported datasets | |
| self.disease_list = ( | |
| self.mimic_diseases | |
| + self.chexpert_diseases | |
| + self.nih14_diseases | |
| + self.rsna_diseases | |
| + self.vindr_diseases | |
| + self.shenzhen_diseases | |
| ) | |
| # Multi-head configuration: [MIMIC, CheXpert, NIH, RSNA, VinDr, Shenzhen] | |
| self.num_classes_list = [14, 14, 14, 3, 6, 1] | |
| # Initialize the OmniSwinTransformer model with ArcPlus architecture | |
| self.model = OmniSwinTransformer( | |
| num_classes_list=self.num_classes_list, | |
| projector_features=1376, # Enhanced feature representation | |
| use_mlp=False, # Linear projector (not MLP) | |
| img_size=768, # High-resolution input | |
| patch_size=4, | |
| window_size=12, | |
| embed_dim=192, | |
| depths=(2, 2, 18, 2), # Swin-Large configuration | |
| num_heads=(6, 12, 24, 48), | |
| ) | |
| # Load pre-trained weights if provided | |
| if cache_dir: | |
| model_path = os.path.join(cache_dir, "Ark6_swinLarge768_ep50.pth.tar") | |
| self._load_checkpoint(model_path) | |
| self.model.eval() | |
| self.device = torch.device(device) if device else "cuda" | |
| self.model = self.model.to(self.device) | |
| # ImageNet normalization parameters for optimal performance | |
| self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| def _load_checkpoint(self, model_path: str) -> None: | |
| """ | |
| Load the ArcPlus model checkpoint. | |
| Args: | |
| model_path (str): Path to the model checkpoint file. | |
| """ | |
| # Load the checkpoint (set weights_only=False for PyTorch 2.6+ compatibility) | |
| checkpoint = torch.load(model_path, map_location=torch.device("cpu"), weights_only=False) | |
| state_dict = checkpoint["teacher"] # Use 'teacher' key | |
| # Remove "module." prefix if present (improved logic from example) | |
| if any([True if "module." in k else False for k in state_dict.keys()]): | |
| state_dict = {k.replace("module.", ""): v for k, v in state_dict.items() if k.startswith("module.")} | |
| # Load the model weights | |
| msg = self.model.load_state_dict(state_dict, strict=False) | |
| def _process_image(self, image_path: str) -> torch.Tensor: | |
| """ | |
| Process the input chest X-ray image for model inference. | |
| This method loads the image, applies necessary transformations, | |
| and prepares it as a torch.Tensor for model input. | |
| Args: | |
| image_path (str): The file path to the chest X-ray image. | |
| Returns: | |
| torch.Tensor: A processed image tensor ready for model inference. | |
| Raises: | |
| FileNotFoundError: If the specified image file does not exist. | |
| ValueError: If the image cannot be properly loaded or processed. | |
| """ | |
| try: | |
| # Load and preprocess image following the example pattern | |
| image = Image.open(image_path) | |
| # Properly handle 16-bit grayscale images (common in medical imaging) | |
| if image.mode == "I;16": | |
| # Convert 16-bit to 8-bit by normalizing to 0-255 range | |
| img_array = np.array(image) | |
| img_normalized = ((img_array - img_array.min()) / (img_array.max() - img_array.min()) * 255).astype(np.uint8) | |
| image = Image.fromarray(img_normalized, mode='L') | |
| image = image.convert("RGB").resize((768, 768)) | |
| # Convert to numpy array and normalize to [0, 1] | |
| image_array = np.array(image) / 255.0 | |
| # Apply ImageNet normalization | |
| image_tensor = torch.from_numpy(image_array).float() | |
| image_tensor = image_tensor.permute(2, 0, 1) # HWC to CHW | |
| image_tensor = self.normalize(image_tensor) | |
| # Add batch dimension and move to device | |
| image_tensor = image_tensor.unsqueeze(0).to(self.device) | |
| return image_tensor | |
| except Exception as e: | |
| raise ValueError(f"Error processing image {image_path}: {str(e)}") | |
| def _run( | |
| self, | |
| image_path: str, | |
| run_manager: Optional[CallbackManagerForToolRun] = None, | |
| ) -> Tuple[Dict[str, float], Dict]: | |
| """Classify the chest X-ray image using ArcPlus SwinTransformer. | |
| Args: | |
| image_path (str): The path to the chest X-ray image file. | |
| run_manager (Optional[CallbackManagerForToolRun]): The callback manager for the tool run. | |
| Returns: | |
| Tuple[Dict[str, float], Dict]: A tuple containing the classification results | |
| (pathologies and their probabilities from 0 to 1) | |
| and any additional metadata. | |
| Raises: | |
| Exception: If there's an error processing the image or during classification. | |
| """ | |
| try: | |
| # Process the image | |
| image_tensor = self._process_image(image_path) | |
| # Run model inference | |
| with torch.no_grad(): | |
| pre_logits = self.model(image_tensor) | |
| # Apply sigmoid to each output head (as seen in example) | |
| preds = [torch.sigmoid(out) for out in pre_logits] | |
| # Concatenate all predictions into single tensor | |
| preds = torch.cat(preds, dim=1) | |
| # Convert to numpy | |
| predictions = preds.cpu().numpy().flatten() | |
| # Map predictions to disease names | |
| if len(predictions) != len(self.disease_list): | |
| print(f"Warning: Expected {len(self.disease_list)} predictions, got {len(predictions)}") | |
| # Pad or truncate as needed | |
| if len(predictions) < len(self.disease_list): | |
| predictions = np.pad(predictions, (0, len(self.disease_list) - len(predictions))) | |
| else: | |
| predictions = predictions[: len(self.disease_list)] | |
| # Create output dictionary mapping disease names to probabilities | |
| # Convert numpy floats to native Python floats for proper serialization | |
| output = dict(zip(self.disease_list, [float(pred) for pred in predictions])) | |
| metadata = { | |
| "image_path": image_path, | |
| "model": "ArcPlus OmniSwinTransformer", | |
| "analysis_status": "completed", | |
| "num_predictions": len(predictions), | |
| "num_heads": len(self.num_classes_list), | |
| "projector_features": 1376, | |
| "note": "Probabilities range from 0 to 1, with higher values indicating higher likelihood of the condition.", | |
| } | |
| return output, metadata | |
| except Exception as e: | |
| return {"error": str(e)}, { | |
| "image_path": image_path, | |
| "analysis_status": "failed", | |
| "error_details": str(e), | |
| } | |
| async def _arun( | |
| self, | |
| image_path: str, | |
| run_manager: Optional[AsyncCallbackManagerForToolRun] = None, | |
| ) -> Tuple[Dict[str, float], Dict]: | |
| """Asynchronously classify the chest X-ray image using ArcPlus SwinTransformer. | |
| This method currently calls the synchronous version, as the model inference | |
| is not inherently asynchronous. For true asynchronous behavior, consider | |
| using a separate thread or process. | |
| Args: | |
| image_path (str): The path to the chest X-ray image file. | |
| run_manager (Optional[AsyncCallbackManagerForToolRun]): The async callback manager for the tool run. | |
| Returns: | |
| Tuple[Dict[str, float], Dict]: A tuple containing the classification results | |
| (pathologies and their probabilities from 0 to 1) | |
| and any additional metadata. | |
| Raises: | |
| Exception: If there's an error processing the image or during classification. | |
| """ | |
| return self._run(image_path) | |