import matplotlib.pyplot as plt import numpy as np, torch import torchvision.transforms as T import torch.nn as nn import torch.nn.functional as F import gradio as gr # Grad-CAM imports removed for simplified UI from PIL import Image import requests import os import base64 import io # Define CNN class Net(nn.Module): """Simple CNN with Batch Normalization and Dropout regularisation.""" def __init__(self) -> None: super().__init__() # Convolutional block 1 self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(16) # Convolutional block 2 self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(32) # Fully - connected head self.fc1 = nn.Linear(32 * 56 * 56, 112) self.dropout1 = nn.Dropout(0.5) self.fc2 = nn.Linear(112, 84) self.dropout2 = nn.Dropout(0.2) self.fc3 = nn.Linear(84, 2) def forward(self, x) -> torch.Tensor: # N,C,H,W """Forward pass returning raw logits (no softmax).""" c1 = self.pool(F.relu(self.bn1(self.conv1(x)))) # N,16,112,112 c2 = self.pool(F.relu(self.bn2(self.conv2(c1)))) # N,32,56,56 c2 = torch.flatten(c2, 1) # N,32*56*56 f3 = self.dropout1(F.relu(self.fc1(c2))) # N,112 f4 = self.dropout2(F.relu(self.fc2(f3))) # N,84 out = self.fc3(f4) # N,2 return out # Load pre-trained model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Net().to(device) model.load_state_dict(torch.load("best_model.pt", map_location=device)) model.eval() transform = T.Compose([T.Resize((224,224)), T.ToTensor(), T.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])]) # Simplified prediction function without Grad-CAM def predict_pneumonia(image): # Convert image to RGB img = image.convert("RGB") # HERE IS WHERE THE IMAGE ENTERS THE MODEL: # 1. Apply transforms (resize to 224x224, normalize) tensor = transform(img).unsqueeze(0).to(device) # Shape: [1, 3, 224, 224] # 2. Pass through the model with torch.no_grad(): p = torch.softmax(model(tensor), dim=1)[0,1].item() # Format results prob = f"{p:.3f}" label = f"{'PNEUMONIA' if p>0.5 else 'NORMAL'}" confidence = f"{p*100:.1f}%" if p > 0.5 else f"{(1-p)*100:.1f}%" return prob, label, confidence # MedGemma Chatbot functionality def image_to_base64(image): """Convert PIL image to base64 string""" buffer = io.BytesIO() image.save(buffer, format="JPEG") img_bytes = buffer.getvalue() img_base64 = base64.b64encode(img_bytes).decode() return f"data:image/jpeg;base64,{img_base64}" def query_medgemma(message, history, image=None): """Query MedGemma endpoint with proper multimodal format""" # Your endpoint URL endpoint_url = "https://t911ok4t5x994zcu.us-east-1.aws.endpoints.huggingface.cloud" # Headers with your HF token headers = { "Authorization": f"Bearer {os.getenv('HUGGINGFACE_TOKEN')}", "Content-Type": "application/json" } # Prepare payload based on whether we have an image or not if image is not None: # Multimodal format: Send image as base64 in the content image_base64 = image_to_base64(image) # TGI multimodal format for MedGemma with better stopping payload = { "inputs": { "text": message, "image": image_base64 }, "parameters": { "max_new_tokens": 150, # Optimized for TGI "temperature": 0.4, # Balanced for medical responses "do_sample": True, "return_full_text": False, "stop": ["", "<|im_end|>", "\n\n"], "repetition_penalty": 1.15, # Adjusted for TGI "top_p": 0.9, # Better for medical content "seed": 42 # Consistent responses for testing } } # Alternative format if the above doesn't work payload_alt = { "inputs": message, "image": image_base64, "parameters": { "max_new_tokens": 150, "temperature": 0.4, "do_sample": True, "return_full_text": False, "stop": ["", "<|im_end|>", "\n\n"], "repetition_penalty": 1.15, "top_p": 0.9, "seed": 42 } } else: # Text-only format with better parameters payload = { "inputs": message, "parameters": { "max_new_tokens": 150, "temperature": 0.4, "do_sample": True, "return_full_text": False, "stop": ["", "<|im_end|>", "\n\n"], "repetition_penalty": 1.15, "top_p": 0.9, "seed": 42 } } payload_alt = None try: # Try primary format response = requests.post(endpoint_url, headers=headers, json=payload, timeout=30) if response.status_code == 200: result = response.json() # Handle different TGI response formats if isinstance(result, list) and len(result) > 0: return result[0].get("generated_text", "Error en el formato de respuesta") elif "generated_text" in result: return result["generated_text"] else: return f"Formato de respuesta inesperado: {str(result)[:200]}" elif response.status_code == 422 and payload_alt is not None: # Try alternative format for multimodal response = requests.post(endpoint_url, headers=headers, json=payload_alt, timeout=30) if response.status_code == 200: result = response.json() if isinstance(result, list) and len(result) > 0: return result[0].get("generated_text", "Error en el formato de respuesta") elif "generated_text" in result: return result["generated_text"] else: return f"Formato de respuesta inesperado: {str(result)[:200]}" else: return f"Error 422 en ambos formatos. Detalles: {response.text[:300]}" elif response.status_code == 503: return "El modelo está escalado a cero. Intenta de nuevo en unos segundos mientras se activa." elif response.status_code == 422: return f"Error de formato en la petición. Detalles: {response.text[:300]}" else: return f"Error del endpoint: {response.status_code}. Detalles: {response.text[:200]}" except requests.exceptions.Timeout: return "Timeout: El modelo está procesando, intenta de nuevo en unos segundos." except Exception as e: return f"Error de conexión: {str(e)}" def medical_chat(message, history, uploaded_image): """Handle medical chat with context from pneumonia detection""" # Always add medical context to keep the model focused if uploaded_image is not None: context_message = f"""Eres MedGemma, un asistente médico especializado en radiología. Analiza esta radiografía de tórax y responde la siguiente pregunta médica de forma precisa y profesional: Pregunta: {message} Contexto: Esta es una radiografía de tórax para detectar neumonía. Enfócate en proporcionar información médica relevante, síntomas, diagnósticos o recomendaciones. Siempre menciona que se debe consultar a un profesional médico.""" else: context_message = f"""Eres MedGemma, un asistente médico especializado. Responde la siguiente pregunta médica de forma precisa y profesional: Pregunta: {message} Instrucciones: Proporciona información médica precisa y relevante. Si no es una pregunta médica, redirige hacia temas de salud relacionados. Siempre menciona que se debe consultar a un profesional médico para diagnósticos definitivos.""" response = query_medgemma(context_message, history, uploaded_image) # Clean the response if it contains non-medical content if any(keyword in response.lower() for keyword in ["requests", "python", "código", "api", "import", "status code"]): response = "Como asistente médico, me especializo en temas de salud y medicina. ¿Podrías hacer una pregunta relacionada con medicina, síntomas, diagnósticos o radiografías? Estoy aquí para ayudarte con consultas médicas." # Add the exchange to history history.append([message, response]) return history, "" # Create the main pneumonia detection interface pneumonia_interface = gr.Interface( fn=predict_pneumonia, inputs=gr.Image(type="pil", label="Upload Chest X-ray"), outputs=[ gr.Textbox(label="Probability of Pneumonia"), gr.Label(label="Prediction"), gr.Textbox(label="Confidence") ], title="🫁 Pneumonia Detection from Chest X-rays", description="Upload a chest X-ray to detect signs of pneumonia using deep learning.", flagging_mode="never" ) # Create the MedGemma chatbot interface with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 🫁 RADOX - Sistema Inteligente de Detección de Neumonía") gr.Markdown("### Análisis de Radiografías + Consulta Médica con IA") with gr.Row(): with gr.Column(scale=1): # Pneumonia Detection Section gr.Markdown("## 🔍 Detección de Neumonía") input_image = gr.Image(type="pil", label="Subir Radiografía de Tórax") analyze_btn = gr.Button("🔬 Analizar Radiografía", variant="primary") with gr.Row(): prob_output = gr.Textbox(label="Probabilidad de Neumonía") pred_output = gr.Label(label="Diagnóstico") conf_output = gr.Textbox(label="Confianza") # Medical Chatbot Section gr.Markdown("## 🤖 Consulta Médica con MedGemma") gr.Markdown("*Haz preguntas sobre la radiografía o consultas médicas generales*") with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot( label="Chat Médico", height=400, show_label=True ) with gr.Row(): msg_input = gr.Textbox( label="Tu pregunta", placeholder="Ej: ¿Qué significan estos resultados? ¿Cuáles son los síntomas de neumonía?", scale=4 ) send_btn = gr.Button("Enviar", variant="primary", scale=1) with gr.Column(scale=1): chat_image = gr.Image( type="pil", label="Imagen para el chat (opcional)", height=300 ) gr.Markdown("💡 **Tip:** Puedes subir la misma radiografía aquí para hacer preguntas específicas sobre ella.") # Event handlers analyze_btn.click( fn=predict_pneumonia, inputs=[input_image], outputs=[prob_output, pred_output, conf_output] ) send_btn.click( fn=medical_chat, inputs=[msg_input, chatbot, chat_image], outputs=[chatbot, msg_input] ) msg_input.submit( fn=medical_chat, inputs=[msg_input, chatbot, chat_image], outputs=[chatbot, msg_input] ) # Footer gr.Markdown(""" --- ⚠️ **Aviso Médico Importante**: Esta herramienta es solo para fines educativos y de apoyo diagnóstico. Siempre consulte con un profesional médico cualificado para obtener diagnósticos y tratamientos precisos. """) demo.launch()