itserphan's picture
removing grad-cam due to issues with latest gradio version. instead using a pytorch cutom CAM to do the same.
2a627de verified
import gradio as gr
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import numpy as np
import cv2
import warnings
warnings.filterwarnings('ignore')
DEVICE = torch.device("cpu")
MODEL_PATH = "effnet_b0_nih_2026_02_19.pth"
LABELS_PATH = "labels.txt"
with open(LABELS_PATH, "r") as f:
categories = [line.strip() for line in f.readlines()]
def load_model():
model = models.efficientnet_b0(weights=None)
n_inputs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(n_inputs, len(categories))
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval()
return model
model = load_model()
class SimpleGradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
self.target_layer.register_forward_hook(self.save_activation)
self.target_layer.register_full_backward_hook(self.save_gradient)
def save_activation(self, module, input, output):
self.activations = output.detach()
def save_gradient(self, module, grad_input, grad_output):
self.gradients = grad_output[0].detach()
def __call__(self, x, class_idx):
self.model.zero_grad()
output = self.model(x)
score = output[:, class_idx].squeeze()
score.backward(retain_graph=True)
b, k, u, v = self.gradients.size()
alpha = self.gradients.view(b, k, -1).mean(2)
weights = alpha.view(b, k, 1, 1)
saliency_map = (weights * self.activations).sum(1, keepdim=True)
saliency_map = torch.relu(saliency_map)
saliency_map = saliency_map.squeeze().cpu().numpy()
if saliency_map.max() > 0:
saliency_map = saliency_map / saliency_map.max()
return saliency_map
cam_extractor = SimpleGradCAM(model, model.features[-1])
class MedicalTransform:
def __call__(self, img):
img_np = np.array(img)
if len(img_np.shape) > 2:
img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
return Image.fromarray(clahe.apply(img_np))
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def predict(input_image):
if input_image is None:
raise gr.Error("Please upload an X-ray image.")
try:
raw_img = Image.fromarray(input_image.astype('uint8')).convert('L')
img_clahe = MedicalTransform()(raw_img)
img_clahe_rgb = img_clahe.convert('RGB')
img_tensor = preprocess(img_clahe_rgb).unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits = model(img_tensor)
probs = torch.sigmoid(logits)[0]
top_idx = torch.argmax(probs).item()
with torch.enable_grad():
img_tensor_cam = img_tensor.clone().requires_grad_(True)
heatmap = cam_extractor(img_tensor_cam, top_idx)
display_img = np.array(img_clahe_rgb)
display_img = cv2.resize(display_img, (224, 224))
heatmap_resized = cv2.resize(heatmap, (224, 224))
heatmap_color = cv2.applyColorMap(np.uint8(255 * heatmap_resized), cv2.COLORMAP_JET)
heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
cam_image = cv2.addWeighted(display_img, 0.6, heatmap_color, 0.4, 0)
pred_dict = {categories[i]: probs[i].item() for i in range(len(categories))}
return img_clahe, cam_image, pred_dict
except Exception as e:
raise gr.Error(f"Analysis Failed: {str(e)}")
theme = gr.themes.Soft(
primary_hue="blue",
secondary_hue="indigo",
font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"]
)
with gr.Blocks(theme=theme, title="AI Radiologist Assistant") as demo:
gr.Markdown(
"""
# 🩺 AI Radiologist Assistant
**Model:** EfficientNet-B0 | **Dataset:** NIH Chest X-Ray 14 | **XAI:** Custom PyTorch CAM
*Upload a chest X-ray to predict the probability of 14 common thoracic conditions.*
"""
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(label="Input X-Ray", type="numpy")
analyze_btn = gr.Button("🔍 Analyze Image", variant="primary")
gr.Examples(
examples=["example_1.png", "example_2.png", "example_3.png", "example_4.png", "example_5.png"],
inputs=image_input,
cache_examples=False
)
with gr.Column(scale=1):
with gr.Row():
clahe_output = gr.Image(label="Enhanced View (CLAHE)")
cam_output = gr.Image(label="Attention Map")
label_output = gr.Label(num_top_classes=5, label="Top Predictions")
analyze_btn.click(
fn=predict,
inputs=image_input,
outputs=[clahe_output, cam_output, label_output]
)
if __name__ == "__main__":
demo.launch(ssr_mode=False)