CXR_Demo / app.py
alfahimmohammad
simplify schema + stable gradio
e5dcb32
import io, json, os
import gradio as gr
import numpy as np
import torch, torchvision
from PIL import Image
import torchxrayvision as xrv
import cv2
DEVICE = torch.device("cpu")
# Load models with error handling
try:
print("Loading classification model...")
CLS_MODEL = xrv.models.DenseNet(weights="densenet121-res224-all").to(DEVICE).eval()
print("Classification model loaded successfully")
except Exception as e:
print(f"Error loading classification model: {e}")
CLS_MODEL = None
try:
print("Loading segmentation model...")
SEG_MODEL = xrv.baseline_models.chestx_det.PSPNet().to(DEVICE).eval()
print("Segmentation model loaded successfully")
except Exception as e:
print(f"Error loading segmentation model: {e}")
SEG_MODEL = None
CenterCrop = xrv.datasets.XRayCenterCrop()
Resizer_CLS = xrv.datasets.XRayResizer((224, 244))
Resizer_SEG = xrv.datasets.XRayResizer((512, 512))
Transform_CLS = torchvision.transforms.Compose([CenterCrop, Resizer_CLS])
Transform_SEG = torchvision.transforms.Compose([CenterCrop, Resizer_SEG])
def validate_image_format(pil_img):
"""Validate image format and properties"""
try:
if pil_img is None:
raise Exception("No image provided")
# Check if it's a valid PIL Image
if not hasattr(pil_img, 'mode') or not hasattr(pil_img, 'size'):
raise Exception("Invalid image format - not a PIL Image")
# Check image size
width, height = pil_img.size
if width < 32 or height < 32:
raise Exception(f"Image too small: {width}x{height}. Minimum size: 32x32")
if width > 4096 or height > 4096:
raise Exception(f"Image too large: {width}x{height}. Maximum size: 4096x4096")
# Check if image has content (not all black or white)
img_array = np.array(pil_img)
if len(img_array.shape) == 3:
img_gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
else:
img_gray = img_array
# Check for valid image content
if img_gray.std() < 5: # Very low standard deviation means mostly uniform
raise Exception("Image appears to be blank or uniform (no visible content)")
return True
except Exception as e:
raise Exception(f"Image validation failed: {str(e)}")
def pil_to_cv2_gray(pil_img):
"""Convert PIL image to OpenCV grayscale format with comprehensive error handling"""
try:
# First validate the image
validate_image_format(pil_img)
# Convert PIL to numpy array
img_array = np.array(pil_img)
# Check array properties
if img_array.size == 0:
raise Exception("Empty image array")
# Convert to grayscale if needed
if len(img_array.shape) == 3:
if img_array.shape[2] == 3: # RGB
img_gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
elif img_array.shape[2] == 4: # RGBA
img_gray = cv2.cvtColor(img_array, cv2.COLOR_RGBA2GRAY)
else:
# Take the first channel if it's not RGB/RGBA
img_gray = img_array[:, :, 0]
else:
img_gray = img_array
# Ensure proper data type and range
if img_gray.dtype != np.uint8:
if img_gray.max() <= 1.0: # Float image in [0,1] range
img_gray = (img_gray * 255).astype(np.uint8)
else:
img_gray = img_gray.astype(np.uint8)
# Final validation
if img_gray.min() < 0 or img_gray.max() > 255:
raise Exception(f"Invalid pixel values: min={img_gray.min()}, max={img_gray.max()}")
return img_gray
except Exception as e:
raise Exception(f"Error in pil_to_cv2_gray: {str(e)}")
def cv2_resize_and_normalize(img_gray, target_size):
"""Resize image using cv2 and normalize using xrv with comprehensive error handling"""
try:
# Validate input
if img_gray is None:
raise Exception("Input image is None")
if not isinstance(img_gray, np.ndarray):
raise Exception(f"Input must be numpy array, got {type(img_gray)}")
if len(img_gray.shape) != 2:
raise Exception(f"Input must be 2D grayscale image, got shape {img_gray.shape}")
if img_gray.size == 0:
raise Exception("Input image is empty")
# Validate target size
if not isinstance(target_size, (tuple, list)) or len(target_size) != 2:
raise Exception(f"Target size must be (width, height) tuple, got {target_size}")
width, height = target_size
if width <= 0 or height <= 0:
raise Exception(f"Invalid target size: {target_size}")
# Resize using cv2
try:
img_resized = cv2.resize(img_gray, (width, height), interpolation=cv2.INTER_AREA)
except Exception as e:
raise Exception(f"CV2 resize failed: {str(e)}")
# Validate resize result
if img_resized.shape != (height, width):
raise Exception(f"Resize failed: expected {(height, width)}, got {img_resized.shape}")
# Normalize using xrv
try:
img_normalized = xrv.datasets.normalize(img_resized, 255).astype(np.float32)
except Exception as e:
raise Exception(f"XRV normalization failed: {str(e)}")
# Ensure single channel
if img_normalized.ndim == 3:
img_normalized = img_normalized.mean(axis=2)
# Final validation
if np.isnan(img_normalized).any():
raise Exception("Normalized image contains NaN values")
if np.isinf(img_normalized).any():
raise Exception("Normalized image contains infinite values")
return img_normalized[None, ...] # Add channel dimension
except Exception as e:
raise Exception(f"Error in cv2_resize_and_normalize: {str(e)}")
def preprocess_for_classification(pil_img):
try:
# Convert PIL to OpenCV grayscale
# img_gray = pil_to_cv2_gray(pil_img)
# print(f"Grayscale image shape: {img_gray.shape}, dtype: {img_gray.dtype}")
img_gray = xrv.datasets.normalize(np.array(pil_img), 255) # convert 8-bit image to [-1024, 1024] range
img_gray = img_gray.mean(2)[None, ...] # Make single color channel
transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop(),xrv.datasets.XRayResizer(224)])
img = transform(img_gray)
# # Resize and normalize using cv2
# img_processed = cv2_resize_and_normalize(img_gray, (224, 244))
# print(f"Processed image shape: {img_processed.shape}, dtype: {img_processed.dtype}")
return torch.from_numpy(img)[None, ...].to(DEVICE)
except Exception as e:
raise Exception(f"Error in classification preprocessing: {str(e)}, {img_gray.shape}, {img_gray.dtype}")
def preprocess_for_segmentation(pil_img):
try:
# Convert PIL to OpenCV grayscale
img_gray = pil_to_cv2_gray(pil_img)
print(f"Grayscale image shape: {img_gray.shape}, dtype: {img_gray.dtype}")
# Resize and normalize using cv2
img_processed = cv2_resize_and_normalize(img_gray, (512, 512))
print(f"Processed image shape: {img_processed.shape}, dtype: {img_processed.dtype}")
return torch.from_numpy(img_processed)[None, ...].to(DEVICE)
except Exception as e:
raise Exception(f"Error in segmentation preprocessing: {str(e)}")
@torch.no_grad()
def run_classification(timg):
if CLS_MODEL is None:
raise Exception("Classification model not loaded")
outputs = CLS_MODEL(timg)
raw = outputs[0].detach().cpu().numpy().tolist()
labels = CLS_MODEL.pathologies
return dict(zip(labels, raw))
@torch.no_grad()
def run_segmentation(timg):
if SEG_MODEL is None:
raise Exception("Segmentation model not loaded")
logits = SEG_MODEL(timg)
return torch.sigmoid(logits).detach().cpu().numpy()[0]
def masks_to_overlay(gray_512, masks, alpha=0.45):
try:
# Validate inputs
if gray_512 is None:
raise Exception("Base image is None")
if masks is None:
raise Exception("Masks are None")
if not isinstance(gray_512, np.ndarray):
raise Exception(f"Base image must be numpy array, got {type(gray_512)}")
if not isinstance(masks, np.ndarray):
raise Exception(f"Masks must be numpy array, got {type(masks)}")
if len(gray_512.shape) != 2:
raise Exception(f"Base image must be 2D, got shape {gray_512.shape}")
if len(masks.shape) != 3:
raise Exception(f"Masks must be 3D (channels, height, width), got shape {masks.shape}")
# Ensure the base image is the right size
if gray_512.shape != (512, 512):
try:
base = cv2.resize(gray_512, (512, 512), interpolation=cv2.INTER_AREA)
except Exception as e:
raise Exception(f"Failed to resize base image: {str(e)}")
else:
base = gray_512
# Validate base image
if base.min() < 0 or base.max() > 255:
raise Exception(f"Invalid base image pixel values: min={base.min()}, max={base.max()}")
# Convert to RGB
try:
base_rgb = cv2.cvtColor(base, cv2.COLOR_GRAY2RGB).astype(np.float32) / 255.0
except Exception as e:
raise Exception(f"Failed to convert to RGB: {str(e)}")
# Color palette for different anatomical structures
palette = np.array([[0,1,0],[1,0,0],[0,0,1],[1,1,0],[1,0,1],[0,1,1],[1,0.5,0]], dtype=np.float32)
overlay = base_rgb.copy()
# Apply each mask with different colors
try:
for i in range(min(masks.shape[0], 14)):
color = palette[i % len(palette)]
binm = (masks[i] > 0.5).astype(np.float32)[..., None]
overlay = overlay * (1 - binm * alpha) + (color * binm * alpha)
except Exception as e:
raise Exception(f"Failed to apply masks: {str(e)}")
# Validate output
if np.isnan(overlay).any():
raise Exception("Overlay contains NaN values")
if np.isinf(overlay).any():
raise Exception("Overlay contains infinite values")
return Image.fromarray((overlay * 255.0).clip(0,255).astype(np.uint8))
except Exception as e:
raise Exception(f"Error in masks_to_overlay: {str(e)}")
def infer(image):
try:
# Comprehensive input validation
if image is None:
return None, "No image uploaded. Please upload a chest X-ray image.", "No image uploaded. Please upload a chest X-ray image.", None
# Handle different input types
try:
if isinstance(image, np.ndarray):
# Convert numpy array to PIL Image
if len(image.shape) == 3:
pil_img = Image.fromarray(image)
else:
pil_img = Image.fromarray(image, mode='L')
elif hasattr(image, 'mode') and hasattr(image, 'size'):
# Already a PIL Image
pil_img = image
else:
raise Exception(f"Unsupported image type: {type(image)}")
except Exception as e:
return None, f"Image format error: {str(e)}", f"Image format error: {str(e)}", None
# Validate image format and content
try:
validate_image_format(pil_img)
except Exception as e:
return None, f"Image validation failed: {str(e)}", f"Image validation failed: {str(e)}", None
# Classification
try:
t_cls = preprocess_for_classification(pil_img)
raw_dict = run_classification(t_cls)
cls_vis = t_cls[0, 0].detach().cpu().numpy()
cls_vis = cls_vis - cls_vis.min()
if cls_vis.max() > 0:
cls_vis = cls_vis / cls_vis.max()
cls_vis_img = Image.fromarray((cls_vis * 255).astype(np.uint8))
raw_json = json.dumps(raw_dict, indent=2)
except Exception as e:
cls_vis_img = None
raw_json = f"Classification Error: {str(e)}"
# Segmentation
try:
t_seg = preprocess_for_segmentation(pil_img)
masks = run_segmentation(t_seg)
# Create visualization using cv2
seg_vis_gray = t_seg[0, 0].detach().cpu().numpy()
seg_vis_gray = seg_vis_gray - seg_vis_gray.min()
if seg_vis_gray.max() > 0:
seg_vis_gray = seg_vis_gray / seg_vis_gray.max()
seg_vis_gray = (seg_vis_gray * 255).astype(np.uint8)
# Ensure proper shape for overlay
if seg_vis_gray.shape != (512, 512):
seg_vis_gray = cv2.resize(seg_vis_gray, (512, 512), interpolation=cv2.INTER_AREA)
seg_overlay = masks_to_overlay(seg_vis_gray, masks, alpha=0.45)
except Exception as e:
seg_overlay = None
# If segmentation fails, we still want to return the classification results
if 'cls_vis_img' not in locals():
cls_vis_img = None
raw_json = f"Segmentation Error: {str(e)}"
return cls_vis_img, raw_json, seg_overlay
except Exception as e:
error_msg = f"General Error: {str(e)}"
return None, error_msg, None
# Use the older Interface API to avoid schema issues
demo = gr.Interface(
fn=infer,
inputs=gr.Image(label="Upload chest X-ray (PNG/JPG)", type="pil"),
outputs=[
gr.Image(label="Classification input view (224×244)", type="pil"),
gr.Textbox(label="Classification probability scores (JSON)", lines=10),
gr.Image(label="Segmentation overlay (512×512, 14 anatomies)", type="pil")
],
title="Chest X-ray: Abnormality Classification + Anatomical Segmentation",
description="""
**Models:** TorchXRayVision DenseNet (classification) & PSPNet (14-class anatomy)
**Shapes:** Classification (224×244), Segmentation (512×512)
*For research/education. Not for clinical use.*
""",
allow_flagging="never"
)
if __name__ == "__main__":
# Always expose a public link on Spaces
demo.launch(server_name="0.0.0.0", share=True)