VaneshDev commited on
Commit
2285f01
Β·
verified Β·
1 Parent(s): 1aa82f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -59
app.py CHANGED
@@ -8,112 +8,120 @@ from PIL import Image
8
  import fitz # PyMuPDF
9
  import torchxrayvision as xrv
10
  from torchvision import transforms
11
- from torchcam.methods import SmoothGradCAMpp
12
- from torchcam.utils import overlay_mask
13
  import re
14
 
15
- # --- Model Setup ---
16
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
- MODEL = xrv.models.get_model("densenet121-res224-all").to(DEVICE)
18
  LABELS = MODEL.pathologies
19
- cam_extractor = SmoothGradCAMpp(MODEL, input_shape=(1, 224, 224))
20
 
21
- # --- Preprocessing ---
22
- def preprocess_image(pil_img: Image.Image):
23
- """Convert image to grayscale, normalize and resize for model"""
24
  if pil_img.mode != "L":
25
  pil_img = pil_img.convert("L")
 
26
  img_array = np.array(pil_img).astype(np.float32)
27
- img_array = xrv.datasets.normalize(img_array, 255)
28
  img_array = img_array[None, ...] # [1, H, W]
29
- img_array = xrv.datasets.XRayCenterCrop()(img_array)
30
- img_array = xrv.datasets.XRayResizer(224)(img_array)
 
 
 
 
 
 
31
  tensor = torch.from_numpy(img_array).unsqueeze(0).to(DEVICE)
32
- tensor.requires_grad_(True)
33
  return tensor
34
 
35
- def get_medical_advice(label):
36
- advice_dict = {
37
- "Pneumonia": "Consider antibiotics. Consult a pulmonologist.",
38
- "Cardiomegaly": "Recommend echocardiogram and cardiologist review.",
39
- "Effusion": "Pleural fluid detected. May require thoracentesis.",
40
- "Fracture": "Possible bone injury. Orthopedic consultation needed.",
41
- "Edema": "Fluid in lungs. Evaluate for heart failure.",
42
- }
43
- return advice_dict.get(label, "Please consult a radiologist for further evaluation.")
44
-
45
- # --- X-ray Analysis ---
 
 
46
  def analyse_xray(img: Image.Image):
47
  try:
48
  if img is None:
49
  return "Please upload an X-ray image.", None
50
- MODEL.train() # Enable gradients for CAM
51
  x = preprocess_image(img)
52
- output = MODEL(x)
53
- probs = torch.sigmoid(output)[0] * 100
 
 
54
  topk = torch.topk(probs, 5)
55
- html = "<h3>🩻 Top Predictions</h3><table border='1'><tr><th>Condition</th><th>Confidence</th></tr>"
56
  for idx in topk.indices:
57
  html += f"<tr><td>{LABELS[idx]}</td><td>{probs[idx]:.1f}%</td></tr>"
58
  html += "</table><br>"
59
- top_label = LABELS[topk.indices[0]]
60
- advice = get_medical_advice(top_label)
61
- html += f"<b>Suggested Action for '{top_label}':</b> {advice}"
62
-
63
- # Grad-CAM overlay
64
- cam = cam_extractor(topk.indices[0].item(), output)[0] # 2D, (224,224)
65
- img_rgb = img.convert("RGB").resize((224, 224))
66
- cam_img = Image.fromarray((cam.cpu().numpy() * 255).astype(np.uint8))
67
- heat_img = overlay_mask(img_rgb, cam_img, alpha=0.5)
68
- MODEL.eval()
69
- return html, heat_img
70
  except Exception as e:
71
  return f"Error processing image: {str(e)}", None
72
 
73
- # --- PDF Report Analysis ---
74
  def analyse_report(file):
75
  try:
76
  if file is None:
77
  return "Please upload a PDF report."
 
78
  doc = fitz.open(file.name)
79
  text = "\n".join(page.get_text() for page in doc)
80
  doc.close()
 
81
  found = []
82
  for label in LABELS:
83
  if re.search(rf"\b{label.lower()}\b", text.lower()):
84
  found.append(label)
 
85
  if found:
86
- html = "<h3>πŸ“„ Detected Conditions</h3><ul>"
87
  for label in found:
88
- html += f"<li><b>{label}</b>: {get_medical_advice(label)}</li>"
89
  html += "</ul>"
90
  else:
91
- html = "<p>No specific conditions found in the report.</p>"
 
92
  return html
93
  except Exception as e:
94
  return f"Error processing PDF: {str(e)}"
95
 
96
  # --- Gradio UI ---
97
- with gr.Blocks(title="RadiologyScan AI", theme=gr.themes.Soft()) as demo:
98
- gr.Markdown("## 🩻 RadiologyScan AI\nUpload an X-ray or PDF report for AI-assisted analysis")
99
 
100
  with gr.Tabs():
101
  with gr.Tab("πŸ” X-ray Analysis"):
102
- xray_input = gr.Image(label="Upload Chest X-ray", type="pil")
103
- xray_html = gr.HTML()
104
- xray_cam = gr.Image(label="AI Heatmap")
105
- analyse_btn = gr.Button("Analyze X-ray")
106
- clear_xray = gr.Button("Clear")
107
- analyse_btn.click(analyse_xray, inputs=xray_input, outputs=[xray_html, xray_cam])
108
- clear_xray.click(lambda: (None, "", None), None, outputs=[xray_input, xray_html, xray_cam])
109
-
110
- with gr.Tab("πŸ“„ Report Analysis"):
111
- pdf_input = gr.File(label="Upload PDF report", file_types=[".pdf"])
112
- pdf_html = gr.HTML()
113
- analyse_pdf_btn = gr.Button("Analyze Report")
114
- clear_pdf = gr.Button("Clear")
115
- analyse_pdf_btn.click(analyse_report, inputs=pdf_input, outputs=pdf_html)
116
- clear_pdf.click(lambda: (None, ""), None, outputs=[pdf_input, pdf_html])
 
 
 
117
 
118
  if __name__ == "__main__":
119
  demo.launch(server_port=int(os.getenv("PORT", 7860)), show_error=True)
 
8
  import fitz # PyMuPDF
9
  import torchxrayvision as xrv
10
  from torchvision import transforms
 
 
11
  import re
12
 
13
+ # --- Device & Model ---
14
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ MODEL = xrv.models.get_model("densenet121-res224-all").to(DEVICE).eval()
16
  LABELS = MODEL.pathologies
 
17
 
18
+ # --- Image Preprocessing ---
19
+ def preprocess_image(pil_img: Image.Image) -> torch.Tensor:
20
+ """Convert to grayscale, normalize, and resize for model."""
21
  if pil_img.mode != "L":
22
  pil_img = pil_img.convert("L")
23
+
24
  img_array = np.array(pil_img).astype(np.float32)
25
+ img_array = xrv.datasets.normalize(img_array, 255) # normalize [-1024, 1024]
26
  img_array = img_array[None, ...] # [1, H, W]
27
+
28
+ # Apply center crop and resize
29
+ transform = transforms.Compose([
30
+ xrv.datasets.XRayCenterCrop(),
31
+ xrv.datasets.XRayResizer(224)
32
+ ])
33
+ img_array = transform(img_array)
34
+
35
  tensor = torch.from_numpy(img_array).unsqueeze(0).to(DEVICE)
 
36
  return tensor
37
 
38
+ # --- Medical Recommendations ---
39
+ ADVICE = {
40
+ "Pneumonia": "Possible infection. Recommend antibiotics and pulmonology consult.",
41
+ "Cardiomegaly": "Enlarged heart. Recommend echocardiography and cardiologist review.",
42
+ "Effusion": "Fluid in lung space. May need thoracentesis.",
43
+ "Fracture": "Possible bone break. Requires orthopedic consultation.",
44
+ "Edema": "Pulmonary fluid overload. Evaluate for heart failure.",
45
+ }
46
+
47
+ def get_advice(label):
48
+ return ADVICE.get(label, "Please consult a radiologist for further evaluation.")
49
+
50
+ # --- X-ray Analysis (No CAM) ---
51
  def analyse_xray(img: Image.Image):
52
  try:
53
  if img is None:
54
  return "Please upload an X-ray image.", None
55
+
56
  x = preprocess_image(img)
57
+ with torch.no_grad():
58
+ output = MODEL(x)
59
+ probs = torch.sigmoid(output)[0] * 100 # convert to percent
60
+
61
  topk = torch.topk(probs, 5)
62
+ html = "<h3>🩺 Top 5 Predictions</h3><table border='1'><tr><th>Condition</th><th>Confidence</th></tr>"
63
  for idx in topk.indices:
64
  html += f"<tr><td>{LABELS[idx]}</td><td>{probs[idx]:.1f}%</td></tr>"
65
  html += "</table><br>"
66
+
67
+ top_label = LABELS[topk.indices[0].item()]
68
+ html += f"<b>Recommended Action for '{top_label}':</b> {get_advice(top_label)}"
69
+
70
+ return html, img.resize((224, 224)) # return resized image for display
 
 
 
 
 
 
71
  except Exception as e:
72
  return f"Error processing image: {str(e)}", None
73
 
74
+ # --- Report PDF Analysis ---
75
  def analyse_report(file):
76
  try:
77
  if file is None:
78
  return "Please upload a PDF report."
79
+
80
  doc = fitz.open(file.name)
81
  text = "\n".join(page.get_text() for page in doc)
82
  doc.close()
83
+
84
  found = []
85
  for label in LABELS:
86
  if re.search(rf"\b{label.lower()}\b", text.lower()):
87
  found.append(label)
88
+
89
  if found:
90
+ html = "<h3>πŸ“ƒ Findings Detected in Report:</h3><ul>"
91
  for label in found:
92
+ html += f"<li><b>{label}</b>: {get_advice(label)}</li>"
93
  html += "</ul>"
94
  else:
95
+ html = "<p>No known conditions detected from report text.</p>"
96
+
97
  return html
98
  except Exception as e:
99
  return f"Error processing PDF: {str(e)}"
100
 
101
  # --- Gradio UI ---
102
+ with gr.Blocks(title="🩻 RadiologyScan AI") as demo:
103
+ gr.Markdown("## 🩻 RadiologyScan AI\nPerform fast AI-based analysis of Chest X-rays and medical reports")
104
 
105
  with gr.Tabs():
106
  with gr.Tab("πŸ” X-ray Analysis"):
107
+ x_input = gr.Image(label="Upload Chest X-ray", type="pil")
108
+ x_out_html = gr.HTML()
109
+ x_out_image = gr.Image(label="Resized X-ray (224x224)")
110
+
111
+ analyze_btn = gr.Button("Analyze X-ray")
112
+ clear_btn = gr.Button("Clear")
113
+
114
+ analyze_btn.click(analyse_xray, inputs=x_input, outputs=[x_out_html, x_out_image])
115
+ clear_btn.click(lambda: (None, "", None), None, [x_input, x_out_html, x_out_image])
116
+
117
+ with gr.Tab("πŸ“„ PDF Report Analysis"):
118
+ pdf_input = gr.File(file_types=[".pdf"], label="Upload PDF Medical Report")
119
+ pdf_output = gr.HTML()
120
+ analyze_pdf_btn = gr.Button("Analyze Report")
121
+ clear_pdf_btn = gr.Button("Clear")
122
+
123
+ analyze_pdf_btn.click(analyse_report, inputs=pdf_input, outputs=pdf_output)
124
+ clear_pdf_btn.click(lambda: (None, ""), None, [pdf_input, pdf_output])
125
 
126
  if __name__ == "__main__":
127
  demo.launch(server_port=int(os.getenv("PORT", 7860)), show_error=True)