VaneshDev commited on
Commit
56914ac
Β·
verified Β·
1 Parent(s): 936353d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -214
app.py CHANGED
@@ -1,262 +1,151 @@
1
- """
2
- RadiologyScan AI – X-ray & Report analyser
3
- """
4
-
5
  import os
6
- # Fix for PyTorch 2.6 weights_only issue
7
  os.environ['TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD'] = '1'
8
 
9
- import re, logging, tempfile
 
10
  import gradio as gr
11
  from PIL import Image
12
- import torch
13
- import torch.nn.functional as F
14
- from torchvision import transforms
15
- import torchxrayvision as xrv
16
  import fitz # PyMuPDF
 
 
17
  from torchcam.methods import SmoothGradCAMpp
18
- from transformers import pipeline
19
- import numpy as np
 
20
 
21
  logging.basicConfig(level=logging.INFO)
22
- log = logging.getLogger(__name__)
23
 
24
- # Load model - IMPORTANT: Don't call .eval() here for CAM to work
25
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
  MODEL = xrv.models.get_model("densenet121-res224-all").to(DEVICE)
27
- # Note: We'll switch between train/eval modes as needed
28
  LABELS = MODEL.pathologies
29
 
30
- # Initialize CAM extractor with correct input shape for grayscale
31
  cam_extractor = SmoothGradCAMpp(MODEL, input_shape=(1, 224, 224))
32
 
33
- def preprocess_xray(pil_img: Image.Image) -> torch.Tensor:
34
- """
35
- Preprocess PIL image for TorchXRayVision model
36
- """
37
- # Convert to grayscale if needed
38
  if pil_img.mode != "L":
39
  pil_img = pil_img.convert("L")
40
-
41
- # Convert to numpy array
42
- img_array = np.array(pil_img, dtype=np.float32)
43
-
44
- # Normalize to [-1024, 1024] range (TorchXRayVision standard)
45
  img_array = xrv.datasets.normalize(img_array, 255)
46
-
47
  # Add channel dimension
48
- img_array = img_array[None, ...]
49
-
50
- # Use TorchXRayVision transforms
51
  transform = transforms.Compose([
52
  xrv.datasets.XRayCenterCrop(),
53
  xrv.datasets.XRayResizer(224)
54
  ])
55
-
56
  img_array = transform(img_array)
57
-
58
- # Convert to tensor with gradient enabled
59
- img_tensor = torch.from_numpy(img_array).unsqueeze(0).to(DEVICE)
60
- img_tensor.requires_grad_(True) # Enable gradients for CAM
61
-
62
- return img_tensor
63
 
64
- def analyse_xray(img: Image.Image):
65
- if img is None:
66
- return "Please upload an image.", None
 
 
 
 
 
 
 
 
 
 
 
67
 
 
68
  try:
69
- # Preprocess image
70
- x = preprocess_xray(img)
71
-
72
- # Set model to train mode temporarily for gradient computation
73
- MODEL.train()
74
-
75
- # Forward pass with gradient tracking
76
- logits = MODEL(x)
77
- probs = torch.sigmoid(logits)[0] * 100 # Convert to percentages
78
-
79
- # Get top 5 predictions
80
  topk = torch.topk(probs, 5)
81
-
82
- # Generate Grad-CAM heat-map for the highest scoring condition
83
- target = topk.indices[0].item()
84
-
85
- # Generate activation map
86
- activation_map = cam_extractor(target, logits)[0]
87
-
88
- # Convert single channel to 3-channel for overlay
89
- input_for_overlay = x.squeeze(0).cpu().detach()
90
- input_for_overlay = input_for_overlay.repeat(3, 1, 1)
91
-
92
- # Generate heatmap
93
- heatmap = cam_extractor.overlay(input_for_overlay, activation_map)
94
-
95
- # Set model back to eval mode
96
  MODEL.eval()
97
-
98
- # Build HTML summary table
99
- table_rows = ""
100
- for i in range(len(topk.indices)):
101
- idx = topk.indices[i].item()
102
- prob = probs[idx].item()
103
- condition = LABELS[idx]
104
- table_rows += f"<tr><td>{condition}</td><td>{prob:.1f}%</td></tr>"
105
-
106
- top_condition = LABELS[target]
107
- advice = get_medical_advice(top_condition)
108
-
109
- html_output = f"""
110
- <div style="font-family: Arial, sans-serif;">
111
- <h3>🩺 AI Analysis Results</h3>
112
- <table border="1" style="border-collapse: collapse; width: 100%;">
113
- <tr style="background-color: #f2f2f2;">
114
- <th style="padding: 8px; text-align: left;">Condition</th>
115
- <th style="padding: 8px; text-align: left;">Probability</th>
116
- </tr>
117
- {table_rows}
118
- </table>
119
- <br>
120
- <h4>πŸ” Top Finding: {top_condition}</h4>
121
- <p><strong>Recommendation:</strong> {advice}</p>
122
- <p><em>⚠️ This is an AI analysis tool for educational purposes only. Always consult qualified medical professionals for diagnosis and treatment.</em></p>
123
- </div>
124
- """
125
-
126
- return html_output, Image.fromarray(heatmap)
127
-
128
  except Exception as e:
129
- log.error(f"Error in X-ray analysis: {e}")
130
  return f"Error processing image: {str(e)}", None
131
 
132
- # Medical advice dictionary
133
- MEDICAL_ADVICE = {
134
- "Atelectasis": "Lung collapse detected. Recommend pulmonology consultation and chest physiotherapy.",
135
- "Cardiomegaly": "Enlarged heart detected. Recommend echocardiography and cardiology consultation.",
136
- "Consolidation": "Lung consolidation detected. May indicate pneumonia or other lung disease. Seek medical attention.",
137
- "Edema": "Pulmonary edema detected. Recommend urgent cardiology evaluation.",
138
- "Emphysema": "Emphysema changes detected. Recommend pulmonology consultation and smoking cessation if applicable.",
139
- "Fibrosis": "Lung fibrosis detected. Recommend pulmonology consultation for further evaluation.",
140
- "Hernia": "Hernia detected. Recommend surgical consultation if symptomatic.",
141
- "Infiltration": "Lung infiltration detected. May indicate infection or inflammation. Seek medical attention.",
142
- "Mass": "Lung mass detected. Recommend urgent oncology consultation and further imaging.",
143
- "Nodule": "Lung nodule detected. Recommend follow-up imaging and pulmonology consultation.",
144
- "Pleural_Thickening": "Pleural thickening detected. Recommend pulmonology consultation.",
145
- "Pneumonia": "Pneumonia detected. Recommend immediate antibiotic treatment and medical supervision.",
146
- "Pneumothorax": "Pneumothorax (collapsed lung) detected. May require immediate medical intervention.",
147
- "Effusion": "Pleural effusion detected. Recommend thoracentesis evaluation and pulmonology consultation."
148
- }
149
-
150
- def get_medical_advice(condition: str) -> str:
151
- return MEDICAL_ADVICE.get(condition, "Consult with a radiologist or pulmonologist for proper interpretation.")
152
-
153
  def analyse_report(file):
154
- if file is None:
155
- return "Please upload a PDF file."
156
-
157
  try:
158
- # Create temporary file
159
- with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
160
- tmp.write(file.read())
161
- tmp_path = tmp.name
162
-
163
- # Extract text from PDF
164
- doc = fitz.open(tmp_path)
165
  text = "\n".join(page.get_text() for page in doc)
166
  doc.close()
167
- os.unlink(tmp_path)
168
-
169
- # Simple pattern matching for common conditions
170
- detected_conditions = []
171
- for condition in LABELS:
172
- if re.search(rf'\b{condition.lower()}\b', text.lower()):
173
- detected_conditions.append(condition)
174
-
175
- if detected_conditions:
176
- html_output = "<h3>πŸ“‹ Report Analysis</h3>"
177
- html_output += "<h4>Detected Conditions:</h4><ul>"
178
- for condition in detected_conditions[:5]: # Show top 5
179
- advice = get_medical_advice(condition)
180
- html_output += f"<li><strong>{condition}</strong>: {advice}</li>"
181
- html_output += "</ul>"
182
- html_output += "<p><em>⚠️ This analysis is for educational purposes only. Consult medical professionals for proper diagnosis.</em></p>"
183
- return html_output
184
  else:
185
- return "<h3>πŸ“‹ Report Analysis</h3><p>No specific pathological conditions detected in the report text. Please consult with a medical professional for proper interpretation.</p>"
186
-
 
 
187
  except Exception as e:
188
- log.error(f"Error in report analysis: {e}")
189
  return f"Error processing PDF: {str(e)}"
190
 
191
- # Gradio interface
192
- with gr.Blocks(title="🩻 RadiologyScan AI", theme=gr.themes.Soft()) as demo:
193
- gr.Markdown("""
194
- # 🩻 RadiologyScan AI
195
- ### AI-Powered Chest X-ray and Medical Report Analysis
196
-
197
- **⚠️ IMPORTANT DISCLAIMER**: This tool is for educational and research purposes only.
198
- It should NOT be used for actual medical diagnosis or treatment decisions.
199
- Always consult qualified healthcare professionals for medical advice.
200
- """)
201
 
202
  with gr.Tabs():
203
  with gr.Tab("πŸ” X-ray Analysis"):
204
- gr.Markdown("Upload a chest X-ray image for AI analysis")
205
-
206
- with gr.Row():
207
- with gr.Column():
208
- img_input = gr.Image(label="Upload Chest X-ray", type="pil")
209
-
210
- with gr.Row():
211
- analyze_btn = gr.Button("πŸ” Analyze X-ray", variant="primary", size="lg")
212
- clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
213
-
214
- with gr.Column():
215
- html_output = gr.HTML(label="Analysis Results")
216
- cam_output = gr.Image(label="Attention Heatmap", type="pil")
217
-
218
- analyze_btn.click(
219
- fn=analyse_xray,
220
- inputs=img_input,
221
- outputs=[html_output, cam_output]
222
- )
223
-
224
- clear_btn.click(
225
- fn=lambda: (None, "", None),
226
- inputs=None,
227
- outputs=[img_input, html_output, cam_output]
228
- )
229
 
230
  with gr.Tab("πŸ“„ Report Analysis"):
231
- gr.Markdown("Upload a medical report PDF for AI analysis")
232
-
233
- with gr.Row():
234
- with gr.Column():
235
- pdf_input = gr.File(label="Upload PDF Report", file_types=[".pdf"])
236
-
237
- with gr.Row():
238
- analyze_report_btn = gr.Button("πŸ“„ Analyze Report", variant="primary", size="lg")
239
- clear_report_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
240
-
241
- with gr.Column():
242
- report_output = gr.HTML(label="Report Analysis")
243
-
244
- analyze_report_btn.click(
245
- fn=analyse_report,
246
- inputs=pdf_input,
247
- outputs=report_output
248
- )
249
-
250
- clear_report_btn.click(
251
- fn=lambda: (None, ""),
252
- inputs=None,
253
- outputs=[pdf_input, report_output]
254
- )
255
 
256
  if __name__ == "__main__":
257
- demo.launch(
258
- server_name="0.0.0.0",
259
- server_port=int(os.getenv("PORT", 7860)),
260
- show_error=True,
261
- share=False
262
- )
 
 
 
 
 
1
  import os
 
2
  os.environ['TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD'] = '1'
3
 
4
+ import numpy as np
5
+ import torch
6
  import gradio as gr
7
  from PIL import Image
 
 
 
 
8
  import fitz # PyMuPDF
9
+ import torchxrayvision as xrv
10
+ from torchvision import transforms
11
  from torchcam.methods import SmoothGradCAMpp
12
+ from torchcam.utils import overlay_mask
13
+ import re
14
+ import logging
15
 
16
  logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
 
19
+ # ---------------- MODEL SETUP ---------------- #
20
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
  MODEL = xrv.models.get_model("densenet121-res224-all").to(DEVICE)
 
22
  LABELS = MODEL.pathologies
23
 
24
+ # Grad-CAM extractor (single-channel input)
25
  cam_extractor = SmoothGradCAMpp(MODEL, input_shape=(1, 224, 224))
26
 
27
+ # ---------------- IMAGE HANDLING ---------------- #
28
+ def preprocess_image(pil_img: Image.Image):
29
+ """Convert to grayscale, normalize, resize for model"""
 
 
30
  if pil_img.mode != "L":
31
  pil_img = pil_img.convert("L")
32
+
33
+ img_array = np.array(pil_img).astype(np.float32)
 
 
 
34
  img_array = xrv.datasets.normalize(img_array, 255)
35
+
36
  # Add channel dimension
37
+ img_array = img_array[None, ...] # Shape: [1, H, W]
38
+
 
39
  transform = transforms.Compose([
40
  xrv.datasets.XRayCenterCrop(),
41
  xrv.datasets.XRayResizer(224)
42
  ])
 
43
  img_array = transform(img_array)
 
 
 
 
 
 
44
 
45
+ # Convert to tensor
46
+ tensor = torch.from_numpy(img_array).unsqueeze(0).to(DEVICE)
47
+ tensor.requires_grad_(True)
48
+ return tensor
49
+
50
+ def get_medical_advice(label):
51
+ advice_dict = {
52
+ "Pneumonia": "Consider antibiotics. Consult a pulmonologist.",
53
+ "Cardiomegaly": "Recommend echocardiogram and cardiologist review.",
54
+ "Effusion": "Pleural fluid detected. May require thoracentesis.",
55
+ "Fracture": "Possible bone injury. Orthopedic consultation needed.",
56
+ "Edema": "Fluid in lungs. Evaluate for heart failure.",
57
+ }
58
+ return advice_dict.get(label, "Please consult a radiologist for further evaluation.")
59
 
60
+ def analyse_xray(img: Image.Image):
61
  try:
62
+ if img is None:
63
+ return "Please upload an X-ray image.", None
64
+
65
+ MODEL.train() # required for CAM to calculate gradients
66
+ x = preprocess_image(img)
67
+ output = MODEL(x)
68
+ probs = torch.sigmoid(output)[0] * 100
69
+
70
+ # Top 5 predictions
 
 
71
  topk = torch.topk(probs, 5)
72
+
73
+ html = "<h3>🩻 Top Predictions</h3><table border='1'><tr><th>Condition</th><th>Confidence</th></tr>"
74
+ for idx in topk.indices:
75
+ html += f"<tr><td>{LABELS[idx]}</td><td>{probs[idx]:.1f}%</td></tr>"
76
+ html += "</table><br>"
77
+
78
+ top_label = LABELS[topk.indices[0]]
79
+ advice = get_medical_advice(top_label)
80
+ html += f"<b>Suggested Action for '{top_label}':</b> {advice}"
81
+
82
+ # Grad-CAM
83
+ cam = cam_extractor(topk.indices[0].item(), output)[0]
84
+ img_vis = img.convert("RGB").resize((224, 224))
85
+ heat_img = overlay_mask(img_vis, Image.fromarray((cam.cpu().numpy() * 255).astype(np.uint8)), alpha=0.5)
86
+
87
  MODEL.eval()
88
+ return html, heat_img
89
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  except Exception as e:
91
+ logger.error(e)
92
  return f"Error processing image: {str(e)}", None
93
 
94
+ # ---------------- PDF HANDLING ---------------- #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  def analyse_report(file):
 
 
 
96
  try:
97
+ if file is None:
98
+ return "Please upload a PDF report."
99
+
100
+ # Use file.name instead of .read()
101
+ doc = fitz.open(file.name)
 
 
102
  text = "\n".join(page.get_text() for page in doc)
103
  doc.close()
104
+
105
+ found = []
106
+ for label in LABELS:
107
+ if re.search(rf"\b{label.lower()}\b", text.lower()):
108
+ found.append(label)
109
+
110
+ if found:
111
+ html = "<h3>πŸ“„ Detected Conditions</h3><ul>"
112
+ for label in found:
113
+ html += f"<li><b>{label}</b>: {get_medical_advice(label)}</li>"
114
+ html += "</ul>"
 
 
 
 
 
 
115
  else:
116
+ html = "<p>No specific conditions found in the report.</p>"
117
+
118
+ return html
119
+
120
  except Exception as e:
121
+ logger.error(e)
122
  return f"Error processing PDF: {str(e)}"
123
 
124
+ # ---------------- UI ---------------- #
125
+ with gr.Blocks(title="RadiologyScan AI", theme=gr.themes.Soft()) as demo:
126
+ gr.Markdown("## 🩻 RadiologyScan AI\nUpload an X-ray or PDF report for AI-assisted analysis")
 
 
 
 
 
 
 
127
 
128
  with gr.Tabs():
129
  with gr.Tab("πŸ” X-ray Analysis"):
130
+ xray_input = gr.Image(label="Upload Chest X-ray", type="pil")
131
+ xray_html = gr.HTML()
132
+ xray_cam = gr.Image(label="AI Heatmap")
133
+
134
+ analyse_btn = gr.Button("Analyze X-ray")
135
+ clear_xray = gr.Button("Clear")
136
+
137
+ analyse_btn.click(analyse_xray, inputs=xray_input, outputs=[xray_html, xray_cam])
138
+ clear_xray.click(lambda: (None, "", None), None, outputs=[xray_input, xray_html, xray_cam])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  with gr.Tab("πŸ“„ Report Analysis"):
141
+ pdf_input = gr.File(label="Upload PDF report", file_types=[".pdf"])
142
+ pdf_html = gr.HTML()
143
+
144
+ analyse_pdf_btn = gr.Button("Analyze Report")
145
+ clear_pdf = gr.Button("Clear")
146
+
147
+ analyse_pdf_btn.click(analyse_report, inputs=pdf_input, outputs=pdf_html)
148
+ clear_pdf.click(lambda: (None, ""), None, outputs=[pdf_input, pdf_html])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
  if __name__ == "__main__":
151
+ demo.launch(server_port=int(os.getenv("PORT", 7860)), show_error=True)