Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.models as models | |
| import torchvision.transforms as transforms | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| # Configuration | |
| num_classes = 15 | |
| model_path = "model_epoch_20.pth" # Path to your trained model weights | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Class names | |
| CLASS_NAMES = { | |
| "0": "Atelectasis", | |
| "1": "Cardiomegaly", | |
| "2": "Consolidation", | |
| "3": "Edema", | |
| "4": "Effusion", | |
| "5": "Emphysema", | |
| "6": "Fibrosis", | |
| "7": "Hernia", | |
| "8": "Infiltration", | |
| "9": "Mass", | |
| "10": "No Finding", | |
| "11": "Nodule", | |
| "12": "Pleural_Thickening", | |
| "13": "Pneumonia", | |
| "14": "Pneumothorax" | |
| } | |
| # Load pretrained Inception v3 | |
| model = models.inception_v3(weights='IMAGENET1K_V1', aux_logits=True) | |
| # Replace final classifier (fc) and auxiliary classifier | |
| model.fc = nn.Sequential( | |
| nn.Linear(model.fc.in_features, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, num_classes), | |
| nn.Sigmoid() | |
| ) | |
| model.AuxLogits.fc = nn.Sequential( | |
| nn.Linear(model.AuxLogits.fc.in_features, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, num_classes), | |
| nn.Sigmoid() | |
| ) | |
| # Load model weights | |
| try: | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| except FileNotFoundError: | |
| raise FileNotFoundError(f"Model weights file '{model_path}' not found. Please upload it to the Hugging Face Space.") | |
| except Exception as e: | |
| raise Exception(f"Error loading model weights: {str(e)}") | |
| model.eval() | |
| model.to(device) | |
| # Preprocessing (resize to 224x224) | |
| preprocess = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| ]) | |
| # Grad-CAM setup | |
| target_layer = model.Mixed_7c # Last convolutional layer for Inception v3 | |
| gradcam = GradCAM(model=model, target_layers=[target_layer]) | |
| # Inference function with Grad-CAM | |
| def predict_xray(image: np.ndarray): | |
| # Resize input image if needed | |
| if image.shape[:2] != (224, 224): | |
| image = Image.fromarray(image).resize((224, 224)) | |
| image = np.array(image) | |
| # Ensure image is RGB | |
| if image.ndim == 2: # Grayscale | |
| image = np.stack([image] * 3, axis=-1) | |
| elif image.shape[-1] == 1: | |
| image = np.repeat(image, 3, axis=-1) | |
| # Preprocess image | |
| pil_img = Image.fromarray(image.astype("uint8"), "RGB") | |
| input_tensor = preprocess(pil_img).unsqueeze(0).to(device) | |
| # Get model predictions | |
| with torch.no_grad(): | |
| logits = model(input_tensor) | |
| probs = torch.sigmoid(logits)[0].cpu().numpy() | |
| # Generate results with probabilities and binary labels (threshold=0.5) | |
| result = { | |
| CLASS_NAMES[str(i)]: probs[i] | |
| for i in range(len(CLASS_NAMES)) if i != 10 # Exclude "No Finding" | |
| } | |
| # Generate Grad-CAM for top 4 classes | |
| top_k = 4 | |
| top_indices = np.argsort(probs)[-top_k:][::-1] | |
| heatmaps = [] | |
| rgb_img = input_tensor.squeeze().permute(1, 2, 0).cpu().numpy() | |
| rgb_img = (rgb_img - rgb_img.min()) / (rgb_img.max() - rgb_img.min() + 1e-8) # Normalize | |
| for idx in top_indices: | |
| if idx == 10: # Skip "No Finding" for Grad-CAM | |
| continue | |
| targets = [ClassifierOutputTarget(idx)] | |
| grayscale_cam = gradcam(input_tensor=input_tensor, targets=targets)[0] | |
| cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True) | |
| heatmaps.append((cam_image, f"{CLASS_NAMES[str(idx)]} (Prob: {probs[idx]:.3f})")) | |
| return result, heatmaps | |
| # Gradio interface | |
| interface = gr.Interface( | |
| fn=predict_xray, | |
| inputs=gr.Image(type="numpy"), | |
| outputs=[ | |
| gr.Label(num_top_classes=14, label="Predicted Probabilities (and Binary Labels)"), | |
| gr.Gallery(label="Grad-CAM Heatmaps (Top 4 Classes)") | |
| ], | |
| title="NIH Chest X-ray Multi-Label Classifier", | |
| description="Upload a chest X-ray (resized to 224x224). The model outputs probabilities and binary labels (threshold=0.5) for 14 findings, excluding 'No Finding'. Grad-CAM heatmaps highlight regions for the top 4 predicted findings. Low probabilities are common for rare conditions like Hernia due to dataset imbalance." | |
| ) | |
| if __name__ == "__main__": | |
| print("starting Gradio interface...") | |
| interface.launch(share=True) # Set to True for Hugging Face Spaces |