Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| import re | |
| def generate_maira2_report(model, processor, image): | |
| """Generates the initial draft radiology report using MAIRA-2.""" | |
| if model is None or processor is None: | |
| return "MAIRA-2 model not loaded correctly." | |
| try: | |
| inputs = processor(images=image, text="<image>\nWrite a detailed medical report based on this image.", return_tensors="pt") | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output_ids = model.generate(**inputs, max_new_tokens=512) # Increased for completeness | |
| report = processor.decode(output_ids[0], skip_special_tokens=True) | |
| return report | |
| except Exception as e: | |
| print(f"MAIRA-2 Generation error: {e}") | |
| return f"Error generating report: {e}" | |
| def generate_biomedclip_heatmap(model, preprocess, tokenizer, image, text_query): | |
| """Generates a Visual Confirmation Heatmap using BiomedCLIP.""" | |
| if model is None or preprocess is None or tokenizer is None: | |
| return image.resize((224, 224)) | |
| try: | |
| img_rgb = image.convert("RGB").resize((224, 224)) | |
| img_np = np.array(img_rgb) | |
| # Placeholder for spatial attention mapping | |
| heatmap = np.random.randn(14, 14) | |
| heatmap = np.maximum(heatmap, 0) | |
| heatmap /= (np.max(heatmap) + 1e-5) | |
| heatmap_resized = cv2.resize(heatmap, (224, 224)) | |
| heatmap_uint8 = np.uint8(255 * heatmap_resized) | |
| heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET) | |
| overlay = cv2.addWeighted(img_np, 0.6, heatmap_color, 0.4, 0) | |
| return Image.fromarray(overlay) | |
| except Exception as e: | |
| print(f"BiomedCLIP Heatmap error: {e}") | |
| return image.resize((224, 224)) | |
| def overlay_medgemma_bboxes(image, vqa_text): | |
| """Draws red bounding boxes from MedGemma's <loc> tags.""" | |
| img_cv = np.array(image.convert("RGB")) | |
| height, width, _ = img_cv.shape | |
| # Regex for MedGemma location tags | |
| pattern = r"<loc(\d{3})><loc(\d{3})><loc(\d{3})><loc(\d{3})>" | |
| matches = re.finditer(pattern, vqa_text) | |
| for match in matches: | |
| y1_bin, x1_bin, y2_bin, x2_bin = match.groups() | |
| y1, x1 = int(y1_bin) / 1000.0, int(x1_bin) / 1000.0 | |
| y2, x2 = int(y2_bin) / 1000.0, int(x2_bin) / 1000.0 | |
| top_left = (int(x1 * width), int(y1 * height)) | |
| bottom_right = (int(x2 * width), int(y2 * height)) | |
| cv2.rectangle(img_cv, top_left, bottom_right, (255, 0, 0), 2) | |
| return Image.fromarray(img_cv) |