Spaces:
Runtime error
Runtime error
| import io, json, os | |
| import gradio as gr | |
| import numpy as np | |
| import torch, torchvision | |
| from PIL import Image | |
| import torchxrayvision as xrv | |
| import cv2 | |
| DEVICE = torch.device("cpu") | |
| # Load models with error handling | |
| try: | |
| print("Loading classification model...") | |
| CLS_MODEL = xrv.models.DenseNet(weights="densenet121-res224-all").to(DEVICE).eval() | |
| print("Classification model loaded successfully") | |
| except Exception as e: | |
| print(f"Error loading classification model: {e}") | |
| CLS_MODEL = None | |
| try: | |
| print("Loading segmentation model...") | |
| SEG_MODEL = xrv.baseline_models.chestx_det.PSPNet().to(DEVICE).eval() | |
| print("Segmentation model loaded successfully") | |
| except Exception as e: | |
| print(f"Error loading segmentation model: {e}") | |
| SEG_MODEL = None | |
| CenterCrop = xrv.datasets.XRayCenterCrop() | |
| Resizer_CLS = xrv.datasets.XRayResizer((224, 244)) | |
| Resizer_SEG = xrv.datasets.XRayResizer((512, 512)) | |
| Transform_CLS = torchvision.transforms.Compose([CenterCrop, Resizer_CLS]) | |
| Transform_SEG = torchvision.transforms.Compose([CenterCrop, Resizer_SEG]) | |
| def validate_image_format(pil_img): | |
| """Validate image format and properties""" | |
| try: | |
| if pil_img is None: | |
| raise Exception("No image provided") | |
| # Check if it's a valid PIL Image | |
| if not hasattr(pil_img, 'mode') or not hasattr(pil_img, 'size'): | |
| raise Exception("Invalid image format - not a PIL Image") | |
| # Check image size | |
| width, height = pil_img.size | |
| if width < 32 or height < 32: | |
| raise Exception(f"Image too small: {width}x{height}. Minimum size: 32x32") | |
| if width > 4096 or height > 4096: | |
| raise Exception(f"Image too large: {width}x{height}. Maximum size: 4096x4096") | |
| # Check if image has content (not all black or white) | |
| img_array = np.array(pil_img) | |
| if len(img_array.shape) == 3: | |
| img_gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) | |
| else: | |
| img_gray = img_array | |
| # Check for valid image content | |
| if img_gray.std() < 5: # Very low standard deviation means mostly uniform | |
| raise Exception("Image appears to be blank or uniform (no visible content)") | |
| return True | |
| except Exception as e: | |
| raise Exception(f"Image validation failed: {str(e)}") | |
| def pil_to_cv2_gray(pil_img): | |
| """Convert PIL image to OpenCV grayscale format with comprehensive error handling""" | |
| try: | |
| # First validate the image | |
| validate_image_format(pil_img) | |
| # Convert PIL to numpy array | |
| img_array = np.array(pil_img) | |
| # Check array properties | |
| if img_array.size == 0: | |
| raise Exception("Empty image array") | |
| # Convert to grayscale if needed | |
| if len(img_array.shape) == 3: | |
| if img_array.shape[2] == 3: # RGB | |
| img_gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) | |
| elif img_array.shape[2] == 4: # RGBA | |
| img_gray = cv2.cvtColor(img_array, cv2.COLOR_RGBA2GRAY) | |
| else: | |
| # Take the first channel if it's not RGB/RGBA | |
| img_gray = img_array[:, :, 0] | |
| else: | |
| img_gray = img_array | |
| # Ensure proper data type and range | |
| if img_gray.dtype != np.uint8: | |
| if img_gray.max() <= 1.0: # Float image in [0,1] range | |
| img_gray = (img_gray * 255).astype(np.uint8) | |
| else: | |
| img_gray = img_gray.astype(np.uint8) | |
| # Final validation | |
| if img_gray.min() < 0 or img_gray.max() > 255: | |
| raise Exception(f"Invalid pixel values: min={img_gray.min()}, max={img_gray.max()}") | |
| return img_gray | |
| except Exception as e: | |
| raise Exception(f"Error in pil_to_cv2_gray: {str(e)}") | |
| def cv2_resize_and_normalize(img_gray, target_size): | |
| """Resize image using cv2 and normalize using xrv with comprehensive error handling""" | |
| try: | |
| # Validate input | |
| if img_gray is None: | |
| raise Exception("Input image is None") | |
| if not isinstance(img_gray, np.ndarray): | |
| raise Exception(f"Input must be numpy array, got {type(img_gray)}") | |
| if len(img_gray.shape) != 2: | |
| raise Exception(f"Input must be 2D grayscale image, got shape {img_gray.shape}") | |
| if img_gray.size == 0: | |
| raise Exception("Input image is empty") | |
| # Validate target size | |
| if not isinstance(target_size, (tuple, list)) or len(target_size) != 2: | |
| raise Exception(f"Target size must be (width, height) tuple, got {target_size}") | |
| width, height = target_size | |
| if width <= 0 or height <= 0: | |
| raise Exception(f"Invalid target size: {target_size}") | |
| # Resize using cv2 | |
| try: | |
| img_resized = cv2.resize(img_gray, (width, height), interpolation=cv2.INTER_AREA) | |
| except Exception as e: | |
| raise Exception(f"CV2 resize failed: {str(e)}") | |
| # Validate resize result | |
| if img_resized.shape != (height, width): | |
| raise Exception(f"Resize failed: expected {(height, width)}, got {img_resized.shape}") | |
| # Normalize using xrv | |
| try: | |
| img_normalized = xrv.datasets.normalize(img_resized, 255).astype(np.float32) | |
| except Exception as e: | |
| raise Exception(f"XRV normalization failed: {str(e)}") | |
| # Ensure single channel | |
| if img_normalized.ndim == 3: | |
| img_normalized = img_normalized.mean(axis=2) | |
| # Final validation | |
| if np.isnan(img_normalized).any(): | |
| raise Exception("Normalized image contains NaN values") | |
| if np.isinf(img_normalized).any(): | |
| raise Exception("Normalized image contains infinite values") | |
| return img_normalized[None, ...] # Add channel dimension | |
| except Exception as e: | |
| raise Exception(f"Error in cv2_resize_and_normalize: {str(e)}") | |
| def preprocess_for_classification(pil_img): | |
| try: | |
| # Convert PIL to OpenCV grayscale | |
| # img_gray = pil_to_cv2_gray(pil_img) | |
| # print(f"Grayscale image shape: {img_gray.shape}, dtype: {img_gray.dtype}") | |
| img_gray = xrv.datasets.normalize(np.array(pil_img), 255) # convert 8-bit image to [-1024, 1024] range | |
| img_gray = img_gray.mean(2)[None, ...] # Make single color channel | |
| transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(224)]) | |
| img = transform(img_gray) | |
| # # Resize and normalize using cv2 | |
| # img_processed = cv2_resize_and_normalize(img_gray, (224, 244)) | |
| # print(f"Processed image shape: {img_processed.shape}, dtype: {img_processed.dtype}") | |
| return torch.from_numpy(img)[None, ...].to(DEVICE) | |
| except Exception as e: | |
| raise Exception(f"Error in classification preprocessing: {str(e)}, {img_gray.shape}, {img_gray.dtype}") | |
| def preprocess_for_segmentation(pil_img): | |
| try: | |
| # Convert PIL to OpenCV grayscale | |
| img_gray = pil_to_cv2_gray(pil_img) | |
| print(f"Grayscale image shape: {img_gray.shape}, dtype: {img_gray.dtype}") | |
| # Resize and normalize using cv2 | |
| img_processed = cv2_resize_and_normalize(img_gray, (512, 512)) | |
| print(f"Processed image shape: {img_processed.shape}, dtype: {img_processed.dtype}") | |
| return torch.from_numpy(img_processed)[None, ...].to(DEVICE) | |
| except Exception as e: | |
| raise Exception(f"Error in segmentation preprocessing: {str(e)}") | |
| def run_classification(timg): | |
| if CLS_MODEL is None: | |
| raise Exception("Classification model not loaded") | |
| outputs = CLS_MODEL(timg) | |
| raw = outputs[0].detach().cpu().numpy().tolist() | |
| labels = CLS_MODEL.pathologies | |
| return dict(zip(labels, raw)) | |
| def run_segmentation(timg): | |
| if SEG_MODEL is None: | |
| raise Exception("Segmentation model not loaded") | |
| logits = SEG_MODEL(timg) | |
| return torch.sigmoid(logits).detach().cpu().numpy()[0] | |
| def masks_to_overlay(gray_512, masks, alpha=0.45): | |
| try: | |
| # Validate inputs | |
| if gray_512 is None: | |
| raise Exception("Base image is None") | |
| if masks is None: | |
| raise Exception("Masks are None") | |
| if not isinstance(gray_512, np.ndarray): | |
| raise Exception(f"Base image must be numpy array, got {type(gray_512)}") | |
| if not isinstance(masks, np.ndarray): | |
| raise Exception(f"Masks must be numpy array, got {type(masks)}") | |
| if len(gray_512.shape) != 2: | |
| raise Exception(f"Base image must be 2D, got shape {gray_512.shape}") | |
| if len(masks.shape) != 3: | |
| raise Exception(f"Masks must be 3D (channels, height, width), got shape {masks.shape}") | |
| # Ensure the base image is the right size | |
| if gray_512.shape != (512, 512): | |
| try: | |
| base = cv2.resize(gray_512, (512, 512), interpolation=cv2.INTER_AREA) | |
| except Exception as e: | |
| raise Exception(f"Failed to resize base image: {str(e)}") | |
| else: | |
| base = gray_512 | |
| # Validate base image | |
| if base.min() < 0 or base.max() > 255: | |
| raise Exception(f"Invalid base image pixel values: min={base.min()}, max={base.max()}") | |
| # Convert to RGB | |
| try: | |
| base_rgb = cv2.cvtColor(base, cv2.COLOR_GRAY2RGB).astype(np.float32) / 255.0 | |
| except Exception as e: | |
| raise Exception(f"Failed to convert to RGB: {str(e)}") | |
| # Color palette for different anatomical structures | |
| palette = np.array([[0,1,0],[1,0,0],[0,0,1],[1,1,0],[1,0,1],[0,1,1],[1,0.5,0]], dtype=np.float32) | |
| overlay = base_rgb.copy() | |
| # Apply each mask with different colors | |
| try: | |
| for i in range(min(masks.shape[0], 14)): | |
| color = palette[i % len(palette)] | |
| binm = (masks[i] > 0.5).astype(np.float32)[..., None] | |
| overlay = overlay * (1 - binm * alpha) + (color * binm * alpha) | |
| except Exception as e: | |
| raise Exception(f"Failed to apply masks: {str(e)}") | |
| # Validate output | |
| if np.isnan(overlay).any(): | |
| raise Exception("Overlay contains NaN values") | |
| if np.isinf(overlay).any(): | |
| raise Exception("Overlay contains infinite values") | |
| return Image.fromarray((overlay * 255.0).clip(0,255).astype(np.uint8)) | |
| except Exception as e: | |
| raise Exception(f"Error in masks_to_overlay: {str(e)}") | |
| def infer(image): | |
| try: | |
| # Comprehensive input validation | |
| if image is None: | |
| return None, "No image uploaded. Please upload a chest X-ray image.", "No image uploaded. Please upload a chest X-ray image.", None | |
| # Handle different input types | |
| try: | |
| if isinstance(image, np.ndarray): | |
| # Convert numpy array to PIL Image | |
| if len(image.shape) == 3: | |
| pil_img = Image.fromarray(image) | |
| else: | |
| pil_img = Image.fromarray(image, mode='L') | |
| elif hasattr(image, 'mode') and hasattr(image, 'size'): | |
| # Already a PIL Image | |
| pil_img = image | |
| else: | |
| raise Exception(f"Unsupported image type: {type(image)}") | |
| except Exception as e: | |
| return None, f"Image format error: {str(e)}", f"Image format error: {str(e)}", None | |
| # Validate image format and content | |
| try: | |
| validate_image_format(pil_img) | |
| except Exception as e: | |
| return None, f"Image validation failed: {str(e)}", f"Image validation failed: {str(e)}", None | |
| # Classification | |
| try: | |
| t_cls = preprocess_for_classification(pil_img) | |
| raw_dict = run_classification(t_cls) | |
| cls_vis = t_cls[0, 0].detach().cpu().numpy() | |
| cls_vis = cls_vis - cls_vis.min() | |
| if cls_vis.max() > 0: | |
| cls_vis = cls_vis / cls_vis.max() | |
| cls_vis_img = Image.fromarray((cls_vis * 255).astype(np.uint8)) | |
| raw_json = json.dumps(raw_dict, indent=2) | |
| except Exception as e: | |
| cls_vis_img = None | |
| raw_json = f"Classification Error: {str(e)}" | |
| # Segmentation | |
| try: | |
| t_seg = preprocess_for_segmentation(pil_img) | |
| masks = run_segmentation(t_seg) | |
| # Create visualization using cv2 | |
| seg_vis_gray = t_seg[0, 0].detach().cpu().numpy() | |
| seg_vis_gray = seg_vis_gray - seg_vis_gray.min() | |
| if seg_vis_gray.max() > 0: | |
| seg_vis_gray = seg_vis_gray / seg_vis_gray.max() | |
| seg_vis_gray = (seg_vis_gray * 255).astype(np.uint8) | |
| # Ensure proper shape for overlay | |
| if seg_vis_gray.shape != (512, 512): | |
| seg_vis_gray = cv2.resize(seg_vis_gray, (512, 512), interpolation=cv2.INTER_AREA) | |
| seg_overlay = masks_to_overlay(seg_vis_gray, masks, alpha=0.45) | |
| except Exception as e: | |
| seg_overlay = None | |
| # If segmentation fails, we still want to return the classification results | |
| if 'cls_vis_img' not in locals(): | |
| cls_vis_img = None | |
| raw_json = f"Segmentation Error: {str(e)}" | |
| return cls_vis_img, raw_json, seg_overlay | |
| except Exception as e: | |
| error_msg = f"General Error: {str(e)}" | |
| return None, error_msg, None | |
| # Use the older Interface API to avoid schema issues | |
| demo = gr.Interface( | |
| fn=infer, | |
| inputs=gr.Image(label="Upload chest X-ray (PNG/JPG)", type="pil"), | |
| outputs=[ | |
| gr.Image(label="Classification input view (224×244)", type="pil"), | |
| gr.Textbox(label="Classification probability scores (JSON)", lines=10), | |
| gr.Image(label="Segmentation overlay (512×512, 14 anatomies)", type="pil") | |
| ], | |
| title="Chest X-ray: Abnormality Classification + Anatomical Segmentation", | |
| description=""" | |
| **Models:** TorchXRayVision DenseNet (classification) & PSPNet (14-class anatomy) | |
| **Shapes:** Classification (224×244), Segmentation (512×512) | |
| *For research/education. Not for clinical use.* | |
| """, | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| # Always expose a public link on Spaces | |
| demo.launch(server_name="0.0.0.0", share=True) | |