|
|
import matplotlib.pyplot as plt |
|
|
import numpy as np, torch |
|
|
import torchvision.transforms as T |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import gradio as gr |
|
|
|
|
|
from PIL import Image |
|
|
import requests |
|
|
import os |
|
|
import base64 |
|
|
import io |
|
|
|
|
|
|
|
|
|
|
|
class Net(nn.Module): |
|
|
"""Simple CNN with Batch Normalization and Dropout regularisation.""" |
|
|
|
|
|
def __init__(self) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1) |
|
|
self.bn1 = nn.BatchNorm2d(16) |
|
|
|
|
|
|
|
|
self.pool = nn.MaxPool2d(2, 2) |
|
|
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) |
|
|
self.bn2 = nn.BatchNorm2d(32) |
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(32 * 56 * 56, 112) |
|
|
self.dropout1 = nn.Dropout(0.5) |
|
|
|
|
|
self.fc2 = nn.Linear(112, 84) |
|
|
self.dropout2 = nn.Dropout(0.2) |
|
|
|
|
|
self.fc3 = nn.Linear(84, 2) |
|
|
|
|
|
def forward(self, x) -> torch.Tensor: |
|
|
"""Forward pass returning raw logits (no softmax).""" |
|
|
c1 = self.pool(F.relu(self.bn1(self.conv1(x)))) |
|
|
c2 = self.pool(F.relu(self.bn2(self.conv2(c1)))) |
|
|
c2 = torch.flatten(c2, 1) |
|
|
f3 = self.dropout1(F.relu(self.fc1(c2))) |
|
|
f4 = self.dropout2(F.relu(self.fc2(f3))) |
|
|
out = self.fc3(f4) |
|
|
return out |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model = Net().to(device) |
|
|
model.load_state_dict(torch.load("best_model.pt", map_location=device)) |
|
|
model.eval() |
|
|
|
|
|
transform = T.Compose([T.Resize((224,224)), |
|
|
T.ToTensor(), |
|
|
T.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])]) |
|
|
|
|
|
|
|
|
def predict_pneumonia(image): |
|
|
|
|
|
img = image.convert("RGB") |
|
|
|
|
|
|
|
|
|
|
|
tensor = transform(img).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
p = torch.softmax(model(tensor), dim=1)[0,1].item() |
|
|
|
|
|
|
|
|
prob = f"{p:.3f}" |
|
|
label = f"{'PNEUMONIA' if p>0.5 else 'NORMAL'}" |
|
|
confidence = f"{p*100:.1f}%" if p > 0.5 else f"{(1-p)*100:.1f}%" |
|
|
|
|
|
return prob, label, confidence |
|
|
|
|
|
|
|
|
def image_to_base64(image): |
|
|
"""Convert PIL image to base64 string""" |
|
|
buffer = io.BytesIO() |
|
|
image.save(buffer, format="JPEG") |
|
|
img_bytes = buffer.getvalue() |
|
|
img_base64 = base64.b64encode(img_bytes).decode() |
|
|
return f"data:image/jpeg;base64,{img_base64}" |
|
|
|
|
|
def query_medgemma(message, history, image=None): |
|
|
"""Query MedGemma endpoint with proper multimodal format""" |
|
|
|
|
|
|
|
|
endpoint_url = "https://t911ok4t5x994zcu.us-east-1.aws.endpoints.huggingface.cloud" |
|
|
|
|
|
|
|
|
headers = { |
|
|
"Authorization": f"Bearer {os.getenv('HUGGINGFACE_TOKEN')}", |
|
|
"Content-Type": "application/json" |
|
|
} |
|
|
|
|
|
|
|
|
if image is not None: |
|
|
|
|
|
image_base64 = image_to_base64(image) |
|
|
|
|
|
|
|
|
payload = { |
|
|
"inputs": { |
|
|
"text": message, |
|
|
"image": image_base64 |
|
|
}, |
|
|
"parameters": { |
|
|
"max_new_tokens": 150, |
|
|
"temperature": 0.4, |
|
|
"do_sample": True, |
|
|
"return_full_text": False, |
|
|
"stop": ["</s>", "<|im_end|>", "\n\n"], |
|
|
"repetition_penalty": 1.15, |
|
|
"top_p": 0.9, |
|
|
"seed": 42 |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
payload_alt = { |
|
|
"inputs": message, |
|
|
"image": image_base64, |
|
|
"parameters": { |
|
|
"max_new_tokens": 150, |
|
|
"temperature": 0.4, |
|
|
"do_sample": True, |
|
|
"return_full_text": False, |
|
|
"stop": ["</s>", "<|im_end|>", "\n\n"], |
|
|
"repetition_penalty": 1.15, |
|
|
"top_p": 0.9, |
|
|
"seed": 42 |
|
|
} |
|
|
} |
|
|
|
|
|
else: |
|
|
|
|
|
payload = { |
|
|
"inputs": message, |
|
|
"parameters": { |
|
|
"max_new_tokens": 150, |
|
|
"temperature": 0.4, |
|
|
"do_sample": True, |
|
|
"return_full_text": False, |
|
|
"stop": ["</s>", "<|im_end|>", "\n\n"], |
|
|
"repetition_penalty": 1.15, |
|
|
"top_p": 0.9, |
|
|
"seed": 42 |
|
|
} |
|
|
} |
|
|
payload_alt = None |
|
|
|
|
|
try: |
|
|
|
|
|
response = requests.post(endpoint_url, headers=headers, json=payload, timeout=30) |
|
|
|
|
|
if response.status_code == 200: |
|
|
result = response.json() |
|
|
|
|
|
|
|
|
if isinstance(result, list) and len(result) > 0: |
|
|
return result[0].get("generated_text", "Error en el formato de respuesta") |
|
|
elif "generated_text" in result: |
|
|
return result["generated_text"] |
|
|
else: |
|
|
return f"Formato de respuesta inesperado: {str(result)[:200]}" |
|
|
|
|
|
elif response.status_code == 422 and payload_alt is not None: |
|
|
|
|
|
response = requests.post(endpoint_url, headers=headers, json=payload_alt, timeout=30) |
|
|
|
|
|
if response.status_code == 200: |
|
|
result = response.json() |
|
|
if isinstance(result, list) and len(result) > 0: |
|
|
return result[0].get("generated_text", "Error en el formato de respuesta") |
|
|
elif "generated_text" in result: |
|
|
return result["generated_text"] |
|
|
else: |
|
|
return f"Formato de respuesta inesperado: {str(result)[:200]}" |
|
|
else: |
|
|
return f"Error 422 en ambos formatos. Detalles: {response.text[:300]}" |
|
|
|
|
|
elif response.status_code == 503: |
|
|
return "El modelo está escalado a cero. Intenta de nuevo en unos segundos mientras se activa." |
|
|
elif response.status_code == 422: |
|
|
return f"Error de formato en la petición. Detalles: {response.text[:300]}" |
|
|
else: |
|
|
return f"Error del endpoint: {response.status_code}. Detalles: {response.text[:200]}" |
|
|
|
|
|
except requests.exceptions.Timeout: |
|
|
return "Timeout: El modelo está procesando, intenta de nuevo en unos segundos." |
|
|
except Exception as e: |
|
|
return f"Error de conexión: {str(e)}" |
|
|
|
|
|
def medical_chat(message, history, uploaded_image): |
|
|
"""Handle medical chat with context from pneumonia detection""" |
|
|
|
|
|
|
|
|
if uploaded_image is not None: |
|
|
context_message = f"""Eres MedGemma, un asistente médico especializado en radiología. Analiza esta radiografía de tórax y responde la siguiente pregunta médica de forma precisa y profesional: |
|
|
|
|
|
Pregunta: {message} |
|
|
|
|
|
Contexto: Esta es una radiografía de tórax para detectar neumonía. Enfócate en proporcionar información médica relevante, síntomas, diagnósticos o recomendaciones. Siempre menciona que se debe consultar a un profesional médico.""" |
|
|
else: |
|
|
context_message = f"""Eres MedGemma, un asistente médico especializado. Responde la siguiente pregunta médica de forma precisa y profesional: |
|
|
|
|
|
Pregunta: {message} |
|
|
|
|
|
Instrucciones: Proporciona información médica precisa y relevante. Si no es una pregunta médica, redirige hacia temas de salud relacionados. Siempre menciona que se debe consultar a un profesional médico para diagnósticos definitivos.""" |
|
|
|
|
|
response = query_medgemma(context_message, history, uploaded_image) |
|
|
|
|
|
|
|
|
if any(keyword in response.lower() for keyword in ["requests", "python", "código", "api", "import", "status code"]): |
|
|
response = "Como asistente médico, me especializo en temas de salud y medicina. ¿Podrías hacer una pregunta relacionada con medicina, síntomas, diagnósticos o radiografías? Estoy aquí para ayudarte con consultas médicas." |
|
|
|
|
|
|
|
|
history.append([message, response]) |
|
|
return history, "" |
|
|
|
|
|
|
|
|
pneumonia_interface = gr.Interface( |
|
|
fn=predict_pneumonia, |
|
|
inputs=gr.Image(type="pil", label="Upload Chest X-ray"), |
|
|
outputs=[ |
|
|
gr.Textbox(label="Probability of Pneumonia"), |
|
|
gr.Label(label="Prediction"), |
|
|
gr.Textbox(label="Confidence") |
|
|
], |
|
|
title="🫁 Pneumonia Detection from Chest X-rays", |
|
|
description="Upload a chest X-ray to detect signs of pneumonia using deep learning.", |
|
|
flagging_mode="never" |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# 🫁 RADOX - Sistema Inteligente de Detección de Neumonía") |
|
|
gr.Markdown("### Análisis de Radiografías + Consulta Médica con IA") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
gr.Markdown("## 🔍 Detección de Neumonía") |
|
|
input_image = gr.Image(type="pil", label="Subir Radiografía de Tórax") |
|
|
analyze_btn = gr.Button("🔬 Analizar Radiografía", variant="primary") |
|
|
|
|
|
with gr.Row(): |
|
|
prob_output = gr.Textbox(label="Probabilidad de Neumonía") |
|
|
pred_output = gr.Label(label="Diagnóstico") |
|
|
conf_output = gr.Textbox(label="Confianza") |
|
|
|
|
|
|
|
|
gr.Markdown("## 🤖 Consulta Médica con MedGemma") |
|
|
gr.Markdown("*Haz preguntas sobre la radiografía o consultas médicas generales*") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
chatbot = gr.Chatbot( |
|
|
label="Chat Médico", |
|
|
height=400, |
|
|
show_label=True |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
msg_input = gr.Textbox( |
|
|
label="Tu pregunta", |
|
|
placeholder="Ej: ¿Qué significan estos resultados? ¿Cuáles son los síntomas de neumonía?", |
|
|
scale=4 |
|
|
) |
|
|
send_btn = gr.Button("Enviar", variant="primary", scale=1) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
chat_image = gr.Image( |
|
|
type="pil", |
|
|
label="Imagen para el chat (opcional)", |
|
|
height=300 |
|
|
) |
|
|
gr.Markdown("💡 **Tip:** Puedes subir la misma radiografía aquí para hacer preguntas específicas sobre ella.") |
|
|
|
|
|
|
|
|
analyze_btn.click( |
|
|
fn=predict_pneumonia, |
|
|
inputs=[input_image], |
|
|
outputs=[prob_output, pred_output, conf_output] |
|
|
) |
|
|
|
|
|
send_btn.click( |
|
|
fn=medical_chat, |
|
|
inputs=[msg_input, chatbot, chat_image], |
|
|
outputs=[chatbot, msg_input] |
|
|
) |
|
|
|
|
|
msg_input.submit( |
|
|
fn=medical_chat, |
|
|
inputs=[msg_input, chatbot, chat_image], |
|
|
outputs=[chatbot, msg_input] |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
⚠️ **Aviso Médico Importante**: Esta herramienta es solo para fines educativos y de apoyo diagnóstico. |
|
|
Siempre consulte con un profesional médico cualificado para obtener diagnósticos y tratamientos precisos. |
|
|
""") |
|
|
|
|
|
demo.launch() |