VaneshDev commited on
Commit
1aa82f5
Β·
verified Β·
1 Parent(s): 4c15f23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -47
app.py CHANGED
@@ -11,38 +11,23 @@ from torchvision import transforms
11
  from torchcam.methods import SmoothGradCAMpp
12
  from torchcam.utils import overlay_mask
13
  import re
14
- import logging
15
 
16
- logging.basicConfig(level=logging.INFO)
17
- logger = logging.getLogger(__name__)
18
-
19
- # ---------------- MODEL SETUP ---------------- #
20
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
  MODEL = xrv.models.get_model("densenet121-res224-all").to(DEVICE)
22
  LABELS = MODEL.pathologies
23
-
24
- # Grad-CAM extractor (single-channel input)
25
  cam_extractor = SmoothGradCAMpp(MODEL, input_shape=(1, 224, 224))
26
 
27
- # ---------------- IMAGE HANDLING ---------------- #
28
  def preprocess_image(pil_img: Image.Image):
29
- """Convert to grayscale, normalize, resize for model"""
30
  if pil_img.mode != "L":
31
  pil_img = pil_img.convert("L")
32
-
33
  img_array = np.array(pil_img).astype(np.float32)
34
  img_array = xrv.datasets.normalize(img_array, 255)
35
-
36
- # Add channel dimension
37
- img_array = img_array[None, ...] # Shape: [1, H, W]
38
-
39
- transform = transforms.Compose([
40
- xrv.datasets.XRayCenterCrop(),
41
- xrv.datasets.XRayResizer(224)
42
- ])
43
- img_array = transform(img_array)
44
-
45
- # Convert to tensor
46
  tensor = torch.from_numpy(img_array).unsqueeze(0).to(DEVICE)
47
  tensor.requires_grad_(True)
48
  return tensor
@@ -57,56 +42,46 @@ def get_medical_advice(label):
57
  }
58
  return advice_dict.get(label, "Please consult a radiologist for further evaluation.")
59
 
 
60
  def analyse_xray(img: Image.Image):
61
  try:
62
  if img is None:
63
  return "Please upload an X-ray image.", None
64
-
65
- MODEL.train() # required for CAM to calculate gradients
66
  x = preprocess_image(img)
67
  output = MODEL(x)
68
  probs = torch.sigmoid(output)[0] * 100
69
-
70
- # Top 5 predictions
71
  topk = torch.topk(probs, 5)
72
-
73
  html = "<h3>🩻 Top Predictions</h3><table border='1'><tr><th>Condition</th><th>Confidence</th></tr>"
74
  for idx in topk.indices:
75
  html += f"<tr><td>{LABELS[idx]}</td><td>{probs[idx]:.1f}%</td></tr>"
76
  html += "</table><br>"
77
-
78
  top_label = LABELS[topk.indices[0]]
79
  advice = get_medical_advice(top_label)
80
  html += f"<b>Suggested Action for '{top_label}':</b> {advice}"
81
 
82
- # Grad-CAM
83
- cam = cam_extractor(topk.indices[0].item(), output)[0]
84
- img_vis = img.convert("RGB").resize((224, 224))
85
- heat_img = overlay_mask(img_vis, Image.fromarray((cam.cpu().numpy() * 255).astype(np.uint8)), alpha=0.5)
86
-
87
  MODEL.eval()
88
  return html, heat_img
89
-
90
  except Exception as e:
91
- logger.error(e)
92
  return f"Error processing image: {str(e)}", None
93
 
94
- # ---------------- PDF HANDLING ---------------- #
95
  def analyse_report(file):
96
  try:
97
  if file is None:
98
  return "Please upload a PDF report."
99
-
100
- # Use file.name instead of .read()
101
  doc = fitz.open(file.name)
102
  text = "\n".join(page.get_text() for page in doc)
103
  doc.close()
104
-
105
  found = []
106
  for label in LABELS:
107
  if re.search(rf"\b{label.lower()}\b", text.lower()):
108
  found.append(label)
109
-
110
  if found:
111
  html = "<h3>πŸ“„ Detected Conditions</h3><ul>"
112
  for label in found:
@@ -114,14 +89,11 @@ def analyse_report(file):
114
  html += "</ul>"
115
  else:
116
  html = "<p>No specific conditions found in the report.</p>"
117
-
118
  return html
119
-
120
  except Exception as e:
121
- logger.error(e)
122
  return f"Error processing PDF: {str(e)}"
123
 
124
- # ---------------- UI ---------------- #
125
  with gr.Blocks(title="RadiologyScan AI", theme=gr.themes.Soft()) as demo:
126
  gr.Markdown("## 🩻 RadiologyScan AI\nUpload an X-ray or PDF report for AI-assisted analysis")
127
 
@@ -130,20 +102,16 @@ with gr.Blocks(title="RadiologyScan AI", theme=gr.themes.Soft()) as demo:
130
  xray_input = gr.Image(label="Upload Chest X-ray", type="pil")
131
  xray_html = gr.HTML()
132
  xray_cam = gr.Image(label="AI Heatmap")
133
-
134
  analyse_btn = gr.Button("Analyze X-ray")
135
  clear_xray = gr.Button("Clear")
136
-
137
  analyse_btn.click(analyse_xray, inputs=xray_input, outputs=[xray_html, xray_cam])
138
  clear_xray.click(lambda: (None, "", None), None, outputs=[xray_input, xray_html, xray_cam])
139
 
140
  with gr.Tab("πŸ“„ Report Analysis"):
141
  pdf_input = gr.File(label="Upload PDF report", file_types=[".pdf"])
142
  pdf_html = gr.HTML()
143
-
144
  analyse_pdf_btn = gr.Button("Analyze Report")
145
  clear_pdf = gr.Button("Clear")
146
-
147
  analyse_pdf_btn.click(analyse_report, inputs=pdf_input, outputs=pdf_html)
148
  clear_pdf.click(lambda: (None, ""), None, outputs=[pdf_input, pdf_html])
149
 
 
11
  from torchcam.methods import SmoothGradCAMpp
12
  from torchcam.utils import overlay_mask
13
  import re
 
14
 
15
+ # --- Model Setup ---
 
 
 
16
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  MODEL = xrv.models.get_model("densenet121-res224-all").to(DEVICE)
18
  LABELS = MODEL.pathologies
 
 
19
  cam_extractor = SmoothGradCAMpp(MODEL, input_shape=(1, 224, 224))
20
 
21
+ # --- Preprocessing ---
22
  def preprocess_image(pil_img: Image.Image):
23
+ """Convert image to grayscale, normalize and resize for model"""
24
  if pil_img.mode != "L":
25
  pil_img = pil_img.convert("L")
 
26
  img_array = np.array(pil_img).astype(np.float32)
27
  img_array = xrv.datasets.normalize(img_array, 255)
28
+ img_array = img_array[None, ...] # [1, H, W]
29
+ img_array = xrv.datasets.XRayCenterCrop()(img_array)
30
+ img_array = xrv.datasets.XRayResizer(224)(img_array)
 
 
 
 
 
 
 
 
31
  tensor = torch.from_numpy(img_array).unsqueeze(0).to(DEVICE)
32
  tensor.requires_grad_(True)
33
  return tensor
 
42
  }
43
  return advice_dict.get(label, "Please consult a radiologist for further evaluation.")
44
 
45
+ # --- X-ray Analysis ---
46
  def analyse_xray(img: Image.Image):
47
  try:
48
  if img is None:
49
  return "Please upload an X-ray image.", None
50
+ MODEL.train() # Enable gradients for CAM
 
51
  x = preprocess_image(img)
52
  output = MODEL(x)
53
  probs = torch.sigmoid(output)[0] * 100
 
 
54
  topk = torch.topk(probs, 5)
 
55
  html = "<h3>🩻 Top Predictions</h3><table border='1'><tr><th>Condition</th><th>Confidence</th></tr>"
56
  for idx in topk.indices:
57
  html += f"<tr><td>{LABELS[idx]}</td><td>{probs[idx]:.1f}%</td></tr>"
58
  html += "</table><br>"
 
59
  top_label = LABELS[topk.indices[0]]
60
  advice = get_medical_advice(top_label)
61
  html += f"<b>Suggested Action for '{top_label}':</b> {advice}"
62
 
63
+ # Grad-CAM overlay
64
+ cam = cam_extractor(topk.indices[0].item(), output)[0] # 2D, (224,224)
65
+ img_rgb = img.convert("RGB").resize((224, 224))
66
+ cam_img = Image.fromarray((cam.cpu().numpy() * 255).astype(np.uint8))
67
+ heat_img = overlay_mask(img_rgb, cam_img, alpha=0.5)
68
  MODEL.eval()
69
  return html, heat_img
 
70
  except Exception as e:
 
71
  return f"Error processing image: {str(e)}", None
72
 
73
+ # --- PDF Report Analysis ---
74
  def analyse_report(file):
75
  try:
76
  if file is None:
77
  return "Please upload a PDF report."
 
 
78
  doc = fitz.open(file.name)
79
  text = "\n".join(page.get_text() for page in doc)
80
  doc.close()
 
81
  found = []
82
  for label in LABELS:
83
  if re.search(rf"\b{label.lower()}\b", text.lower()):
84
  found.append(label)
 
85
  if found:
86
  html = "<h3>πŸ“„ Detected Conditions</h3><ul>"
87
  for label in found:
 
89
  html += "</ul>"
90
  else:
91
  html = "<p>No specific conditions found in the report.</p>"
 
92
  return html
 
93
  except Exception as e:
 
94
  return f"Error processing PDF: {str(e)}"
95
 
96
+ # --- Gradio UI ---
97
  with gr.Blocks(title="RadiologyScan AI", theme=gr.themes.Soft()) as demo:
98
  gr.Markdown("## 🩻 RadiologyScan AI\nUpload an X-ray or PDF report for AI-assisted analysis")
99
 
 
102
  xray_input = gr.Image(label="Upload Chest X-ray", type="pil")
103
  xray_html = gr.HTML()
104
  xray_cam = gr.Image(label="AI Heatmap")
 
105
  analyse_btn = gr.Button("Analyze X-ray")
106
  clear_xray = gr.Button("Clear")
 
107
  analyse_btn.click(analyse_xray, inputs=xray_input, outputs=[xray_html, xray_cam])
108
  clear_xray.click(lambda: (None, "", None), None, outputs=[xray_input, xray_html, xray_cam])
109
 
110
  with gr.Tab("πŸ“„ Report Analysis"):
111
  pdf_input = gr.File(label="Upload PDF report", file_types=[".pdf"])
112
  pdf_html = gr.HTML()
 
113
  analyse_pdf_btn = gr.Button("Analyze Report")
114
  clear_pdf = gr.Button("Clear")
 
115
  analyse_pdf_btn.click(analyse_report, inputs=pdf_input, outputs=pdf_html)
116
  clear_pdf.click(lambda: (None, ""), None, outputs=[pdf_input, pdf_html])
117