nih-chest-xray / app.py
Jamshaid89's picture
dsf
6aba54b
import gradio as gr
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
# Configuration
num_classes = 15
model_path = "model_epoch_20.pth" # Path to your trained model weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Class names
CLASS_NAMES = {
"0": "Atelectasis",
"1": "Cardiomegaly",
"2": "Consolidation",
"3": "Edema",
"4": "Effusion",
"5": "Emphysema",
"6": "Fibrosis",
"7": "Hernia",
"8": "Infiltration",
"9": "Mass",
"10": "No Finding",
"11": "Nodule",
"12": "Pleural_Thickening",
"13": "Pneumonia",
"14": "Pneumothorax"
}
# Load pretrained Inception v3
model = models.inception_v3(weights='IMAGENET1K_V1', aux_logits=True)
# Replace final classifier (fc) and auxiliary classifier
model.fc = nn.Sequential(
nn.Linear(model.fc.in_features, 512),
nn.ReLU(),
nn.Linear(512, num_classes),
nn.Sigmoid()
)
model.AuxLogits.fc = nn.Sequential(
nn.Linear(model.AuxLogits.fc.in_features, 512),
nn.ReLU(),
nn.Linear(512, num_classes),
nn.Sigmoid()
)
# Load model weights
try:
model.load_state_dict(torch.load(model_path, map_location=device))
except FileNotFoundError:
raise FileNotFoundError(f"Model weights file '{model_path}' not found. Please upload it to the Hugging Face Space.")
except Exception as e:
raise Exception(f"Error loading model weights: {str(e)}")
model.eval()
model.to(device)
# Preprocessing (resize to 224x224)
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Grad-CAM setup
target_layer = model.Mixed_7c # Last convolutional layer for Inception v3
gradcam = GradCAM(model=model, target_layers=[target_layer])
# Inference function with Grad-CAM
def predict_xray(image: np.ndarray):
# Resize input image if needed
if image.shape[:2] != (224, 224):
image = Image.fromarray(image).resize((224, 224))
image = np.array(image)
# Ensure image is RGB
if image.ndim == 2: # Grayscale
image = np.stack([image] * 3, axis=-1)
elif image.shape[-1] == 1:
image = np.repeat(image, 3, axis=-1)
# Preprocess image
pil_img = Image.fromarray(image.astype("uint8"), "RGB")
input_tensor = preprocess(pil_img).unsqueeze(0).to(device)
# Get model predictions
with torch.no_grad():
logits = model(input_tensor)
probs = torch.sigmoid(logits)[0].cpu().numpy()
# Generate results with probabilities and binary labels (threshold=0.5)
result = {
CLASS_NAMES[str(i)]: probs[i]
for i in range(len(CLASS_NAMES)) if i != 10 # Exclude "No Finding"
}
# Generate Grad-CAM for top 4 classes
top_k = 4
top_indices = np.argsort(probs)[-top_k:][::-1]
heatmaps = []
rgb_img = input_tensor.squeeze().permute(1, 2, 0).cpu().numpy()
rgb_img = (rgb_img - rgb_img.min()) / (rgb_img.max() - rgb_img.min() + 1e-8) # Normalize
for idx in top_indices:
if idx == 10: # Skip "No Finding" for Grad-CAM
continue
targets = [ClassifierOutputTarget(idx)]
grayscale_cam = gradcam(input_tensor=input_tensor, targets=targets)[0]
cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
heatmaps.append((cam_image, f"{CLASS_NAMES[str(idx)]} (Prob: {probs[idx]:.3f})"))
return result, heatmaps
# Gradio interface
interface = gr.Interface(
fn=predict_xray,
inputs=gr.Image(type="numpy"),
outputs=[
gr.Label(num_top_classes=14, label="Predicted Probabilities (and Binary Labels)"),
gr.Gallery(label="Grad-CAM Heatmaps (Top 4 Classes)")
],
title="NIH Chest X-ray Multi-Label Classifier",
description="Upload a chest X-ray (resized to 224x224). The model outputs probabilities and binary labels (threshold=0.5) for 14 findings, excluding 'No Finding'. Grad-CAM heatmaps highlight regions for the top 4 predicted findings. Low probabilities are common for rare conditions like Hernia due to dataset imbalance."
)
if __name__ == "__main__":
print("starting Gradio interface...")
interface.launch(share=True) # Set to True for Hugging Face Spaces