VaneshDev commited on
Commit
35a8d6f
Β·
verified Β·
1 Parent(s): 1736b2d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +220 -108
app.py CHANGED
@@ -16,152 +16,264 @@ import torchxrayvision as xrv
16
  import fitz # PyMuPDF
17
  from torchcam.methods import SmoothGradCAMpp
18
  from transformers import pipeline
 
19
 
20
  logging.basicConfig(level=logging.INFO)
21
  log = logging.getLogger(__name__)
22
 
23
  # Load model
24
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
- MODEL = xrv.models.get_model("densenet121-res224-all").to(DEVICE).eval()
26
  LABELS = MODEL.pathologies
27
 
28
- TRANSFORM = transforms.Compose([
29
- transforms.Resize(224),
30
- transforms.CenterCrop(224),
31
- transforms.ToTensor(),
32
- transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
33
- ])
34
-
35
- def preprocess(pil_img: Image.Image) -> torch.Tensor:
36
- if pil_img.mode != "RGB":
37
- pil_img = pil_img.convert("RGB")
38
- return TRANSFORM(pil_img).unsqueeze(0).to(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- # X-ray prediction with Grad-CAM
41
- cam_extractor = SmoothGradCAMpp(MODEL)
42
 
43
  def analyse_xray(img: Image.Image):
44
- if img is None: return "Please upload an image."
45
-
46
- x = preprocess(img)
47
- with torch.no_grad():
48
- logits = MODEL(x)
49
- probs = torch.sigmoid(logits)[0] * 100
50
- topk = torch.topk(probs, 3)
51
-
52
- # Grad-CAM heat-map
53
- target = topk.indices[0].item()
54
- activation_map = cam_extractor(target, logits)[0]
55
- heatmap = cam_extractor.overlay(torch.squeeze(x).cpu(), activation_map)
56
-
57
- # Build HTML summary
58
- rows = "".join(
59
- f"<tr><td>{LABELS[i]}</td><td>{probs[i]:.1f}%</td></tr>"
60
- for i in topk.indices
61
- )
62
- advice = medical_advice(LABELS[target])
63
- html = f"""
64
- <h3>AI findings</h3>
65
- <table border="1"><tr><th>Condition</th><th>Probability</th></tr>{rows}</table>
66
- <p><b>Advice:</b> {advice}</p>
67
- """
68
 
69
- return html, Image.fromarray(heatmap)
70
-
71
- # Medical advice
72
- ADVICE = {
73
- "Pneumonia": "Consult a pulmonologist; antibiotics or antivirals as indicated.",
74
- "Cardiomegaly": "Recommend echocardiography; refer to cardiology.",
75
- "Atelectasis": "Further imaging may be needed; consult pulmonologist.",
76
- "Consolidation": "Likely infection or inflammation; seek medical attention.",
77
- "Pleural_Thickening": "Monitor for progression; pulmonology consultation.",
78
- "Edema": "Evaluate for heart failure; cardiology consultation.",
79
- "Effusion": "Thoracentesis may be needed; pulmonology consultation.",
80
- "Fracture": "Orthopaedic consultation; consider CT if uncertain.",
81
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- def medical_advice(label):
84
- return ADVICE.get(label, "Discuss with a radiologist for next steps.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- # PDF report summariser
87
- try:
88
- summariser = pipeline("summarization", model="sshleifer/distilbart-cnn-12-6")
89
- except:
90
- summariser = None
91
- log.warning("Could not load summarization model")
92
 
 
93
  def analyse_report(file):
94
- if file is None: return "Please upload a PDF."
 
95
 
96
  try:
 
97
  with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
98
  tmp.write(file.read())
99
  tmp_path = tmp.name
100
 
 
101
  doc = fitz.open(tmp_path)
102
  text = "\n".join(page.get_text() for page in doc)
103
  doc.close()
104
  os.unlink(tmp_path)
105
-
106
- disease = regex_find_disease(text)
107
- if disease:
108
- advice = medical_advice(disease)
109
- return f"""
110
- <h3>Disease detected:</h3><p>{disease}</p>
111
- <p><b>Recommendation:</b> {advice}</p>
112
- """
113
- elif summariser:
114
- # fallback LLM summary
115
- short_text = text[:4000] if len(text) > 4000 else text
116
- summary = summariser(short_text, max_length=120, min_length=30, do_sample=False)
117
- return f"<h3>Report summary</h3><p>{summary[0]['summary_text']}</p>"
 
 
 
118
  else:
119
- return "<h3>Report processed</h3><p>No specific conditions detected. Please consult with a medical professional for interpretation.</p>"
 
120
  except Exception as e:
 
121
  return f"Error processing PDF: {str(e)}"
122
 
123
- def regex_find_disease(text: str):
124
- patterns = {
125
- "Pneumonia": r"\b(pneumonia|lung infection)\b",
126
- "Cardiomegaly": r"\b(cardiomegaly|enlarged heart)\b",
127
- "Atelectasis": r"\b(atelectasis|lung collapse)\b",
128
- "Consolidation": r"\b(consolidation|lung consolidation)\b",
129
- "Fracture": r"\b(fracture|broken bone|break)\b",
130
- "Edema": r"\b(edema|fluid buildup)\b",
131
- "Effusion": r"\b(effusion|fluid collection)\b",
132
- }
133
- for condition, pattern in patterns.items():
134
- if re.search(pattern, text, flags=re.I):
135
- return condition
136
- return None
137
-
138
- # Gradio UI
139
- with gr.Blocks(title="🩻 RadiologyScan AI") as demo:
140
- gr.Markdown("## 🩻 RadiologyScan AI – Chest X-ray & Report Analyser")
141
 
142
  with gr.Tabs():
143
- with gr.Tab("X-ray Analysis"):
144
- in_img = gr.Image(label="Upload chest X-ray", type="pil")
145
- out_html = gr.HTML()
146
- out_cam = gr.Image(label="Attention Map")
147
 
148
  with gr.Row():
149
- analyze_btn = gr.Button("Analyze X-ray", variant="primary")
150
- clear_btn = gr.Button("Clear")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
- analyze_btn.click(analyse_xray, inputs=in_img, outputs=[out_html, out_cam])
153
- clear_btn.click(lambda: (None, "", None), inputs=None, outputs=[in_img, out_html, out_cam])
 
 
 
154
 
155
- with gr.Tab("Report Analysis"):
156
- in_pdf = gr.File(label="Upload PDF report", file_types=[".pdf"])
157
- out_rep = gr.HTML()
158
 
159
  with gr.Row():
160
- analyze_rep_btn = gr.Button("Analyze Report", variant="primary")
161
- clear_rep_btn = gr.Button("Clear")
 
 
 
 
 
 
 
162
 
163
- analyze_rep_btn.click(analyse_report, inputs=in_pdf, outputs=out_rep)
164
- clear_rep_btn.click(lambda: (None, ""), inputs=None, outputs=[in_pdf, out_rep])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  if __name__ == "__main__":
167
- demo.launch(show_error=True, server_port=int(os.getenv("PORT", 7860)))
 
 
 
 
 
 
16
  import fitz # PyMuPDF
17
  from torchcam.methods import SmoothGradCAMpp
18
  from transformers import pipeline
19
+ import numpy as np
20
 
21
  logging.basicConfig(level=logging.INFO)
22
  log = logging.getLogger(__name__)
23
 
24
  # Load model
25
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ MODEL = xrv.models.get_model("densenet121-res224-all").to(DEVICE).eval()
27
  LABELS = MODEL.pathologies
28
 
29
+ # Correct transform for TorchXRayVision (grayscale, normalized to [-1024, 1024])
30
+ def preprocess_xray(pil_img: Image.Image) -> torch.Tensor:
31
+ """
32
+ Preprocess PIL image for TorchXRayVision model
33
+ TorchXRayVision expects:
34
+ - Single channel grayscale image
35
+ - Values normalized to [-1024, 1024] range
36
+ - Resolution of 224x224
37
+ """
38
+ # Convert to grayscale if needed
39
+ if pil_img.mode != "L":
40
+ pil_img = pil_img.convert("L")
41
+
42
+ # Convert to numpy array
43
+ img_array = np.array(pil_img, dtype=np.float32)
44
+
45
+ # Normalize to [-1024, 1024] range (TorchXRayVision standard)
46
+ # Assume input is 8-bit (0-255), scale to [-1024, 1024]
47
+ img_array = xrv.datasets.normalize(img_array, 255)
48
+
49
+ # Add channel dimension and resize
50
+ img_array = img_array[None, ...] # Add channel dimension
51
+
52
+ # Use TorchXRayVision transforms
53
+ transform = transforms.Compose([
54
+ xrv.datasets.XRayCenterCrop(),
55
+ xrv.datasets.XRayResizer(224)
56
+ ])
57
+
58
+ img_array = transform(img_array)
59
+
60
+ # Convert to tensor
61
+ img_tensor = torch.from_numpy(img_array).unsqueeze(0).to(DEVICE)
62
+
63
+ return img_tensor
64
 
65
+ # Initialize CAM extractor with correct input shape for grayscale
66
+ cam_extractor = SmoothGradCAMpp(MODEL, input_shape=(1, 224, 224)) # Single channel
67
 
68
  def analyse_xray(img: Image.Image):
69
+ if img is None:
70
+ return "Please upload an image.", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ try:
73
+ # Preprocess image for TorchXRayVision
74
+ x = preprocess_xray(img)
75
+
76
+ with torch.no_grad():
77
+ logits = MODEL(x)
78
+ probs = torch.sigmoid(logits)[0] * 100 # Convert to percentages
79
+
80
+ # Get top 5 predictions
81
+ topk = torch.topk(probs, 5)
82
+
83
+ # Generate Grad-CAM heat-map for the highest scoring condition
84
+ target = topk.indices[0].item()
85
+
86
+ # Generate activation map
87
+ activation_map = cam_extractor(target, logits)[0]
88
+
89
+ # Overlay heatmap on original image
90
+ # Convert single channel to 3-channel for overlay
91
+ input_for_overlay = x.squeeze(0).cpu()
92
+ input_for_overlay = input_for_overlay.repeat(3, 1, 1) # Repeat single channel 3 times
93
+
94
+ heatmap = cam_extractor.overlay(input_for_overlay, activation_map)
95
+
96
+ # Build HTML summary table
97
+ table_rows = ""
98
+ for i in range(len(topk.indices)):
99
+ idx = topk.indices[i].item()
100
+ prob = probs[idx].item()
101
+ condition = LABELS[idx]
102
+ table_rows += f"<tr><td>{condition}</td><td>{prob:.1f}%</td></tr>"
103
+
104
+ top_condition = LABELS[target]
105
+ advice = get_medical_advice(top_condition)
106
+
107
+ html_output = f"""
108
+ <div style="font-family: Arial, sans-serif;">
109
+ <h3>🩺 AI Analysis Results</h3>
110
+ <table border="1" style="border-collapse: collapse; width: 100%;">
111
+ <tr style="background-color: #f2f2f2;">
112
+ <th style="padding: 8px; text-align: left;">Condition</th>
113
+ <th style="padding: 8px; text-align: left;">Probability</th>
114
+ </tr>
115
+ {table_rows}
116
+ </table>
117
+ <br>
118
+ <h4>πŸ” Top Finding: {top_condition}</h4>
119
+ <p><strong>Recommendation:</strong> {advice}</p>
120
+ <p><em>⚠️ This is an AI analysis tool for educational purposes only. Always consult qualified medical professionals for diagnosis and treatment.</em></p>
121
+ </div>
122
+ """
123
+
124
+ return html_output, Image.fromarray(heatmap)
125
+
126
+ except Exception as e:
127
+ log.error(f"Error in X-ray analysis: {e}")
128
+ return f"Error processing image: {str(e)}", None
129
 
130
+ # Medical advice dictionary
131
+ MEDICAL_ADVICE = {
132
+ "Atelectasis": "Lung collapse detected. Recommend pulmonology consultation and chest physiotherapy.",
133
+ "Cardiomegaly": "Enlarged heart detected. Recommend echocardiography and cardiology consultation.",
134
+ "Consolidation": "Lung consolidation detected. May indicate pneumonia or other lung disease. Seek medical attention.",
135
+ "Edema": "Pulmonary edema detected. Recommend urgent cardiology evaluation.",
136
+ "Emphysema": "Emphysema changes detected. Recommend pulmonology consultation and smoking cessation if applicable.",
137
+ "Fibrosis": "Lung fibrosis detected. Recommend pulmonology consultation for further evaluation.",
138
+ "Hernia": "Hernia detected. Recommend surgical consultation if symptomatic.",
139
+ "Infiltration": "Lung infiltration detected. May indicate infection or inflammation. Seek medical attention.",
140
+ "Mass": "Lung mass detected. Recommend urgent oncology consultation and further imaging.",
141
+ "Nodule": "Lung nodule detected. Recommend follow-up imaging and pulmonology consultation.",
142
+ "Pleural_Thickening": "Pleural thickening detected. Recommend pulmonology consultation.",
143
+ "Pneumonia": "Pneumonia detected. Recommend immediate antibiotic treatment and medical supervision.",
144
+ "Pneumothorax": "Pneumothorax (collapsed lung) detected. May require immediate medical intervention.",
145
+ "Effusion": "Pleural effusion detected. Recommend thoracentesis evaluation and pulmonology consultation."
146
+ }
147
 
148
+ def get_medical_advice(condition: str) -> str:
149
+ return MEDICAL_ADVICE.get(condition, "Consult with a radiologist or pulmonologist for proper interpretation.")
 
 
 
 
150
 
151
+ # PDF report analysis (simplified - focusing on the main issue)
152
  def analyse_report(file):
153
+ if file is None:
154
+ return "Please upload a PDF file."
155
 
156
  try:
157
+ # Create temporary file
158
  with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
159
  tmp.write(file.read())
160
  tmp_path = tmp.name
161
 
162
+ # Extract text from PDF
163
  doc = fitz.open(tmp_path)
164
  text = "\n".join(page.get_text() for page in doc)
165
  doc.close()
166
  os.unlink(tmp_path)
167
+
168
+ # Simple pattern matching for common conditions
169
+ detected_conditions = []
170
+ for condition in LABELS:
171
+ if re.search(rf'\b{condition.lower()}\b', text.lower()):
172
+ detected_conditions.append(condition)
173
+
174
+ if detected_conditions:
175
+ html_output = "<h3>πŸ“‹ Report Analysis</h3>"
176
+ html_output += "<h4>Detected Conditions:</h4><ul>"
177
+ for condition in detected_conditions[:5]: # Show top 5
178
+ advice = get_medical_advice(condition)
179
+ html_output += f"<li><strong>{condition}</strong>: {advice}</li>"
180
+ html_output += "</ul>"
181
+ html_output += "<p><em>⚠️ This analysis is for educational purposes only. Consult medical professionals for proper diagnosis.</em></p>"
182
+ return html_output
183
  else:
184
+ return "<h3>πŸ“‹ Report Analysis</h3><p>No specific pathological conditions detected in the report text. Please consult with a medical professional for proper interpretation.</p>"
185
+
186
  except Exception as e:
187
+ log.error(f"Error in report analysis: {e}")
188
  return f"Error processing PDF: {str(e)}"
189
 
190
+ # Gradio interface
191
+ with gr.Blocks(title="🩻 RadiologyScan AI", theme=gr.themes.Soft()) as demo:
192
+ gr.Markdown("""
193
+ # 🩻 RadiologyScan AI
194
+ ### AI-Powered Chest X-ray and Medical Report Analysis
195
+
196
+ **⚠️ IMPORTANT DISCLAIMER**: This tool is for educational and research purposes only.
197
+ It should NOT be used for actual medical diagnosis or treatment decisions.
198
+ Always consult qualified healthcare professionals for medical advice.
199
+ """)
 
 
 
 
 
 
 
 
200
 
201
  with gr.Tabs():
202
+ with gr.Tab("πŸ” X-ray Analysis"):
203
+ gr.Markdown("Upload a chest X-ray image for AI analysis")
 
 
204
 
205
  with gr.Row():
206
+ with gr.Column():
207
+ img_input = gr.Image(label="Upload Chest X-ray", type="pil")
208
+
209
+ with gr.Row():
210
+ analyze_btn = gr.Button("πŸ” Analyze X-ray", variant="primary", size="lg")
211
+ clear_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
212
+
213
+ with gr.Column():
214
+ html_output = gr.HTML(label="Analysis Results")
215
+ cam_output = gr.Image(label="Attention Heatmap", type="pil")
216
+
217
+ analyze_btn.click(
218
+ fn=analyse_xray,
219
+ inputs=img_input,
220
+ outputs=[html_output, cam_output]
221
+ )
222
 
223
+ clear_btn.click(
224
+ fn=lambda: (None, "", None),
225
+ inputs=None,
226
+ outputs=[img_input, html_output, cam_output]
227
+ )
228
 
229
+ with gr.Tab("πŸ“„ Report Analysis"):
230
+ gr.Markdown("Upload a medical report PDF for AI analysis")
 
231
 
232
  with gr.Row():
233
+ with gr.Column():
234
+ pdf_input = gr.File(label="Upload PDF Report", file_types=[".pdf"])
235
+
236
+ with gr.Row():
237
+ analyze_report_btn = gr.Button("πŸ“„ Analyze Report", variant="primary", size="lg")
238
+ clear_report_btn = gr.Button("πŸ—‘οΈ Clear", variant="secondary")
239
+
240
+ with gr.Column():
241
+ report_output = gr.HTML(label="Report Analysis")
242
 
243
+ analyze_report_btn.click(
244
+ fn=analyse_report,
245
+ inputs=pdf_input,
246
+ outputs=report_output
247
+ )
248
+
249
+ clear_report_btn.click(
250
+ fn=lambda: (None, ""),
251
+ inputs=None,
252
+ outputs=[pdf_input, report_output]
253
+ )
254
+
255
+ gr.Markdown("""
256
+ ### πŸ“– How to Use
257
+ 1. **X-ray Analysis**: Upload a chest X-ray image (JPEG, PNG) and click "Analyze X-ray"
258
+ 2. **Report Analysis**: Upload a medical report PDF and click "Analyze Report"
259
+
260
+ ### πŸ”¬ Technical Details
261
+ - Uses TorchXRayVision pre-trained DenseNet-121 model
262
+ - Trained on multiple chest X-ray datasets
263
+ - Provides attention heatmaps for interpretability
264
+ - Supports 18 different pathological conditions
265
+
266
+ ### ⚠️ Limitations
267
+ - For educational use only
268
+ - Not a substitute for professional medical diagnosis
269
+ - Results may vary based on image quality
270
+ - Always consult healthcare professionals
271
+ """)
272
 
273
  if __name__ == "__main__":
274
+ demo.launch(
275
+ server_name="0.0.0.0",
276
+ server_port=int(os.getenv("PORT", 7860)),
277
+ show_error=True,
278
+ share=False
279
+ )