Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models, transforms | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| DEVICE = torch.device("cpu") | |
| MODEL_PATH = "effnet_b0_nih_2026_02_19.pth" | |
| LABELS_PATH = "labels.txt" | |
| with open(LABELS_PATH, "r") as f: | |
| categories = [line.strip() for line in f.readlines()] | |
| def load_model(): | |
| model = models.efficientnet_b0(weights=None) | |
| n_inputs = model.classifier[1].in_features | |
| model.classifier[1] = nn.Linear(n_inputs, len(categories)) | |
| model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE)) | |
| model.to(DEVICE) | |
| model.eval() | |
| return model | |
| model = load_model() | |
| class SimpleGradCAM: | |
| def __init__(self, model, target_layer): | |
| self.model = model | |
| self.target_layer = target_layer | |
| self.gradients = None | |
| self.activations = None | |
| self.target_layer.register_forward_hook(self.save_activation) | |
| self.target_layer.register_full_backward_hook(self.save_gradient) | |
| def save_activation(self, module, input, output): | |
| self.activations = output.detach() | |
| def save_gradient(self, module, grad_input, grad_output): | |
| self.gradients = grad_output[0].detach() | |
| def __call__(self, x, class_idx): | |
| self.model.zero_grad() | |
| output = self.model(x) | |
| score = output[:, class_idx].squeeze() | |
| score.backward(retain_graph=True) | |
| b, k, u, v = self.gradients.size() | |
| alpha = self.gradients.view(b, k, -1).mean(2) | |
| weights = alpha.view(b, k, 1, 1) | |
| saliency_map = (weights * self.activations).sum(1, keepdim=True) | |
| saliency_map = torch.relu(saliency_map) | |
| saliency_map = saliency_map.squeeze().cpu().numpy() | |
| if saliency_map.max() > 0: | |
| saliency_map = saliency_map / saliency_map.max() | |
| return saliency_map | |
| cam_extractor = SimpleGradCAM(model, model.features[-1]) | |
| class MedicalTransform: | |
| def __call__(self, img): | |
| img_np = np.array(img) | |
| if len(img_np.shape) > 2: | |
| img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY) | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) | |
| return Image.fromarray(clahe.apply(img_np)) | |
| preprocess = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def predict(input_image): | |
| if input_image is None: | |
| raise gr.Error("Please upload an X-ray image.") | |
| try: | |
| raw_img = Image.fromarray(input_image.astype('uint8')).convert('L') | |
| img_clahe = MedicalTransform()(raw_img) | |
| img_clahe_rgb = img_clahe.convert('RGB') | |
| img_tensor = preprocess(img_clahe_rgb).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| logits = model(img_tensor) | |
| probs = torch.sigmoid(logits)[0] | |
| top_idx = torch.argmax(probs).item() | |
| with torch.enable_grad(): | |
| img_tensor_cam = img_tensor.clone().requires_grad_(True) | |
| heatmap = cam_extractor(img_tensor_cam, top_idx) | |
| display_img = np.array(img_clahe_rgb) | |
| display_img = cv2.resize(display_img, (224, 224)) | |
| heatmap_resized = cv2.resize(heatmap, (224, 224)) | |
| heatmap_color = cv2.applyColorMap(np.uint8(255 * heatmap_resized), cv2.COLORMAP_JET) | |
| heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB) | |
| cam_image = cv2.addWeighted(display_img, 0.6, heatmap_color, 0.4, 0) | |
| pred_dict = {categories[i]: probs[i].item() for i in range(len(categories))} | |
| return img_clahe, cam_image, pred_dict | |
| except Exception as e: | |
| raise gr.Error(f"Analysis Failed: {str(e)}") | |
| theme = gr.themes.Soft( | |
| primary_hue="blue", | |
| secondary_hue="indigo", | |
| font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"] | |
| ) | |
| with gr.Blocks(theme=theme, title="AI Radiologist Assistant") as demo: | |
| gr.Markdown( | |
| """ | |
| # 🩺 AI Radiologist Assistant | |
| **Model:** EfficientNet-B0 | **Dataset:** NIH Chest X-Ray 14 | **XAI:** Custom PyTorch CAM | |
| *Upload a chest X-ray to predict the probability of 14 common thoracic conditions.* | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(label="Input X-Ray", type="numpy") | |
| analyze_btn = gr.Button("🔍 Analyze Image", variant="primary") | |
| gr.Examples( | |
| examples=["example_1.png", "example_2.png", "example_3.png", "example_4.png", "example_5.png"], | |
| inputs=image_input, | |
| cache_examples=False | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| clahe_output = gr.Image(label="Enhanced View (CLAHE)") | |
| cam_output = gr.Image(label="Attention Map") | |
| label_output = gr.Label(num_top_classes=5, label="Top Predictions") | |
| analyze_btn.click( | |
| fn=predict, | |
| inputs=image_input, | |
| outputs=[clahe_output, cam_output, label_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False) |