mbrq13 commited on
Commit
a93d23f
·
1 Parent(s): 5ea11f9

Add pneumonia detection app with Grad-CAM

Browse files
Files changed (2) hide show
  1. app.py +187 -26
  2. requirements.txt +2 -3
app.py CHANGED
@@ -4,9 +4,12 @@ import torchvision.transforms as T
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  import gradio as gr
7
- from pytorch_grad_cam import GradCAM
8
- from pytorch_grad_cam.utils.image import show_cam_on_image
9
  from PIL import Image
 
 
 
 
10
 
11
 
12
  # Define CNN
@@ -53,35 +56,193 @@ transform = T.Compose([T.Resize((224,224)),
53
  T.ToTensor(),
54
  T.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
55
 
56
- # Upload and visualize an image
57
- def predict_gradcam(image):
58
- # prediction
59
- img = image.convert("RGB")
60
- plt.imshow(image); plt.axis('off'); plt.show()
61
- tensor = transform(img).unsqueeze(0).to(device)
 
 
 
 
62
  with torch.no_grad():
63
  p = torch.softmax(model(tensor), dim=1)[0,1].item()
64
- prob= f"{p:.3f}"
65
- label= f"{'PNEUMONIA' if p>0.5 else 'NORMAL'}"
66
-
67
- # Grad-CAM
68
- target_layer = model.conv2
69
- input_tensor = transform(img).unsqueeze(0).to(device)
70
- cam = GradCAM(model=model, target_layers=[target_layer])
71
- grayscale_cam = cam(input_tensor=input_tensor)[0]
72
- img_np = np.array(img.resize((224,224)), dtype=np.float32)/255.0
73
- heatmap = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)
74
- heatmap_pil = Image.fromarray(heatmap)
75
-
76
- return prob, label, heatmap_pil
77
-
78
- demo = gr.Interface(
79
- fn=predict_gradcam,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  inputs=gr.Image(type="pil", label="Upload Chest X-ray"),
81
- outputs=[gr.Textbox(label="Probability of Pneumonia"), gr.Label(label="Prediction"), gr.Image(label="Grad-CAM")],
 
 
 
 
82
  title="🫁 Pneumonia Detection from Chest X-rays",
83
- description="Upload a chest X-ray to see whether it shows signs of pneumonia. The model will predict the probability and show a Grad-CAM visualization of the most important regions.",
84
  flagging_mode="never"
85
  )
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  demo.launch()
 
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
  import gradio as gr
7
+ # Grad-CAM imports removed for simplified UI
 
8
  from PIL import Image
9
+ import requests
10
+ import os
11
+ import base64
12
+ import io
13
 
14
 
15
  # Define CNN
 
56
  T.ToTensor(),
57
  T.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
58
 
59
+ # Simplified prediction function without Grad-CAM
60
+ def predict_pneumonia(image):
61
+ # Convert image to RGB
62
+ img = image.convert("RGB")
63
+
64
+ # HERE IS WHERE THE IMAGE ENTERS THE MODEL:
65
+ # 1. Apply transforms (resize to 224x224, normalize)
66
+ tensor = transform(img).unsqueeze(0).to(device) # Shape: [1, 3, 224, 224]
67
+
68
+ # 2. Pass through the model
69
  with torch.no_grad():
70
  p = torch.softmax(model(tensor), dim=1)[0,1].item()
71
+
72
+ # Format results
73
+ prob = f"{p:.3f}"
74
+ label = f"{'PNEUMONIA' if p>0.5 else 'NORMAL'}"
75
+ confidence = f"{p*100:.1f}%" if p > 0.5 else f"{(1-p)*100:.1f}%"
76
+
77
+ return prob, label, confidence
78
+
79
+ # MedGemma Chatbot functionality
80
+ def image_to_base64(image):
81
+ """Convert PIL image to base64 string"""
82
+ buffer = io.BytesIO()
83
+ image.save(buffer, format="JPEG")
84
+ img_bytes = buffer.getvalue()
85
+ img_base64 = base64.b64encode(img_bytes).decode()
86
+ return f"data:image/jpeg;base64,{img_base64}"
87
+
88
+ def query_medgemma(message, history, image=None):
89
+ """Query MedGemma endpoint with text and optional image"""
90
+
91
+ # Your endpoint URL
92
+ endpoint_url = "https://t911ok4t5x994zcu.us-east-1.aws.endpoints.huggingface.cloud"
93
+
94
+ # Headers with your HF token
95
+ headers = {
96
+ "Authorization": f"Bearer {os.getenv('HUGGINGFACE_TOKEN')}",
97
+ "Content-Type": "application/json"
98
+ }
99
+
100
+ # Prepare the message content
101
+ content = []
102
+
103
+ # Add image if provided
104
+ if image is not None:
105
+ image_base64 = image_to_base64(image)
106
+ content.append({
107
+ "type": "image_url",
108
+ "image_url": {"url": image_base64}
109
+ })
110
+
111
+ # Add text message
112
+ content.append({
113
+ "type": "text",
114
+ "text": message
115
+ })
116
+
117
+ # Prepare payload
118
+ payload = {
119
+ "model": "tgi",
120
+ "messages": [
121
+ {
122
+ "role": "user",
123
+ "content": content
124
+ }
125
+ ],
126
+ "max_tokens": 500,
127
+ "temperature": 0.7
128
+ }
129
+
130
+ try:
131
+ response = requests.post(endpoint_url, headers=headers, json=payload, timeout=30)
132
+
133
+ if response.status_code == 200:
134
+ result = response.json()
135
+ if "choices" in result and len(result["choices"]) > 0:
136
+ return result["choices"][0]["message"]["content"]
137
+ else:
138
+ return "Lo siento, no pude obtener una respuesta del modelo."
139
+ else:
140
+ return f"Error del endpoint: {response.status_code}. El modelo puede estar escalado a cero - intenta de nuevo en unos segundos."
141
+
142
+ except requests.exceptions.Timeout:
143
+ return "Timeout: El modelo está despertando, intenta de nuevo en unos segundos."
144
+ except Exception as e:
145
+ return f"Error de conexión: {str(e)}"
146
+
147
+ def medical_chat(message, history, uploaded_image):
148
+ """Handle medical chat with context from pneumonia detection"""
149
+
150
+ # Add context about pneumonia detection if there's an image
151
+ context_message = message
152
+ if uploaded_image is not None:
153
+ context_message = f"""Como asistente médico especializado en radiología, analiza esta imagen de rayos X y responde: {message}
154
+
155
+ Contexto: Esta es una radiografía de tórax que puede mostrar signos de neumonía. Proporciona información médica precisa pero recuerda que siempre se debe consultar a un profesional médico."""
156
+
157
+ response = query_medgemma(context_message, history, uploaded_image)
158
+
159
+ # Add the exchange to history
160
+ history.append([message, response])
161
+ return history, ""
162
+
163
+ # Create the main pneumonia detection interface
164
+ pneumonia_interface = gr.Interface(
165
+ fn=predict_pneumonia,
166
  inputs=gr.Image(type="pil", label="Upload Chest X-ray"),
167
+ outputs=[
168
+ gr.Textbox(label="Probability of Pneumonia"),
169
+ gr.Label(label="Prediction"),
170
+ gr.Textbox(label="Confidence")
171
+ ],
172
  title="🫁 Pneumonia Detection from Chest X-rays",
173
+ description="Upload a chest X-ray to detect signs of pneumonia using deep learning.",
174
  flagging_mode="never"
175
  )
176
 
177
+ # Create the MedGemma chatbot interface
178
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
179
+ gr.Markdown("# 🫁 RADOX - Sistema Inteligente de Detección de Neumonía")
180
+ gr.Markdown("### Análisis de Radiografías + Consulta Médica con IA")
181
+
182
+ with gr.Row():
183
+ with gr.Column(scale=1):
184
+ # Pneumonia Detection Section
185
+ gr.Markdown("## 🔍 Detección de Neumonía")
186
+ input_image = gr.Image(type="pil", label="Subir Radiografía de Tórax")
187
+ analyze_btn = gr.Button("🔬 Analizar Radiografía", variant="primary")
188
+
189
+ with gr.Row():
190
+ prob_output = gr.Textbox(label="Probabilidad de Neumonía")
191
+ pred_output = gr.Label(label="Diagnóstico")
192
+ conf_output = gr.Textbox(label="Confianza")
193
+
194
+ # Medical Chatbot Section
195
+ gr.Markdown("## 🤖 Consulta Médica con MedGemma")
196
+ gr.Markdown("*Haz preguntas sobre la radiografía o consultas médicas generales*")
197
+
198
+ with gr.Row():
199
+ with gr.Column(scale=3):
200
+ chatbot = gr.Chatbot(
201
+ label="Chat Médico",
202
+ height=400,
203
+ show_label=True
204
+ )
205
+
206
+ with gr.Row():
207
+ msg_input = gr.Textbox(
208
+ label="Tu pregunta",
209
+ placeholder="Ej: ¿Qué significan estos resultados? ¿Cuáles son los síntomas de neumonía?",
210
+ scale=4
211
+ )
212
+ send_btn = gr.Button("Enviar", variant="primary", scale=1)
213
+
214
+ with gr.Column(scale=1):
215
+ chat_image = gr.Image(
216
+ type="pil",
217
+ label="Imagen para el chat (opcional)",
218
+ height=300
219
+ )
220
+ gr.Markdown("💡 **Tip:** Puedes subir la misma radiografía aquí para hacer preguntas específicas sobre ella.")
221
+
222
+ # Event handlers
223
+ analyze_btn.click(
224
+ fn=predict_pneumonia,
225
+ inputs=[input_image],
226
+ outputs=[prob_output, pred_output, conf_output]
227
+ )
228
+
229
+ send_btn.click(
230
+ fn=medical_chat,
231
+ inputs=[msg_input, chatbot, chat_image],
232
+ outputs=[chatbot, msg_input]
233
+ )
234
+
235
+ msg_input.submit(
236
+ fn=medical_chat,
237
+ inputs=[msg_input, chatbot, chat_image],
238
+ outputs=[chatbot, msg_input]
239
+ )
240
+
241
+ # Footer
242
+ gr.Markdown("""
243
+ ---
244
+ ⚠️ **Aviso Médico Importante**: Esta herramienta es solo para fines educativos y de apoyo diagnóstico.
245
+ Siempre consulte con un profesional médico cualificado para obtener diagnósticos y tratamientos precisos.
246
+ """)
247
+
248
  demo.launch()
requirements.txt CHANGED
@@ -1,8 +1,7 @@
1
- grad-cam==1.5.5
2
  gradio==4.44.1
3
  matplotlib==3.9.4
4
  numpy==2.0.2
5
- opencv-python==4.11.0.86
6
  pillow==10.4.0
7
  torch==2.7.1
8
- torchvision==0.22.1
 
 
 
1
  gradio==4.44.1
2
  matplotlib==3.9.4
3
  numpy==2.0.2
 
4
  pillow==10.4.0
5
  torch==2.7.1
6
+ torchvision==0.22.1
7
+ requests>=2.25.0