File size: 5,868 Bytes
9748112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1819383
 
 
 
9748112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import gradio as gr
import torch
from torchvision import models, transforms
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from PIL import Image
import numpy as np
from skimage.transform import resize
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

# Constants
REPO_ID = "itsomk/chexpert-densenet121"
FILENAME = "pytorch_model.safetensors"

# Model Definition
class DenseNet121_CheXpert(torch.nn.Module):
    def __init__(self, num_labels=14, pretrained=None):
        super().__init__()
        self.densenet = models.densenet121(weights=pretrained)
        num_features = self.densenet.classifier.in_features
        self.densenet.classifier = torch.nn.Linear(num_features, num_labels)
    
    def forward(self, x):
        return self.densenet(x)

# Labels
LABELS = [
    "No Finding", "Enlarged Cardiomediastinum", "Cardiomegaly", "Lung Opacity",
    "Lung Lesion", "Edema", "Consolidation", "Pneumonia", "Atelectasis",
    "Pneumothorax", "Pleural Effusion", "Pleural Other", "Fracture", "Support Devices"
]

# Preprocessing
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load model
print("Loading model...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
local_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
state = load_file(local_path)
model = DenseNet121_CheXpert(num_labels=14, pretrained=None)
model.load_state_dict(state, strict=False)
model.to(device)
model.eval()
if device.type=='cuda':
    print(f"Model loaded successfully on GPU {torch.cuda.get_device_name(torch.cuda.current_device())}")
else:
    print(f"Model loaded successfully on CPU")

def predict(image, threshold):
    """Generate predictions and Grad-CAM visualizations"""
    if image is None:
        return None, None, "Please upload an X-ray image"
    
    try:
        # Convert to PIL Image
        if isinstance(image, np.ndarray):
            img = Image.fromarray(image).convert("RGB")
        else:
            img = image.convert("RGB")
        
        # Preprocess
        img_tensor = preprocess(img).unsqueeze(0).to(device)
        rgb_img = np.array(img.resize((224, 224)), dtype=np.float32) / 255.0
        
        # Get predictions
        with torch.no_grad():
            logits = model(img_tensor)
            probs = torch.sigmoid(logits).squeeze().cpu().numpy()
        
        # Setup Grad-CAM
        target_layer = model.densenet.features.denseblock4
        cam = GradCAM(model=model, target_layers=[target_layer])
        
        # Generate visualizations for conditions above threshold
        gradcam_images = []
        detected_conditions = []
        
        for i, prob in enumerate(probs):
            if prob > threshold:
                label = LABELS[i]
                targets = [ClassifierOutputTarget(i)]
                grayscale_cam = cam(input_tensor=img_tensor, targets=targets)
                grayscale_cam = grayscale_cam[0, :]
                
                resized_rgb_img = resize(rgb_img, grayscale_cam.shape, anti_aliasing=True)
                cam_image = show_cam_on_image(resized_rgb_img, grayscale_cam, use_rgb=True)
                
                gradcam_images.append(cam_image)
                detected_conditions.append(f"**{label}**: {prob:.4f}")
        
        # Create summary text
        all_predictions = "\n".join([f"{LABELS[i]}: {prob:.4f}" for i, prob in enumerate(probs)])
        
        if detected_conditions:
            summary = f"## Detected Conditions (>{threshold}):\n" + "\n".join(detected_conditions)
            summary += f"\n\n## All Predictions:\n{all_predictions}"
            # Return first Grad-CAM image and original image
            return gradcam_images[0], img, summary
        else:
            summary = f"No conditions detected above threshold {threshold}\n\n## All Predictions:\n{all_predictions}"
            return None, img, summary
            
    except Exception as e:
        return None, None, f"Error: {str(e)}"

# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # 🩻 X-Ray Grad-CAM Visualization
        
        Upload a chest X-ray image to analyze potential conditions using DenseNet121 with Grad-CAM visualization.
        
        **Model**: [itsomk/chexpert-densenet121](https://huggingface.co/itsomk/chexpert-densenet121)
        """
    )
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Upload X-Ray Image", type="pil")
            threshold = gr.Slider(
                minimum=0.0, 
                maximum=1.0, 
                value=0.5, 
                step=0.05, 
                label="Prediction Threshold"
            )
            analyze_btn = gr.Button("🔍 Analyze X-Ray", variant="primary", size="lg")
        
        with gr.Column():
            output_gradcam = gr.Image(label="Grad-CAM Visualization")
            output_image = gr.Image(label="Original Image")
    
    with gr.Row():
        output_text = gr.Markdown(label="Analysis Results")
    
    # Examples
    gr.Markdown("### 📋 Instructions:")
    gr.Markdown(
        """
        1. Upload a chest X-ray image (JPG, PNG)
        2. Adjust the prediction threshold if needed (default: 0.5)
        3. Click 'Analyze X-Ray' to see results
        4. View detected conditions with Grad-CAM heatmaps
        """
    )
    
    # Connect components
    analyze_btn.click(
        fn=predict,
        inputs=[input_image, threshold],
        outputs=[output_gradcam, output_image, output_text]
    )

if __name__ == "__main__":
    demo.launch()