RADOX / app.py
mbrq13's picture
Add pneumonia detection app with Grad-CAM
d3faaea
import matplotlib.pyplot as plt
import numpy as np, torch
import torchvision.transforms as T
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
# Grad-CAM imports removed for simplified UI
from PIL import Image
import requests
import os
import base64
import io
# Define CNN
class Net(nn.Module):
"""Simple CNN with Batch Normalization and Dropout regularisation."""
def __init__(self) -> None:
super().__init__()
# Convolutional block 1
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(16)
# Convolutional block 2
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(32)
# Fully - connected head
self.fc1 = nn.Linear(32 * 56 * 56, 112)
self.dropout1 = nn.Dropout(0.5)
self.fc2 = nn.Linear(112, 84)
self.dropout2 = nn.Dropout(0.2)
self.fc3 = nn.Linear(84, 2)
def forward(self, x) -> torch.Tensor: # N,C,H,W
"""Forward pass returning raw logits (no softmax)."""
c1 = self.pool(F.relu(self.bn1(self.conv1(x)))) # N,16,112,112
c2 = self.pool(F.relu(self.bn2(self.conv2(c1)))) # N,32,56,56
c2 = torch.flatten(c2, 1) # N,32*56*56
f3 = self.dropout1(F.relu(self.fc1(c2))) # N,112
f4 = self.dropout2(F.relu(self.fc2(f3))) # N,84
out = self.fc3(f4) # N,2
return out
# Load pre-trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Net().to(device)
model.load_state_dict(torch.load("best_model.pt", map_location=device))
model.eval()
transform = T.Compose([T.Resize((224,224)),
T.ToTensor(),
T.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
# Simplified prediction function without Grad-CAM
def predict_pneumonia(image):
# Convert image to RGB
img = image.convert("RGB")
# HERE IS WHERE THE IMAGE ENTERS THE MODEL:
# 1. Apply transforms (resize to 224x224, normalize)
tensor = transform(img).unsqueeze(0).to(device) # Shape: [1, 3, 224, 224]
# 2. Pass through the model
with torch.no_grad():
p = torch.softmax(model(tensor), dim=1)[0,1].item()
# Format results
prob = f"{p:.3f}"
label = f"{'PNEUMONIA' if p>0.5 else 'NORMAL'}"
confidence = f"{p*100:.1f}%" if p > 0.5 else f"{(1-p)*100:.1f}%"
return prob, label, confidence
# MedGemma Chatbot functionality
def image_to_base64(image):
"""Convert PIL image to base64 string"""
buffer = io.BytesIO()
image.save(buffer, format="JPEG")
img_bytes = buffer.getvalue()
img_base64 = base64.b64encode(img_bytes).decode()
return f"data:image/jpeg;base64,{img_base64}"
def query_medgemma(message, history, image=None):
"""Query MedGemma endpoint with proper multimodal format"""
# Your endpoint URL
endpoint_url = "https://t911ok4t5x994zcu.us-east-1.aws.endpoints.huggingface.cloud"
# Headers with your HF token
headers = {
"Authorization": f"Bearer {os.getenv('HUGGINGFACE_TOKEN')}",
"Content-Type": "application/json"
}
# Prepare payload based on whether we have an image or not
if image is not None:
# Multimodal format: Send image as base64 in the content
image_base64 = image_to_base64(image)
# TGI multimodal format for MedGemma with better stopping
payload = {
"inputs": {
"text": message,
"image": image_base64
},
"parameters": {
"max_new_tokens": 150, # Optimized for TGI
"temperature": 0.4, # Balanced for medical responses
"do_sample": True,
"return_full_text": False,
"stop": ["</s>", "<|im_end|>", "\n\n"],
"repetition_penalty": 1.15, # Adjusted for TGI
"top_p": 0.9, # Better for medical content
"seed": 42 # Consistent responses for testing
}
}
# Alternative format if the above doesn't work
payload_alt = {
"inputs": message,
"image": image_base64,
"parameters": {
"max_new_tokens": 150,
"temperature": 0.4,
"do_sample": True,
"return_full_text": False,
"stop": ["</s>", "<|im_end|>", "\n\n"],
"repetition_penalty": 1.15,
"top_p": 0.9,
"seed": 42
}
}
else:
# Text-only format with better parameters
payload = {
"inputs": message,
"parameters": {
"max_new_tokens": 150,
"temperature": 0.4,
"do_sample": True,
"return_full_text": False,
"stop": ["</s>", "<|im_end|>", "\n\n"],
"repetition_penalty": 1.15,
"top_p": 0.9,
"seed": 42
}
}
payload_alt = None
try:
# Try primary format
response = requests.post(endpoint_url, headers=headers, json=payload, timeout=30)
if response.status_code == 200:
result = response.json()
# Handle different TGI response formats
if isinstance(result, list) and len(result) > 0:
return result[0].get("generated_text", "Error en el formato de respuesta")
elif "generated_text" in result:
return result["generated_text"]
else:
return f"Formato de respuesta inesperado: {str(result)[:200]}"
elif response.status_code == 422 and payload_alt is not None:
# Try alternative format for multimodal
response = requests.post(endpoint_url, headers=headers, json=payload_alt, timeout=30)
if response.status_code == 200:
result = response.json()
if isinstance(result, list) and len(result) > 0:
return result[0].get("generated_text", "Error en el formato de respuesta")
elif "generated_text" in result:
return result["generated_text"]
else:
return f"Formato de respuesta inesperado: {str(result)[:200]}"
else:
return f"Error 422 en ambos formatos. Detalles: {response.text[:300]}"
elif response.status_code == 503:
return "El modelo está escalado a cero. Intenta de nuevo en unos segundos mientras se activa."
elif response.status_code == 422:
return f"Error de formato en la petición. Detalles: {response.text[:300]}"
else:
return f"Error del endpoint: {response.status_code}. Detalles: {response.text[:200]}"
except requests.exceptions.Timeout:
return "Timeout: El modelo está procesando, intenta de nuevo en unos segundos."
except Exception as e:
return f"Error de conexión: {str(e)}"
def medical_chat(message, history, uploaded_image):
"""Handle medical chat with context from pneumonia detection"""
# Always add medical context to keep the model focused
if uploaded_image is not None:
context_message = f"""Eres MedGemma, un asistente médico especializado en radiología. Analiza esta radiografía de tórax y responde la siguiente pregunta médica de forma precisa y profesional:
Pregunta: {message}
Contexto: Esta es una radiografía de tórax para detectar neumonía. Enfócate en proporcionar información médica relevante, síntomas, diagnósticos o recomendaciones. Siempre menciona que se debe consultar a un profesional médico."""
else:
context_message = f"""Eres MedGemma, un asistente médico especializado. Responde la siguiente pregunta médica de forma precisa y profesional:
Pregunta: {message}
Instrucciones: Proporciona información médica precisa y relevante. Si no es una pregunta médica, redirige hacia temas de salud relacionados. Siempre menciona que se debe consultar a un profesional médico para diagnósticos definitivos."""
response = query_medgemma(context_message, history, uploaded_image)
# Clean the response if it contains non-medical content
if any(keyword in response.lower() for keyword in ["requests", "python", "código", "api", "import", "status code"]):
response = "Como asistente médico, me especializo en temas de salud y medicina. ¿Podrías hacer una pregunta relacionada con medicina, síntomas, diagnósticos o radiografías? Estoy aquí para ayudarte con consultas médicas."
# Add the exchange to history
history.append([message, response])
return history, ""
# Create the main pneumonia detection interface
pneumonia_interface = gr.Interface(
fn=predict_pneumonia,
inputs=gr.Image(type="pil", label="Upload Chest X-ray"),
outputs=[
gr.Textbox(label="Probability of Pneumonia"),
gr.Label(label="Prediction"),
gr.Textbox(label="Confidence")
],
title="🫁 Pneumonia Detection from Chest X-rays",
description="Upload a chest X-ray to detect signs of pneumonia using deep learning.",
flagging_mode="never"
)
# Create the MedGemma chatbot interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🫁 RADOX - Sistema Inteligente de Detección de Neumonía")
gr.Markdown("### Análisis de Radiografías + Consulta Médica con IA")
with gr.Row():
with gr.Column(scale=1):
# Pneumonia Detection Section
gr.Markdown("## 🔍 Detección de Neumonía")
input_image = gr.Image(type="pil", label="Subir Radiografía de Tórax")
analyze_btn = gr.Button("🔬 Analizar Radiografía", variant="primary")
with gr.Row():
prob_output = gr.Textbox(label="Probabilidad de Neumonía")
pred_output = gr.Label(label="Diagnóstico")
conf_output = gr.Textbox(label="Confianza")
# Medical Chatbot Section
gr.Markdown("## 🤖 Consulta Médica con MedGemma")
gr.Markdown("*Haz preguntas sobre la radiografía o consultas médicas generales*")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(
label="Chat Médico",
height=400,
show_label=True
)
with gr.Row():
msg_input = gr.Textbox(
label="Tu pregunta",
placeholder="Ej: ¿Qué significan estos resultados? ¿Cuáles son los síntomas de neumonía?",
scale=4
)
send_btn = gr.Button("Enviar", variant="primary", scale=1)
with gr.Column(scale=1):
chat_image = gr.Image(
type="pil",
label="Imagen para el chat (opcional)",
height=300
)
gr.Markdown("💡 **Tip:** Puedes subir la misma radiografía aquí para hacer preguntas específicas sobre ella.")
# Event handlers
analyze_btn.click(
fn=predict_pneumonia,
inputs=[input_image],
outputs=[prob_output, pred_output, conf_output]
)
send_btn.click(
fn=medical_chat,
inputs=[msg_input, chatbot, chat_image],
outputs=[chatbot, msg_input]
)
msg_input.submit(
fn=medical_chat,
inputs=[msg_input, chatbot, chat_image],
outputs=[chatbot, msg_input]
)
# Footer
gr.Markdown("""
---
⚠️ **Aviso Médico Importante**: Esta herramienta es solo para fines educativos y de apoyo diagnóstico.
Siempre consulte con un profesional médico cualificado para obtener diagnósticos y tratamientos precisos.
""")
demo.launch()