Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| POC RMMM - Automatic Medical Report Generation with Ground Truth Comparison & Evaluation Metrics | |
| This application provides a Gradio interface for generating medical reports from X-ray images | |
| using the RMMM PyTorch model, with automatic evaluation metrics (BLEU-4, ROUGE-L) | |
| to compare against ground truth reports. | |
| """ | |
| # Standard library imports | |
| import asyncio | |
| import hashlib | |
| import json | |
| import os | |
| import pickle | |
| import random | |
| import re | |
| import sys | |
| import time | |
| import traceback | |
| import warnings | |
| from typing import Dict, List, Union | |
| # Third-party imports | |
| import gradio as gr | |
| import nltk | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from rouge import Rouge | |
| from sacrebleu import BLEU | |
| from transformers import GPT2Tokenizer | |
| # Configuration and warnings | |
| warnings.filterwarnings("ignore", message=".*trust_remote_code.*") | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| # Environment variables | |
| os.environ["HF_DATASETS_OFFLINE"] = "1" | |
| os.environ["TRANSFORMERS_OFFLINE"] = "0" # Allow limited online access for core models | |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "2" | |
| os.environ["CUDA_LAUNCH_BLOCKING"] = "1" | |
| # Download NLTK data if needed | |
| try: | |
| nltk.data.find('tokenizers/punkt') | |
| except LookupError: | |
| nltk.download('punkt', quiet=True) | |
| # Global constants | |
| DEVICE = torch.device('cpu') # Force CPU for compatibility | |
| print(f"🖥️ Using device: {DEVICE}") | |
| AUTH_TOKEN = os.getenv("auth_token") | |
| MAX_LENGTH = 100 | |
| NUM_BEAMS = 4 | |
| GEN_KWARGS = {"max_length": MAX_LENGTH, "num_beams": NUM_BEAMS} | |
| # Load MIMIC dataset from JSON file | |
| def load_mimic_data() -> Dict[str, str]: | |
| """Load MIMIC dataset from JSON file. | |
| Returns: | |
| Dict[str, str]: Dictionary mapping image IDs to ground truth reports | |
| """ | |
| json_path = "./data/sample_mimic_test.json" | |
| if not os.path.exists(json_path): | |
| print(f"Warning: {json_path} not found. Using empty dataset.") | |
| return {} | |
| try: | |
| with open(json_path, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| # Convert to dictionary format for easy lookup by image ID | |
| ground_truth_reports = {} | |
| for item in data.get('sample_data', []): | |
| image_id = item.get('id') | |
| report = item.get('report', 'No report available.') | |
| if image_id: | |
| ground_truth_reports[image_id] = report | |
| print(f"Loaded {len(ground_truth_reports)} ground truth reports from MIMIC dataset") | |
| return ground_truth_reports | |
| except Exception as e: | |
| print(f"Error loading MIMIC data: {e}") | |
| return {} | |
| # Load ground truth reports from JSON file | |
| GROUND_TRUTH_REPORTS = load_mimic_data() | |
| # Load model globally at startup | |
| def load_rmmm_model(): | |
| """Load RMMM model once at startup""" | |
| try: | |
| model_path = "./rmmm/rmmm_mimic_cut.pt" | |
| if not os.path.exists(model_path): | |
| print(f"❌ Model not found: {model_path}") | |
| return None | |
| print(f"🤖 Loading RMMM model from: {model_path}") | |
| print(f"🖥️ Target device: {DEVICE}") | |
| # Load model with explicit CPU mapping | |
| print("Loading model with CPU mapping...") | |
| scripted_model = torch.jit.load(model_path, map_location='cpu') | |
| scripted_model.eval() | |
| # Ensure all parameters are on CPU | |
| print("Moving all parameters to CPU...") | |
| scripted_model = scripted_model.cpu() | |
| # Verify device placement | |
| print("Verifying model device placement...") | |
| for param in scripted_model.parameters(): | |
| if param.device != torch.device('cpu'): | |
| print(f"⚠️ Found parameter on {param.device}, moving to CPU") | |
| param.data = param.data.cpu() | |
| print(f"✅ RMMM model loaded successfully on CPU") | |
| # Test model with dummy input to verify CPU compatibility | |
| print("Testing model with dummy input...") | |
| dummy_input = torch.randn(1, 3, 224, 224, device='cpu') | |
| with torch.no_grad(): | |
| try: | |
| _ = scripted_model(dummy_input) | |
| print("✅ Model CPU compatibility test passed") | |
| except Exception as test_error: | |
| print(f"⚠️ Model compatibility test failed: {test_error}") | |
| # Try to recover by ensuring all buffers are also on CPU | |
| for buffer in scripted_model.buffers(): | |
| if buffer.device != torch.device('cpu'): | |
| buffer.data = buffer.data.cpu() | |
| print("Retrying after moving buffers to CPU...") | |
| _ = scripted_model(dummy_input) | |
| print("✅ Model CPU compatibility test passed after buffer fix") | |
| return scripted_model | |
| except Exception as e: | |
| print(f"❌ Error loading RMMM model: {e}") | |
| traceback.print_exc() | |
| return None | |
| # Load tokenizer globally at startup | |
| def load_mimic_tokenizer(): | |
| """Load MIMIC tokenizer once at startup""" | |
| try: | |
| cache_file = "./rmmm/tokenizer_cache/tokenizer.pkl" | |
| with open(cache_file, 'rb') as f: | |
| tokenizer_data = pickle.load(f) | |
| idx2token = tokenizer_data['idx2token'] | |
| print(f"✅ Custom MIMIC tokenizer loaded with vocab size: {len(idx2token)}") | |
| return idx2token | |
| except Exception as e: | |
| print(f"⚠️ Failed to load custom tokenizer: {e}") | |
| return None | |
| # Global model and tokenizer instances | |
| print("🚀 Initializing RMMM application...") | |
| RMMM_MODEL = load_rmmm_model() | |
| MIMIC_TOKENIZER = load_mimic_tokenizer() | |
| def get_available_image_paths() -> List[str]: | |
| """Get list of available image paths based on MIMIC JSON data. | |
| Returns: | |
| List[str]: List of available image file paths | |
| """ | |
| json_path = "./images/sample_mimic_test.json" | |
| if not os.path.exists(json_path): | |
| # Fallback to scanning images directory | |
| images_dir = "./images" | |
| if os.path.exists(images_dir): | |
| return [os.path.join(images_dir, f) for f in os.listdir(images_dir) | |
| if f.lower().endswith(('.jpg', '.jpeg', '.png'))] | |
| return [] | |
| try: | |
| with open(json_path, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| image_paths = [] | |
| for item in data.get('sample_data', []): | |
| image_id = item.get('id') | |
| if image_id: | |
| # Try different possible image extensions | |
| for ext in ['.jpg', '.jpeg', '.png']: | |
| image_path = f"./images/{image_id}{ext}" | |
| if os.path.exists(image_path): | |
| image_paths.append(image_path) | |
| break | |
| print(f"Found {len(image_paths)} available images from MIMIC dataset") | |
| return image_paths | |
| except Exception as e: | |
| print(f"Error loading image paths: {e}") | |
| return [] | |
| def preprocess_text_for_metrics(text: str) -> str: | |
| """Preprocess text for metric calculation. | |
| Args: | |
| text (str): Raw text to preprocess | |
| Returns: | |
| str: Cleaned and preprocessed text | |
| """ | |
| if not text or text.strip() == "": | |
| return "" | |
| # Remove markdown formatting and special characters | |
| text = re.sub(r'\*\*(.*?)\*\*', r'\1', text) # Remove **bold** | |
| text = re.sub(r'[📊📋🤖🩻]', '', text) # Remove emojis | |
| # Extract only the main report content, ignore metadata | |
| lines = text.split('\n') | |
| report_lines = [] | |
| in_report = False | |
| for line in lines: | |
| line = line.strip() | |
| if 'RADIOLOGIST REPORT:' in line or 'IMPRESSION:' in line or 'FINDINGS:' in line: | |
| in_report = True | |
| continue | |
| elif line.startswith('**') and ':' in line: | |
| # Skip metadata lines like **Study ID:** etc | |
| continue | |
| elif in_report and line: | |
| report_lines.append(line) | |
| # If no structured report found, use the whole cleaned text | |
| if not report_lines: | |
| report_lines = [line.strip() for line in lines if line.strip() and not line.startswith('**')] | |
| result = ' '.join(report_lines).strip() | |
| # Additional cleaning | |
| result = re.sub(r'\s+', ' ', result) # Multiple spaces to single | |
| result = re.sub(r'[^\w\s\.\,\;\:\-\(\)]', '', result) # Keep only basic punctuation | |
| return result | |
| def calculate_evaluation_metrics(prediction: str, ground_truth: str) -> Dict[str, Union[float, str, None]]: | |
| """Calculate BLEU-4 and ROUGE-L metrics. | |
| Args: | |
| prediction (str): Generated prediction text | |
| ground_truth (str): Reference ground truth text | |
| Returns: | |
| Dict[str, Union[float, str, None]]: Dictionary containing metric scores and error info | |
| """ | |
| if not prediction or not ground_truth: | |
| return { | |
| 'bleu4_score': 0.0, | |
| 'rougeL_f': 0.0, | |
| 'error': 'Empty prediction or ground truth' | |
| } | |
| try: | |
| # Preprocess texts | |
| pred_clean = preprocess_text_for_metrics(prediction) | |
| gt_clean = preprocess_text_for_metrics(ground_truth) | |
| # Apply lowercase for better comparison | |
| pred_clean = pred_clean.lower() | |
| gt_clean = gt_clean.lower() | |
| if not pred_clean or not gt_clean: | |
| return { | |
| 'bleu4_score': 0.0, | |
| 'rougeL_f': 0.0, | |
| 'error': 'Empty text after preprocessing' | |
| } | |
| # Calculate BLEU-4 score | |
| try: | |
| bleu = BLEU() | |
| # BLEU-4 expects list of references and hypothesis | |
| bleu4_score = bleu.sentence_score(pred_clean, [gt_clean]).score / 100.0 # Convert to 0-1 range | |
| except Exception as e: | |
| print(f"BLEU-4 calculation error: {e}") | |
| bleu4_score = 0.0 | |
| # Calculate ROUGE-L score only | |
| try: | |
| rouge = Rouge() | |
| rouge_scores = rouge.get_scores(pred_clean, gt_clean) | |
| rougeL_f = rouge_scores[0]['rouge-l']['f'] | |
| except Exception as e: | |
| print(f"ROUGE-L calculation error: {e}") | |
| rougeL_f = 0.0 | |
| return { | |
| 'bleu4_score': round(bleu4_score, 4), | |
| 'rougeL_f': round(rougeL_f, 4), | |
| 'error': None | |
| } | |
| except Exception as e: | |
| return { | |
| 'bleu4_score': 0.0, | |
| 'rougeL_f': 0.0, | |
| 'error': f'Metric calculation error: {str(e)}' | |
| } | |
| def format_metrics_display(metrics: Dict[str, Union[float, str, None]]) -> str: | |
| """Format metrics for display with modern HTML styling. | |
| Args: | |
| metrics (Dict[str, Union[float, str, None]]): Dictionary containing metric scores | |
| Returns: | |
| str: Formatted metrics display string with HTML | |
| """ | |
| if metrics.get('error'): | |
| return f""" | |
| <div class="info-card"> | |
| <h2>⚠️ Erro nas Métricas</h2> | |
| <p style="color: #e53e3e;">{metrics['error']}</p> | |
| </div> | |
| """ | |
| # Determine performance levels and colors | |
| bleu_score = metrics['bleu4_score'] | |
| rouge_score = metrics['rougeL_f'] | |
| def get_performance_badge(score, metric_name): | |
| if score > 0.3: | |
| return f'<span style="background: #22543d; color: white; padding: 6px 14px; border-radius: 20px; font-size: 0.85rem; font-weight: 600;">🟢 Bom</span>' | |
| elif score > 0.1: | |
| return f'<span style="background: #b7791f; color: white; padding: 6px 14px; border-radius: 20px; font-size: 0.85rem; font-weight: 600;">🟡 Regular</span>' | |
| else: | |
| return f'<span style="background: #c53030; color: white; padding: 6px 14px; border-radius: 20px; font-size: 0.85rem; font-weight: 600;">🔴 Baixo</span>' | |
| return f""" | |
| <div class="metrics-card" style="background: linear-gradient(135deg, #f7fafc 0%, #edf2f7 100%); border: 2px solid #4299e1; border-radius: 12px; padding: 1.2rem;"> | |
| <h2 style="color: #1a365d; margin-bottom: 1rem; font-size: 1.3rem; font-weight: 700;">📊 Métricas de Avaliação</h2> | |
| <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 1rem; margin-bottom: 0.8rem;"> | |
| <div style="background: white; padding: 1rem; border-radius: 8px; box-shadow: 0 2px 8px rgba(0,0,0,0.08); border: 1px solid #e2e8f0;"> | |
| <h3 style="color: #1e3a8a; margin: 0 0 0.3rem 0; font-size: 1rem; font-weight: 600;">BLEU-4 Score</h3> | |
| <div style="font-size: 1.8rem; font-weight: 700; color: #1e40af; margin-bottom: 0.6rem;">{bleu_score:.4f}</div> | |
| {get_performance_badge(bleu_score, 'BLEU-4')} | |
| </div> | |
| <div style="background: white; padding: 1rem; border-radius: 8px; box-shadow: 0 2px 8px rgba(0,0,0,0.08); border: 1px solid #e2e8f0;"> | |
| <h3 style="color: #6b21a8; margin: 0 0 0.3rem 0; font-size: 1rem; font-weight: 600;">ROUGE-L F1</h3> | |
| <div style="font-size: 1.8rem; font-weight: 700; color: #7c3aed; margin-bottom: 0.6rem;">{rouge_score:.4f}</div> | |
| {get_performance_badge(rouge_score, 'ROUGE-L')} | |
| </div> | |
| </div> | |
| <div style="background: white; padding: 0.8rem; border-radius: 8px; margin-top: 0.8rem; border: 1px solid #e2e8f0;"> | |
| <p style="margin: 0; font-size: 0.9rem; color: #1a202c; font-weight: 500;"> | |
| <h3 style="color: #6b21a8; margin: 0 0 0.3rem 0; font-size: 1rem; font-weight: 600;">💡 Pontuações mais altas indicam maior similaridade com o relatório ground truth</h3> | |
| </p> | |
| </div> | |
| </div> | |
| """ | |
| def inference_torch_model_fast(image_input): | |
| """Run inference with pre-loaded PyTorch model (fast version)""" | |
| try: | |
| # Check if model is loaded | |
| if RMMM_MODEL is None: | |
| return "❌ Erro: Modelo RMMM não foi carregado corretamente na inicialização." | |
| print(f"🔄 Running inference with pre-loaded RMMM model") | |
| print(f"🖼️ Image input type: {type(image_input)}") | |
| # Handle different input types | |
| if isinstance(image_input, str): | |
| # If it's a string, it's a file path - load it with Pillow | |
| print(f"📁 Loading image from path: {image_input}") | |
| image_input = Image.open(image_input) | |
| elif isinstance(image_input, np.ndarray): | |
| # If it's a numpy array, convert to PIL Image | |
| print(f"🔢 Converting numpy array to PIL Image") | |
| if image_input.dtype != np.uint8: | |
| image_input = (image_input * 255).astype(np.uint8) | |
| image_input = Image.fromarray(image_input) | |
| elif hasattr(image_input, 'mode'): | |
| # Already a PIL Image | |
| print(f"🖼️ Already a PIL Image") | |
| else: | |
| print(f"⚠️ Unknown image input type: {type(image_input)}") | |
| if image_input.mode != "RGB": | |
| image_input = image_input.convert("RGB") | |
| print(f"✅ Image loaded successfully: {image_input.size}") | |
| image_input = image_input.resize((224, 224)) | |
| image_array = np.array(image_input).astype(np.float32) | |
| # Debug: Print image statistics to verify different images | |
| print(f"Image stats - Mean: {image_array.mean():.4f}, Std: {image_array.std():.4f}") | |
| print(f"Image range - Min: {image_array.min():.4f}, Max: {image_array.max():.4f}") | |
| # Calculate hash to verify different images | |
| import hashlib | |
| image_hash = hashlib.md5(image_array.tobytes()).hexdigest()[:8] | |
| print(f"Image hash (first 8 chars): {image_hash}") | |
| # Apply ImageNet normalization (matching training preprocessing) | |
| image_array = image_array / 255.0 | |
| # ImageNet mean and std values | |
| mean = np.array([0.485, 0.456, 0.406]) | |
| std = np.array([0.229, 0.224, 0.225]) | |
| # Apply normalization per channel | |
| for i in range(3): | |
| image_array[:, :, i] = (image_array[:, :, i] - mean[i]) / std[i] | |
| image_array = np.transpose(image_array, (2, 0, 1)) # HWC -> CHW | |
| image_tensor = torch.tensor(image_array, dtype=torch.float32, device='cpu').unsqueeze(0) | |
| print(f"Input tensor shape: {image_tensor.shape}") | |
| print(f"Input tensor device: {image_tensor.device}") | |
| print(f"Input tensor stats - Mean: {image_tensor.mean():.4f}, Std: {image_tensor.std():.4f}") | |
| print(f"Input tensor range - Min: {image_tensor.min():.4f}, Max: {image_tensor.max():.4f}") | |
| # Ensure model is on CPU before inference | |
| if hasattr(RMMM_MODEL, 'cpu'): | |
| RMMM_MODEL.cpu() | |
| # Executar inferência (modelo já carregado) | |
| with torch.no_grad(): | |
| outputs = RMMM_MODEL(image_tensor) | |
| # Ensure outputs are on CPU | |
| outputs = outputs.cpu() | |
| print(f"Model output shape: {outputs.shape}") | |
| print(f"Model output device: {outputs.device}") | |
| print(f"Model output dtype: {outputs.dtype}") | |
| # Handle different output types for stats | |
| if outputs.dtype in [torch.float32, torch.float64]: | |
| print(f"Output stats - Mean: {outputs.mean():.4f}, Std: {outputs.std():.4f}") | |
| print(f"Output variance: {outputs.var():.6f}") | |
| else: | |
| # For integer outputs (token IDs), show basic stats | |
| print(f"Output stats - Min: {outputs.min()}, Max: {outputs.max()}") | |
| print(f"Output unique values: {len(torch.unique(outputs))}") | |
| # Debug: Print raw outputs to see if they vary | |
| if len(outputs.shape) >= 2: | |
| print(f"First few output values: {outputs.flatten()[:10]}") | |
| else: | |
| print(f"Output values: {outputs[:10] if len(outputs) > 10 else outputs}") | |
| # Processar a saída - Check if outputs are probabilities, logits, or token IDs | |
| if len(outputs.shape) == 3: | |
| # If 3D (batch, seq_len, vocab_size), take argmax | |
| print("Processing 3D output (batch, seq_len, vocab_size)") | |
| token_ids = torch.argmax(outputs, dim=-1) | |
| elif len(outputs.shape) == 2: | |
| # If 2D, check the dtype and values to determine if it's token IDs or logits | |
| if outputs.dtype in [torch.long, torch.int32, torch.int64]: | |
| print("Processing 2D output as token IDs (integer dtype)") | |
| token_ids = outputs | |
| elif outputs.max() > 1000: # Likely token IDs already | |
| print("Processing 2D output as token IDs (high values)") | |
| token_ids = outputs.long() | |
| else: | |
| print("Processing 2D output as logits, taking argmax") | |
| # Treat as logits and take argmax | |
| token_ids = torch.argmax(outputs, dim=-1) | |
| elif len(outputs.shape) == 1: | |
| # If 1D, likely already token IDs | |
| print("Processing 1D output as token IDs") | |
| token_ids = outputs | |
| else: | |
| print(f"Unexpected output shape: {outputs.shape}") | |
| token_ids = outputs | |
| # Remover dimensão do batch se necessário | |
| if len(token_ids.shape) == 2: | |
| token_ids = token_ids[0] | |
| token_ids = token_ids.cpu().numpy().astype(np.int32) | |
| print(f"Token IDs shape: {token_ids.shape}") | |
| print(f"Token IDs sample: {token_ids[:10]}") | |
| print(f"Token IDs unique count: {len(np.unique(token_ids))}") # Check diversity | |
| print(f"Token IDs shape: {token_ids.shape}") | |
| print(f"Token IDs sample: {token_ids[:10]}") | |
| # Decodificar usando tokenizer pré-carregado | |
| if MIMIC_TOKENIZER is not None: | |
| # Usar tokenizer customizado MIMIC | |
| tokens = [] | |
| for token_id in token_ids: | |
| if token_id == 0: # End token | |
| break | |
| if token_id in MIMIC_TOKENIZER: | |
| tokens.append(MIMIC_TOKENIZER[token_id]) | |
| decoded_text = ' '.join(tokens).strip() | |
| print(f"✅ Used custom MIMIC tokenizer") | |
| else: | |
| # Fallback para GPT-2 | |
| print(f"⚠️ Using GPT-2 fallback tokenizer") | |
| tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Limitar token_ids ao tamanho do vocabulário do GPT-2 | |
| token_ids = np.clip(token_ids, 0, tokenizer.vocab_size - 1) | |
| decoded_text = tokenizer.decode(token_ids, skip_special_tokens=True).strip() | |
| print(f"Decoded text length: {len(decoded_text)}") | |
| print(f"Decoded text preview: {decoded_text[:100]}...") | |
| # Se o texto estiver muito curto, usar fallback | |
| if len(decoded_text) < 10: | |
| decoded_text = ( | |
| f"Medical Report - RMMM Model:\n\n" | |
| f"Chest X-ray analysis completed using PyTorch model. " | |
| f"The radiological examination has been processed successfully.\n\n" | |
| f"Model: rmmm_mimic_cut.pt\n" | |
| f"Status: Processing completed" | |
| ) | |
| return decoded_text | |
| except Exception as e: | |
| error_msg = f"❌ Erro ao processar com o modelo RMMM: {str(e)}" | |
| print(error_msg) | |
| traceback.print_exc() | |
| return error_msg | |
| def get_ground_truth_from_filename(selected_image_filename): | |
| """Get ground truth report from memorized filename""" | |
| if not selected_image_filename: | |
| return "Ground truth not available." | |
| # Extract ID from filename | |
| filename = os.path.basename(selected_image_filename) | |
| image_id = filename.replace('.jpg', '').replace('.jpeg', '').replace('.png', '') | |
| print(f"Debug - selected_image_filename: {selected_image_filename}") | |
| print(f"Debug - extracted image_id: {image_id}") | |
| # Check if we have ground truth for this image | |
| if image_id and image_id in GROUND_TRUTH_REPORTS: | |
| report = GROUND_TRUTH_REPORTS[image_id] | |
| # Return only the clean report text without metadata | |
| return report.strip() | |
| return ( | |
| f"Ground truth not available for this image (ID: {image_id}). " | |
| f"Upload one of the example images to see ground truth comparison." | |
| ) | |
| def inference_image_pipe_with_state(image_input, selected_image_filename): | |
| """Main inference function that uses memorized filename for ground truth""" | |
| # Get ground truth report from memorized filename | |
| ground_truth = get_ground_truth_from_filename(selected_image_filename) | |
| # Generate prediction using pre-loaded RMMM model | |
| prediction = inference_torch_model_fast(image_input) | |
| # Calculate evaluation metrics if both prediction and ground truth are available | |
| metrics_display = "" | |
| if (prediction and ground_truth and "Ground truth not available" not in ground_truth): | |
| metrics = calculate_evaluation_metrics(prediction, ground_truth) | |
| metrics_display = format_metrics_display(metrics) | |
| return prediction, ground_truth, metrics_display | |
| with gr.Blocks( | |
| title="XRaySwinGen / RMMM - AI Medical Report Generator", | |
| theme=gr.themes.Soft( | |
| primary_hue="blue", | |
| secondary_hue="gray", | |
| neutral_hue="slate", | |
| font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"] | |
| ), | |
| css=""" | |
| /* Import Google Fonts */ | |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap'); | |
| /* Global styles */ | |
| .gradio-container { | |
| max-width: 1400px !important; | |
| margin: 0 auto !important; | |
| font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif !important; | |
| padding: 1rem !important; | |
| } | |
| /* Header styling - mais compacto */ | |
| .main-header { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| color: white !important; | |
| padding: 1.5rem !important; | |
| border-radius: 15px !important; | |
| margin-bottom: 1.5rem !important; | |
| box-shadow: 0 8px 25px rgba(0,0,0,0.1) !important; | |
| } | |
| .main-header h1 { | |
| font-size: 2rem !important; | |
| font-weight: 700 !important; | |
| margin: 0 !important; | |
| text-shadow: 2px 2px 4px rgba(0,0,0,0.3) !important; | |
| } | |
| .main-header p { | |
| font-size: 1rem !important; | |
| margin-top: 0.3rem !important; | |
| opacity: 0.9 !important; | |
| } | |
| /* Card-like sections - mais compactos */ | |
| .info-card { | |
| background: white !important; | |
| border-radius: 12px !important; | |
| padding: 1.2rem !important; | |
| box-shadow: 0 3px 15px rgba(0,0,0,0.06) !important; | |
| border: 1px solid #e2e8f0 !important; | |
| margin-bottom: 0.8rem !important; | |
| } | |
| .info-card h2 { | |
| color: #1a202c !important; | |
| font-weight: 600 !important; | |
| margin-bottom: 0.8rem !important; | |
| font-size: 1.2rem !important; | |
| } | |
| .info-card h3 { | |
| color: #2d3748 !important; | |
| font-weight: 500 !important; | |
| margin-bottom: 0.6rem !important; | |
| font-size: 1rem !important; | |
| } | |
| .info-card p { | |
| color: #2d3748 !important; | |
| line-height: 1.5 !important; | |
| margin-bottom: 0.5rem !important; | |
| } | |
| .info-card li { | |
| color: #2d3748 !important; | |
| margin-bottom: 0.3rem !important; | |
| } | |
| /* Gallery improvements - mais compacto */ | |
| .gallery-container { | |
| max-height: 450px !important; | |
| overflow-y: auto !important; | |
| border-radius: 12px !important; | |
| box-shadow: 0 3px 15px rgba(0,0,0,0.06) !important; | |
| border: 1px solid #e2e8f0 !important; | |
| } | |
| .gradio-gallery { | |
| max-height: 450px !important; | |
| border-radius: 12px !important; | |
| } | |
| div[data-testid="gallery"] { | |
| max-height: 450px !important; | |
| border-radius: 12px !important; | |
| } | |
| /* Image input styling */ | |
| .gradio-image { | |
| border-radius: 15px !important; | |
| border: 2px dashed #cbd5e0 !important; | |
| transition: all 0.3s ease !important; | |
| background: white !important; | |
| } | |
| .gradio-image label { | |
| color: #2d3748 !important; | |
| font-weight: 600 !important; | |
| } | |
| .gradio-image:hover { | |
| border-color: #667eea !important; | |
| box-shadow: 0 4px 20px rgba(102, 126, 234, 0.1) !important; | |
| } | |
| /* Button improvements */ | |
| .gradio-button { | |
| border-radius: 12px !important; | |
| font-weight: 500 !important; | |
| transition: all 0.3s ease !important; | |
| border: none !important; | |
| box-shadow: 0 2px 10px rgba(0,0,0,0.1) !important; | |
| } | |
| .gradio-button.primary { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; | |
| color: white !important; | |
| font-size: 1.1rem !important; | |
| padding: 0.8rem 2rem !important; | |
| } | |
| .gradio-button.primary:hover { | |
| transform: translateY(-2px) !important; | |
| box-shadow: 0 4px 20px rgba(102, 126, 234, 0.3) !important; | |
| } | |
| .gradio-button.secondary { | |
| background: #f7fafc !important; | |
| color: #4a5568 !important; | |
| border: 2px solid #e2e8f0 !important; | |
| } | |
| .gradio-button.secondary:hover { | |
| background: #edf2f7 !important; | |
| border-color: #cbd5e0 !important; | |
| } | |
| /* Textbox improvements */ | |
| .gradio-textbox { | |
| border-radius: 12px !important; | |
| border: 1px solid #e2e8f0 !important; | |
| box-shadow: 0 2px 10px rgba(0,0,0,0.05) !important; | |
| background: white !important; | |
| } | |
| .gradio-textbox textarea { | |
| font-family: 'Inter', sans-serif !important; | |
| line-height: 1.6 !important; | |
| font-size: 0.95rem !important; | |
| border-radius: 12px !important; | |
| border: none !important; | |
| padding: 1rem !important; | |
| color: #2d3748 !important; | |
| background: white !important; | |
| } | |
| .gradio-textbox label { | |
| color: #2d3748 !important; | |
| font-weight: 600 !important; | |
| } | |
| .gradio-textbox:focus-within { | |
| box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.1) !important; | |
| border-color: #667eea !important; | |
| } | |
| /* Dropdown styling */ | |
| .gradio-dropdown { | |
| border-radius: 12px !important; | |
| border: 1px solid #e2e8f0 !important; | |
| box-shadow: 0 2px 10px rgba(0,0,0,0.05) !important; | |
| background: white !important; | |
| } | |
| .gradio-dropdown label { | |
| color: #2d3748 !important; | |
| font-weight: 600 !important; | |
| } | |
| .gradio-dropdown .wrap { | |
| background: white !important; | |
| } | |
| .gradio-dropdown input { | |
| color: #2d3748 !important; | |
| background: white !important; | |
| } | |
| /* Metrics section - mais compacto */ | |
| .metrics-card { | |
| background: linear-gradient(135deg, #f7fafc 0%, #edf2f7 100%) !important; | |
| border-radius: 12px !important; | |
| padding: 1.2rem !important; | |
| border: 1px solid #e2e8f0 !important; | |
| margin: 0.8rem 0 !important; | |
| } | |
| .metrics-card table { | |
| width: 100% !important; | |
| border-collapse: collapse !important; | |
| margin-top: 1rem !important; | |
| } | |
| .metrics-card th, .metrics-card td { | |
| padding: 0.8rem !important; | |
| text-align: left !important; | |
| border: 1px solid #e2e8f0 !important; | |
| } | |
| .metrics-card th { | |
| background: #667eea !important; | |
| color: white !important; | |
| font-weight: 600 !important; | |
| } | |
| /* Status indicators */ | |
| .status-good { | |
| color: #22543d !important; | |
| background: #c6f6d5 !important; | |
| padding: 4px 8px !important; | |
| border-radius: 6px !important; | |
| font-weight: 600 !important; | |
| } | |
| .status-fair { | |
| color: #744210 !important; | |
| background: #faf089 !important; | |
| padding: 4px 8px !important; | |
| border-radius: 6px !important; | |
| font-weight: 600 !important; | |
| } | |
| .status-poor { | |
| color: #742a2a !important; | |
| background: #fed7d7 !important; | |
| padding: 4px 8px !important; | |
| border-radius: 6px !important; | |
| font-weight: 600 !important; | |
| } | |
| /* Loading animations */ | |
| @keyframes pulse { | |
| 0%, 100% { opacity: 1; } | |
| 50% { opacity: 0.5; } | |
| } | |
| .loading { | |
| animation: pulse 2s infinite !important; | |
| } | |
| /* Responsive design */ | |
| @media (max-width: 768px) { | |
| .gradio-container { | |
| max-width: 100% !important; | |
| padding: 1rem !important; | |
| } | |
| .main-header h1 { | |
| font-size: 2rem !important; | |
| } | |
| .main-header p { | |
| font-size: 1rem !important; | |
| } | |
| } | |
| /* Scrollbar styling */ | |
| ::-webkit-scrollbar { | |
| width: 8px !important; | |
| } | |
| ::-webkit-scrollbar-track { | |
| background: #f1f1f1 !important; | |
| border-radius: 10px !important; | |
| } | |
| ::-webkit-scrollbar-thumb { | |
| background: #c1c1c1 !important; | |
| border-radius: 10px !important; | |
| } | |
| ::-webkit-scrollbar-thumb:hover { | |
| background: #a8a8a8 !important; | |
| } | |
| /* Tab styling if needed */ | |
| .gradio-tabs { | |
| border-radius: 15px !important; | |
| overflow: hidden !important; | |
| box-shadow: 0 4px 20px rgba(0,0,0,0.08) !important; | |
| } | |
| /* Footer styling - mais compacto */ | |
| .footer-info { | |
| background: #f8fafc !important; | |
| border-radius: 12px !important; | |
| padding: 1.5rem !important; | |
| margin-top: 2rem !important; | |
| border: 1px solid #e2e8f0 !important; | |
| text-align: center !important; | |
| } | |
| """ | |
| ) as demo: | |
| # Modern header with gradient background | |
| model_status = "✅ Modelo RMMM carregado e pronto" if RMMM_MODEL is not None else "❌ Erro ao carregar modelo RMMM" | |
| tokenizer_status = "✅ Tokenizer MIMIC carregado" if MIMIC_TOKENIZER is not None else "⚠️ Usando tokenizer GPT-2 como fallback" | |
| gr.HTML(f""" | |
| <div class="main-header"> | |
| <h1>🩻 XRaySwinGen / RMMM</h1> | |
| <p>AI-Powered Medical Report Generation with Real-time Evaluation Metrics</p> | |
| <div style="margin-top: 0.8rem; font-size: 0.9rem; opacity: 0.9;"> | |
| <div>{model_status}</div> | |
| <div>{tokenizer_status}</div> | |
| </div> | |
| </div> | |
| """) | |
| # Instructions and Metrics section with improved cards | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.HTML(""" | |
| <div class="info-card"> | |
| <h2 style="color: #1a202c; font-weight: 700;">📖 Como Usar</h2> | |
| <div style="line-height: 1.8; color: #2d3748;"> | |
| <p><strong style="color: #1a202c;">1️⃣ Selecionar Imagem:</strong> Clique em qualquer raio-X na galeria para carregá-lo</p> | |
| <p><strong style="color: #1a202c;">2️⃣ Gerar Relatório:</strong> Clique em 'Gerar Relatório' para ver a análise</p> | |
| <p><strong style="color: #1a202c;">3️⃣ Avaliar Resultados:</strong> Revise as métricas de avaliação (BLEU-4, ROUGE-L)</p> | |
| </div> | |
| </div> | |
| """) | |
| with gr.Column(scale=1): | |
| metrics_display = gr.HTML( | |
| value=""" | |
| <div class="info-card"> | |
| <h2 style="color: #1a202c; font-weight: 700;">📊 Métricas de Avaliação</h2> | |
| <p style="color: #4a5568; font-style: italic;"> | |
| Selecione uma imagem e gere um relatório para ver as métricas de avaliação em tempo real | |
| </p> | |
| </div> | |
| """, | |
| label="Métricas de Avaliação" | |
| ) | |
| # Main interface with modern layout | |
| # Create image gallery from available MIMIC dataset images | |
| example_images = get_available_image_paths() | |
| with gr.Row(equal_height=True): | |
| # Left side - Enhanced Gallery | |
| with gr.Column(scale=1): | |
| gr.HTML("<div class='info-card'><h2>🩻 Galeria MIMIC-CXR</h2><p>Clique em qualquer imagem para carregá-la</p></div>") | |
| gallery = gr.Gallery( | |
| value=example_images, | |
| columns=3, | |
| height=450, | |
| object_fit="contain", | |
| allow_preview=True, | |
| show_label=False, | |
| show_download_button=False, | |
| interactive=True, | |
| container=True | |
| ) | |
| # Right side - Enhanced Controls | |
| with gr.Column(scale=1): | |
| gr.HTML("<div class='info-card'><h2>🎛️ Controles</h2></div>") | |
| # Image display with modern styling | |
| image_input = gr.Image( | |
| height=220, | |
| width=220, | |
| label="📸 Imagem Selecionada", | |
| show_label=True, | |
| container=True | |
| ) | |
| # Hidden state to store the selected image filename | |
| selected_image_state = gr.State(value="") | |
| # Modern action buttons | |
| with gr.Row(): | |
| submit_btn = gr.Button( | |
| "🚀 Gerar Relatório", | |
| variant="primary", | |
| size="lg", | |
| scale=3, | |
| elem_classes=["primary"] | |
| ) | |
| clear_btn = gr.Button( | |
| "🗑️ Limpar", | |
| size="lg", | |
| scale=1, | |
| elem_classes=["secondary"] | |
| ) | |
| # Reports section with enhanced styling | |
| gr.HTML("<div style='margin: 2rem 0;'><h2 style='text-align: center; color: #2d3748; font-size: 1.8rem; font-weight: 600;'>📋 Relatórios Médicos</h2></div>") | |
| with gr.Row(equal_height=True): | |
| ai_report = gr.Textbox( | |
| label="🤖 Relatório Gerado por IA", | |
| lines=8, | |
| placeholder="Clique em 'Gerar Relatório' para ver a análise da IA...", | |
| container=True, | |
| show_copy_button=True | |
| ) | |
| ground_truth = gr.Textbox( | |
| label="📋 Relatório Ground Truth", | |
| lines=8, | |
| placeholder="O relatório verdadeiro aparecerá aqui quando você selecionar uma imagem de exemplo...", | |
| container=True, | |
| show_copy_button=True | |
| ) | |
| # Enhanced information section with modern cards | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.HTML(""" | |
| <div class="footer-info"> | |
| <h2 style="color: #1a202c; margin-bottom: 1.2rem; font-weight: 700; font-size: 1.3rem;">🔬 Sobre Esta Aplicação</h2> | |
| <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(280px, 1fr)); gap: 1.2rem; margin: 1.5rem 0;"> | |
| <div class="info-card"> | |
| <h3 style="color: #1a365d; font-weight: 600; font-size: 1rem;">📄 Artigo Científico</h3> | |
| <p><a href="https://www.cell.com/heliyon/fulltext/S2405-8440(24)03547-3" | |
| target="_blank" | |
| style="color: #2563eb; text-decoration: none; font-weight: 600; font-size: 0.9rem;"> | |
| XRaySwinGen: Automatic medical reporting for X-ray exams with multimodal model | |
| </a></p> | |
| <p style="font-size: 0.85rem; color: #4a5568;"> | |
| Pesquisa publicada na revista Heliyon sobre geração automática de relatórios médicos. | |
| </p> | |
| </div> | |
| <div class="info-card"> | |
| <h3 style="color: #1a365d; font-weight: 600; font-size: 1rem;">🤖 Modelo Disponível</h3> | |
| <ul style="text-align: left; padding-left: 1rem; color: #2d3748; font-size: 0.9rem;"> | |
| <li><strong style="color: #1a202c;">RMMM (MIMIC):</strong> Modelo PyTorch para análise de raios-X</li> | |
| </ul> | |
| </div> | |
| <div class="info-card"> | |
| <h3 style="color: #1a365d; font-weight: 600; font-size: 1rem;">📊 Métricas</h3> | |
| <ul style="text-align: left; padding-left: 1rem; color: #2d3748; font-size: 0.9rem;"> | |
| <li><strong style="color: #1a202c;">BLEU-4:</strong> Sobreposição de 4-gramas</li> | |
| <li><strong style="color: #1a202c;">ROUGE-L:</strong> Subsequência comum mais longa</li> | |
| </ul> | |
| </div> | |
| <div class="info-card"> | |
| <h3 style="color: #1a365d; font-weight: 600; font-size: 1rem;">🏥 Dataset</h3> | |
| <p style="color: #2d3748; font-size: 0.9rem;">Imagens do dataset MIMIC-CXR com relatórios de radiologistas especialistas.</p> | |
| </div> | |
| </div> | |
| <div style="margin-top: 1.5rem; padding-top: 1.2rem; border-top: 2px solid #e2e8f0;"> | |
| <p style="color: #4a5568; font-size: 0.9rem;"> | |
| 💡 <strong style="color: #1a202c;">Dica:</strong> Carregue uma imagem da galeria para gerar um relatório médico automaticamente. | |
| </p> | |
| </div> | |
| </div> | |
| """) | |
| # Gallery click handler to load selected image and remember filename | |
| def load_selected_image(evt: gr.SelectData): | |
| selected_image_path = example_images[evt.index] | |
| print(f"Gallery selection - Loading image: {selected_image_path}") | |
| # Load the image explicitly to ensure it's properly loaded | |
| try: | |
| loaded_image = Image.open(selected_image_path).convert('RGB') | |
| print(f"✅ Successfully loaded image: {loaded_image.size}") | |
| return loaded_image, selected_image_path # Return PIL Image and path | |
| except Exception as e: | |
| print(f"❌ Error loading image: {e}") | |
| return None, selected_image_path | |
| gallery.select( | |
| fn=load_selected_image, | |
| outputs=[image_input, selected_image_state] | |
| ) | |
| # Main generation button - ONLY manual trigger | |
| submit_btn.click( | |
| fn=inference_image_pipe_with_state, | |
| inputs=[image_input, selected_image_state], | |
| outputs=[ai_report, ground_truth, metrics_display] | |
| ) | |
| # Clear button | |
| clear_btn.click( | |
| fn=lambda: (None, "", "", """ | |
| <div class="info-card"> | |
| <h2 style="color: #1a202c; font-weight: 700;">📊 Métricas de Avaliação</h2> | |
| <p style="color: #4a5568; font-style: italic;"> | |
| Selecione uma imagem e gere um relatório para ver as métricas de avaliação em tempo real | |
| </p> | |
| </div> | |
| """), | |
| outputs=[image_input, ai_report, ground_truth, metrics_display] | |
| ) | |
| if __name__ == "__main__": | |
| # Fix for Windows asyncio connection issues | |
| if sys.platform.startswith('win'): | |
| try: | |
| # Set event loop policy for Windows | |
| asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) | |
| except Exception: | |
| pass | |
| # Launch with server configuration optimized for Windows | |
| demo.launch( | |
| show_error=True, | |
| quiet=False, | |
| share=True | |
| ) | |