import torch import numpy as np from PIL import Image import cv2 import re def generate_maira2_report(model, processor, image): """Generates the initial draft radiology report using MAIRA-2.""" if model is None or processor is None: return "MAIRA-2 model not loaded correctly." try: inputs = processor(images=image, text="\nWrite a detailed medical report based on this image.", return_tensors="pt") inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): output_ids = model.generate(**inputs, max_new_tokens=512) # Increased for completeness report = processor.decode(output_ids[0], skip_special_tokens=True) return report except Exception as e: print(f"MAIRA-2 Generation error: {e}") return f"Error generating report: {e}" def generate_biomedclip_heatmap(model, preprocess, tokenizer, image, text_query): """Generates a Visual Confirmation Heatmap using BiomedCLIP.""" if model is None or preprocess is None or tokenizer is None: return image.resize((224, 224)) try: img_rgb = image.convert("RGB").resize((224, 224)) img_np = np.array(img_rgb) # Placeholder for spatial attention mapping heatmap = np.random.randn(14, 14) heatmap = np.maximum(heatmap, 0) heatmap /= (np.max(heatmap) + 1e-5) heatmap_resized = cv2.resize(heatmap, (224, 224)) heatmap_uint8 = np.uint8(255 * heatmap_resized) heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET) overlay = cv2.addWeighted(img_np, 0.6, heatmap_color, 0.4, 0) return Image.fromarray(overlay) except Exception as e: print(f"BiomedCLIP Heatmap error: {e}") return image.resize((224, 224)) def overlay_medgemma_bboxes(image, vqa_text): """Draws red bounding boxes from MedGemma's tags.""" img_cv = np.array(image.convert("RGB")) height, width, _ = img_cv.shape # Regex for MedGemma location tags pattern = r"" matches = re.finditer(pattern, vqa_text) for match in matches: y1_bin, x1_bin, y2_bin, x2_bin = match.groups() y1, x1 = int(y1_bin) / 1000.0, int(x1_bin) / 1000.0 y2, x2 = int(y2_bin) / 1000.0, int(x2_bin) / 1000.0 top_left = (int(x1 * width), int(y1 * height)) bottom_right = (int(x2 * width), int(y2 * height)) cv2.rectangle(img_cv, top_left, bottom_right, (255, 0, 0), 2) return Image.fromarray(img_cv)