File size: 13,371 Bytes
9748112 e422ed7 9748112 448ea49 9748112 e422ed7 4a34168 e422ed7 4a34168 e422ed7 4a34168 e422ed7 9748112 e422ed7 9748112 e422ed7 9748112 e422ed7 9748112 e422ed7 9748112 e422ed7 9748112 e422ed7 9748112 e422ed7 9748112 e422ed7 9748112 e422ed7 9748112 e422ed7 9748112 e422ed7 9748112 e422ed7 9748112 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 |
import gradio as gr
import torch
from torchvision import models, transforms
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
from PIL import Image
import numpy as np
from skimage.transform import resize
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
# Constants
REPO_ID = "itsomk/chexpert-densenet121"
FILENAME = "pytorch_model.safetensors"
# Model Definition
class DenseNet121_CheXpert(torch.nn.Module):
def __init__(self, num_labels=14, pretrained=None):
super().__init__()
self.densenet = models.densenet121(weights=pretrained)
num_features = self.densenet.classifier.in_features
self.densenet.classifier = torch.nn.Linear(num_features, num_labels)
def forward(self, x):
return self.densenet(x)
# Labels
LABELS = [
"No Finding", "Enlarged Cardiomediastinum", "Cardiomegaly", "Lung Opacity",
"Lung Lesion", "Edema", "Consolidation", "Pneumonia", "Atelectasis",
"Pneumothorax", "Pleural Effusion", "Pleural Other", "Fracture", "Support Devices"
]
# Label Descriptions for Report
label_descriptions = {
"No Finding": "No significant cardiopulmonary abnormality is identified.",
"Enlarged Cardiomediastinum": "The cardiomediastinal silhouette appears enlarged, which may reflect cardiac or mediastinal pathology.",
"Cardiomegaly": "The cardiac silhouette is enlarged, which may be seen in a variety of cardiac conditions including cardiomyopathy or volume overload.",
"Lung Opacity": "There are areas of increased lung opacity, which may represent infection, inflammation, or other parenchymal processes.",
"Lung Lesion": "There is a focal abnormality in the lung that may represent an underlying lesion and may warrant further evaluation.",
"Edema": "The pulmonary parenchyma demonstrates changes that may represent pulmonary edema.",
"Consolidation": "There is focal or multifocal consolidation compatible with alveolar filling, such as infection or aspiration.",
"Pneumonia": "The pattern of opacities is suspicious for pneumonia in the appropriate clinical context.",
"Atelectasis": "There is volume loss with increased opacity, which may represent atelectasis.",
"Pneumothorax": "There is suspicion for pneumothorax, which represents air within the pleural space and may be clinically significant.",
"Pleural Effusion": "There is fluid in the pleural space, which may compress the adjacent lung parenchyma.",
"Pleural Other": "There are pleural abnormalities that may represent pleural thickening, plaques, or other pleural processes.",
"Fracture": "There is suspicion of osseous fracture, which may require correlation with dedicated imaging and clinical findings.",
"Support Devices": "Support devices are present (e.g. lines, tubes, pacemaker leads) which should be correlated with position and clinical need.",
}
LABEL_THRESHOLDS = {
"No Finding": 0.5,
"Cardiomegaly": 0.6,
"Pneumothorax": 0.6,
"Pleural Effusion": 0.5,
"Fracture": 0.6
}
# Preprocessing
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load model
print("Loading model...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
local_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
state = load_file(local_path)
model = DenseNet121_CheXpert(num_labels=14, pretrained=None)
model.load_state_dict(state, strict=False)
model.to(device)
model.eval()
if device.type=='cuda':
print(f"Model loaded successfully on GPU {torch.cuda.get_device_name(torch.cuda.current_device())}")
else:
print(f"Model loaded successfully on CPU")
# Report Generation Functions
def prob_to_phrase(p: float) -> str:
if p >= 0.8:
return "highly suggestive of"
elif p >= 0.6:
return "likely"
else:
return "may represent"
def rule_based_labeling(probs, default_threshold: float = 0.5):
if len(probs) != len(LABELS):
raise ValueError(f"Expected {len(LABELS)} probabilities, got {len(probs)}")
selected = []
for i, prob in enumerate(probs):
label = LABELS[i]
th = LABEL_THRESHOLDS.get(label, default_threshold)
if prob >= th:
selected.append((i, prob))
return selected
def handle_no_finding(selected):
label_names = [LABELS[i] for i, _ in selected]
if "No Finding" in label_names and len(label_names) > 1:
selected = [(i, p) for (i, p) in selected if LABELS[i] != "No Finding"]
return selected
def remove_redundant_labels(selected):
name_to_prob = {LABELS[i]: p for i, p in selected}
if "Pneumonia" in name_to_prob and "Lung Opacity" in name_to_prob:
selected = [(i, p) for (i, p) in selected if LABELS[i] != "Lung Opacity"]
name_to_prob = {LABELS[i]: p for i, p in selected}
if "Consolidation" in name_to_prob and "Lung Opacity" in name_to_prob:
selected = [(i, p) for (i, p) in selected if LABELS[i] != "Lung Opacity"]
name_to_prob = {LABELS[i]: p for i, p in selected}
if "Pleural Effusion" in name_to_prob and "Pleural Other" in name_to_prob:
selected = [(i, p) for (i, p) in selected if LABELS[i] != "Pleural Other"]
return selected
def build_impression_from_labels(selected):
name_to_prob = {LABELS[i]: p for i, p in selected}
lines = []
has_edema = "Edema" in name_to_prob
has_peff = "Pleural Effusion" in name_to_prob
has_consolidation = "Consolidation" in name_to_prob
has_pneumonia = "Pneumonia" in name_to_prob
has_atelectasis = "Atelectasis" in name_to_prob
if has_edema and has_peff:
lines.append("Pattern consistent with pulmonary edema with associated pleural effusions.")
elif has_edema:
lines.append("Pattern consistent with pulmonary edema.")
elif has_peff:
lines.append("Pleural effusion is suspected, which may be clinically significant.")
if has_pneumonia and has_atelectasis:
lines.append("Focal pulmonary opacity suspicious for pneumonia, atelectasis remains a differential consideration.")
elif has_pneumonia or has_consolidation:
lines.append("Focal pulmonary opacity is suspicious for pneumonia in the appropriate clinical context.")
elif has_atelectasis:
lines.append("Areas of volume loss may represent atelectasis.")
if "Cardiomegaly" in name_to_prob:
lines.append("Cardiac silhouette appears enlarged, correlate clinically for cardiomegaly.")
if "Support Devices" in name_to_prob:
lines.append("Support devices/tubes are present, correlate with clinical indication and positioning.")
if not lines:
for i, p in selected:
label = LABELS[i]
phrase = prob_to_phrase(p)
lines.append(f"{phrase} {label.lower()}.")
return "Impression:\n- " + "\n- ".join(lines)
def generate_textual_report(probs, default_threshold: float = 0.5, top_k: int = None) -> str:
selected = rule_based_labeling(probs, default_threshold)
if not selected:
return (
"Findings:\n"
"No significant cardiopulmonary abnormality is identified by the model.\n\n"
"Impression:\n"
"No acute cardiopulmonary process detected by the model."
)
selected = handle_no_finding(selected)
selected = remove_redundant_labels(selected)
selected.sort(key=lambda x: x[1], reverse=True)
if top_k is not None:
selected = selected[:top_k]
findings_lines = []
for idx, prob in selected:
label = LABELS[idx]
description = label_descriptions.get(label, "")
phrase = prob_to_phrase(prob)
prob_pct = int(round(prob * 100))
findings_lines.append(f"- {label}: {description}.")
findings_text = "Findings:\n" + "\n".join(findings_lines)
impression_text = build_impression_from_labels(selected)
return findings_text + "\n\n" + impression_text
def predict(image, threshold):
"""Generate predictions, Grad-CAM visualizations, and report"""
if image is None:
return None, None, "Please upload an X-ray image", ""
try:
# Convert to PIL Image
if isinstance(image, np.ndarray):
img = Image.fromarray(image).convert("RGB")
else:
img = image.convert("RGB")
# Preprocess
img_tensor = preprocess(img).unsqueeze(0).to(device)
rgb_img = np.array(img.resize((224, 224)), dtype=np.float32) / 255.0
# Get predictions
with torch.no_grad():
logits = model(img_tensor)
probs = torch.sigmoid(logits).squeeze().cpu().numpy()
# Setup Grad-CAM
target_layer = model.densenet.features.denseblock4
cam = GradCAM(model=model, target_layers=[target_layer])
# Generate visualizations for conditions above threshold
gradcam_images = []
detected_conditions = []
for i, prob in enumerate(probs):
if prob > threshold:
label = LABELS[i]
targets = [ClassifierOutputTarget(i)]
grayscale_cam = cam(input_tensor=img_tensor, targets=targets)
grayscale_cam = grayscale_cam[0, :]
resized_rgb_img = resize(rgb_img, grayscale_cam.shape, anti_aliasing=True)
cam_image = show_cam_on_image(resized_rgb_img, grayscale_cam, use_rgb=True)
gradcam_images.append(cam_image)
detected_conditions.append(f"**{label}**: {prob:.4f}")
# Create summary text
all_predictions = "\n".join([f"{LABELS[i]}: {prob:.4f}" for i, prob in enumerate(probs)])
# Generate textual report
report = generate_textual_report(probs, default_threshold=0.5, top_k=5)
if detected_conditions:
summary = f"## Detected Conditions (>{threshold}):\n" + "\n".join(detected_conditions)
summary += f"\n\n## All Predictions:\n{all_predictions}"
return gradcam_images[0], img, summary, report
else:
summary = f"No conditions detected above threshold {threshold}\n\n## All Predictions:\n{all_predictions}"
return None, img, summary, report
except Exception as e:
return None, None, f"Error: {str(e)}", ""
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🩻 X-Ray Grad-CAM Visualization with Report Generation
Upload a chest X-ray image to analyze potential conditions using DenseNet121 with Grad-CAM visualization.
**Model**: [itsomk/chexpert-densenet121](https://huggingface.co/itsomk/chexpert-densenet121)
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Upload X-Ray Image", type="pil")
threshold = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.5,
step=0.05,
label="Prediction Threshold"
)
analyze_btn = gr.Button("🔍 Analyze X-Ray", variant="primary", size="lg")
with gr.Column():
output_gradcam = gr.Image(label="Grad-CAM Visualization")
output_image = gr.Image(label="Original Image")
with gr.Row():
output_text = gr.Markdown(label="Analysis Results")
# Report Section
with gr.Row():
with gr.Column():
gr.Markdown("## 📋 Generated Report")
output_report = gr.Textbox(
label="Clinical Report",
lines=12,
max_lines=20,
show_copy_button=True
)
download_btn = gr.DownloadButton(
label="📥 Download Report",
visible=True
)
# Instructions
gr.Markdown("### 📋 Instructions:")
gr.Markdown(
"""
1. Upload a chest X-ray image (JPG, PNG)
2. Adjust the prediction threshold if needed (default: 0.5)
3. Click 'Analyze X-Ray' to see results
4. View detected conditions with Grad-CAM heatmaps
5. Review the generated clinical report
6. Download the report as a text file if needed
"""
)
# Connect components
def analyze_and_prepare_download(image, threshold):
gradcam, original, summary, report = predict(image, threshold)
# Prepare file for download
if report:
report_file = "xray_report.txt"
with open(report_file, "w") as f:
f.write(report)
return gradcam, original, summary, report, gr.DownloadButton(value=report_file, visible=True)
else:
return gradcam, original, summary, report, gr.DownloadButton(visible=False)
analyze_btn.click(
fn=analyze_and_prepare_download,
inputs=[input_image, threshold],
outputs=[output_gradcam, output_image, output_text, output_report, download_btn]
)
if __name__ == "__main__":
demo.launch() |