File size: 12,088 Bytes
5cf329d
 
 
 
 
 
a93d23f
5cf329d
a93d23f
 
 
 
5cf329d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a93d23f
 
 
 
 
 
 
 
 
 
5cf329d
 
a93d23f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284e8fb
a93d23f
 
 
 
 
 
 
 
 
 
284e8fb
a93d23f
284e8fb
a93d23f
284e8fb
f6426fc
284e8fb
 
 
 
 
 
d3faaea
 
284e8fb
f6426fc
d3faaea
 
 
 
a93d23f
284e8fb
 
 
 
 
 
 
d3faaea
 
284e8fb
f6426fc
d3faaea
 
 
 
284e8fb
 
 
 
f6426fc
284e8fb
 
 
d3faaea
 
284e8fb
 
d3faaea
 
 
 
284e8fb
 
 
a93d23f
 
284e8fb
a93d23f
 
 
 
284e8fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a93d23f
284e8fb
 
 
 
 
 
a93d23f
284e8fb
a93d23f
 
284e8fb
a93d23f
 
 
 
 
 
f6426fc
a93d23f
f6426fc
 
 
 
 
 
 
 
 
 
 
a93d23f
 
 
f6426fc
 
 
 
a93d23f
 
 
 
 
 
 
5cf329d
a93d23f
 
 
 
 
5cf329d
a93d23f
5cf329d
 
 
a93d23f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ea11f9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
import matplotlib.pyplot as plt
import numpy as np, torch
import torchvision.transforms as T
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
# Grad-CAM imports removed for simplified UI
from PIL import Image
import requests
import os
import base64
import io


# Define CNN
class Net(nn.Module):
    """Simple CNN with Batch Normalization and Dropout regularisation."""

    def __init__(self) -> None:
        super().__init__()
        # Convolutional block 1
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)

        # Convolutional block 2
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)

        # Fully - connected head
        self.fc1 = nn.Linear(32 * 56 * 56, 112)
        self.dropout1 = nn.Dropout(0.5)

        self.fc2 = nn.Linear(112, 84)
        self.dropout2 = nn.Dropout(0.2)

        self.fc3 = nn.Linear(84, 2)

    def forward(self, x) -> torch.Tensor:  # N,C,H,W
        """Forward pass returning raw logits (no softmax)."""
        c1 = self.pool(F.relu(self.bn1(self.conv1(x))))  # N,16,112,112
        c2 = self.pool(F.relu(self.bn2(self.conv2(c1))))  # N,32,56,56
        c2 = torch.flatten(c2, 1)  # N,32*56*56
        f3 = self.dropout1(F.relu(self.fc1(c2)))  # N,112
        f4 = self.dropout2(F.relu(self.fc2(f3)))  # N,84
        out = self.fc3(f4)  # N,2
        return out

# Load pre-trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net().to(device)
model.load_state_dict(torch.load("best_model.pt", map_location=device))
model.eval()

transform = T.Compose([T.Resize((224,224)),
                       T.ToTensor(),
                       T.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])

# Simplified prediction function without Grad-CAM
def predict_pneumonia(image):
    # Convert image to RGB
    img = image.convert("RGB")
    
    # HERE IS WHERE THE IMAGE ENTERS THE MODEL:
    # 1. Apply transforms (resize to 224x224, normalize)
    tensor = transform(img).unsqueeze(0).to(device)  # Shape: [1, 3, 224, 224]
    
    # 2. Pass through the model
    with torch.no_grad():
        p = torch.softmax(model(tensor), dim=1)[0,1].item()
    
    # Format results
    prob = f"{p:.3f}"
    label = f"{'PNEUMONIA' if p>0.5 else 'NORMAL'}"
    confidence = f"{p*100:.1f}%" if p > 0.5 else f"{(1-p)*100:.1f}%"

    return prob, label, confidence

# MedGemma Chatbot functionality
def image_to_base64(image):
    """Convert PIL image to base64 string"""
    buffer = io.BytesIO()
    image.save(buffer, format="JPEG")
    img_bytes = buffer.getvalue()
    img_base64 = base64.b64encode(img_bytes).decode()
    return f"data:image/jpeg;base64,{img_base64}"

def query_medgemma(message, history, image=None):
    """Query MedGemma endpoint with proper multimodal format"""
    
    # Your endpoint URL
    endpoint_url = "https://t911ok4t5x994zcu.us-east-1.aws.endpoints.huggingface.cloud"
    
    # Headers with your HF token
    headers = {
        "Authorization": f"Bearer {os.getenv('HUGGINGFACE_TOKEN')}",
        "Content-Type": "application/json"
    }
    
    # Prepare payload based on whether we have an image or not
    if image is not None:
        # Multimodal format: Send image as base64 in the content
        image_base64 = image_to_base64(image)
        
        # TGI multimodal format for MedGemma with better stopping
        payload = {
            "inputs": {
                "text": message,
                "image": image_base64
            },
            "parameters": {
                "max_new_tokens": 150,      # Optimized for TGI
                "temperature": 0.4,         # Balanced for medical responses
                "do_sample": True,
                "return_full_text": False,
                "stop": ["</s>", "<|im_end|>", "\n\n"],
                "repetition_penalty": 1.15, # Adjusted for TGI
                "top_p": 0.9,               # Better for medical content
                "seed": 42                  # Consistent responses for testing
            }
        }
        
        # Alternative format if the above doesn't work
        payload_alt = {
            "inputs": message,
            "image": image_base64,
            "parameters": {
                "max_new_tokens": 150,
                "temperature": 0.4,
                "do_sample": True,
                "return_full_text": False,
                "stop": ["</s>", "<|im_end|>", "\n\n"],
                "repetition_penalty": 1.15,
                "top_p": 0.9,
                "seed": 42
            }
        }
        
    else:
        # Text-only format with better parameters
        payload = {
            "inputs": message,
            "parameters": {
                "max_new_tokens": 150,
                "temperature": 0.4,
                "do_sample": True,
                "return_full_text": False,
                "stop": ["</s>", "<|im_end|>", "\n\n"],
                "repetition_penalty": 1.15,
                "top_p": 0.9,
                "seed": 42
            }
        }
        payload_alt = None
    
    try:
        # Try primary format
        response = requests.post(endpoint_url, headers=headers, json=payload, timeout=30)
        
        if response.status_code == 200:
            result = response.json()
            
            # Handle different TGI response formats
            if isinstance(result, list) and len(result) > 0:
                return result[0].get("generated_text", "Error en el formato de respuesta")
            elif "generated_text" in result:
                return result["generated_text"]
            else:
                return f"Formato de respuesta inesperado: {str(result)[:200]}"
                
        elif response.status_code == 422 and payload_alt is not None:
            # Try alternative format for multimodal
            response = requests.post(endpoint_url, headers=headers, json=payload_alt, timeout=30)
            
            if response.status_code == 200:
                result = response.json()
                if isinstance(result, list) and len(result) > 0:
                    return result[0].get("generated_text", "Error en el formato de respuesta")
                elif "generated_text" in result:
                    return result["generated_text"]
                else:
                    return f"Formato de respuesta inesperado: {str(result)[:200]}"
            else:
                return f"Error 422 en ambos formatos. Detalles: {response.text[:300]}"
                
        elif response.status_code == 503:
            return "El modelo está escalado a cero. Intenta de nuevo en unos segundos mientras se activa."
        elif response.status_code == 422:
            return f"Error de formato en la petición. Detalles: {response.text[:300]}"
        else:
            return f"Error del endpoint: {response.status_code}. Detalles: {response.text[:200]}"
            
    except requests.exceptions.Timeout:
        return "Timeout: El modelo está procesando, intenta de nuevo en unos segundos."
    except Exception as e:
        return f"Error de conexión: {str(e)}"

def medical_chat(message, history, uploaded_image):
    """Handle medical chat with context from pneumonia detection"""
    
    # Always add medical context to keep the model focused
    if uploaded_image is not None:
        context_message = f"""Eres MedGemma, un asistente médico especializado en radiología. Analiza esta radiografía de tórax y responde la siguiente pregunta médica de forma precisa y profesional:

Pregunta: {message}

Contexto: Esta es una radiografía de tórax para detectar neumonía. Enfócate en proporcionar información médica relevante, síntomas, diagnósticos o recomendaciones. Siempre menciona que se debe consultar a un profesional médico."""
    else:
        context_message = f"""Eres MedGemma, un asistente médico especializado. Responde la siguiente pregunta médica de forma precisa y profesional:

Pregunta: {message}

Instrucciones: Proporciona información médica precisa y relevante. Si no es una pregunta médica, redirige hacia temas de salud relacionados. Siempre menciona que se debe consultar a un profesional médico para diagnósticos definitivos."""
    
    response = query_medgemma(context_message, history, uploaded_image)
    
    # Clean the response if it contains non-medical content
    if any(keyword in response.lower() for keyword in ["requests", "python", "código", "api", "import", "status code"]):
        response = "Como asistente médico, me especializo en temas de salud y medicina. ¿Podrías hacer una pregunta relacionada con medicina, síntomas, diagnósticos o radiografías? Estoy aquí para ayudarte con consultas médicas."
    
    # Add the exchange to history
    history.append([message, response])
    return history, ""

# Create the main pneumonia detection interface
pneumonia_interface = gr.Interface(
    fn=predict_pneumonia,
    inputs=gr.Image(type="pil", label="Upload Chest X-ray"),
    outputs=[
        gr.Textbox(label="Probability of Pneumonia"), 
        gr.Label(label="Prediction"), 
        gr.Textbox(label="Confidence")
    ],
    title="🫁 Pneumonia Detection from Chest X-rays",
    description="Upload a chest X-ray to detect signs of pneumonia using deep learning.",
    flagging_mode="never"
)

# Create the MedGemma chatbot interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🫁 RADOX - Sistema Inteligente de Detección de Neumonía")
    gr.Markdown("### Análisis de Radiografías + Consulta Médica con IA")
    
    with gr.Row():
        with gr.Column(scale=1):
            # Pneumonia Detection Section
            gr.Markdown("## 🔍 Detección de Neumonía")
            input_image = gr.Image(type="pil", label="Subir Radiografía de Tórax")
            analyze_btn = gr.Button("🔬 Analizar Radiografía", variant="primary")
            
            with gr.Row():
                prob_output = gr.Textbox(label="Probabilidad de Neumonía")
                pred_output = gr.Label(label="Diagnóstico")
                conf_output = gr.Textbox(label="Confianza")
    
    # Medical Chatbot Section
    gr.Markdown("## 🤖 Consulta Médica con MedGemma")
    gr.Markdown("*Haz preguntas sobre la radiografía o consultas médicas generales*")
    
    with gr.Row():
        with gr.Column(scale=3):
            chatbot = gr.Chatbot(
                label="Chat Médico",
                height=400,
                show_label=True
            )
            
            with gr.Row():
                msg_input = gr.Textbox(
                    label="Tu pregunta",
                    placeholder="Ej: ¿Qué significan estos resultados? ¿Cuáles son los síntomas de neumonía?",
                    scale=4
                )
                send_btn = gr.Button("Enviar", variant="primary", scale=1)
        
        with gr.Column(scale=1):
            chat_image = gr.Image(
                type="pil", 
                label="Imagen para el chat (opcional)",
                height=300
            )
            gr.Markdown("💡 **Tip:** Puedes subir la misma radiografía aquí para hacer preguntas específicas sobre ella.")
    
    # Event handlers
    analyze_btn.click(
        fn=predict_pneumonia,
        inputs=[input_image],
        outputs=[prob_output, pred_output, conf_output]
    )
    
    send_btn.click(
        fn=medical_chat,
        inputs=[msg_input, chatbot, chat_image],
        outputs=[chatbot, msg_input]
    )
    
    msg_input.submit(
        fn=medical_chat,
        inputs=[msg_input, chatbot, chat_image],
        outputs=[chatbot, msg_input]
    )
    
    # Footer
    gr.Markdown("""
    ---
    ⚠️ **Aviso Médico Importante**: Esta herramienta es solo para fines educativos y de apoyo diagnóstico. 
    Siempre consulte con un profesional médico cualificado para obtener diagnósticos y tratamientos precisos.
    """)

demo.launch()