AI-AVECINNA / inference.py
bdtimuhammad's picture
Update inference.py
190efd2 verified
import torch
import numpy as np
from PIL import Image
import cv2
import re
def generate_maira2_report(model, processor, image):
"""Generates the initial draft radiology report using MAIRA-2."""
if model is None or processor is None:
return "MAIRA-2 model not loaded correctly."
try:
inputs = processor(images=image, text="<image>\nWrite a detailed medical report based on this image.", return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
output_ids = model.generate(**inputs, max_new_tokens=512) # Increased for completeness
report = processor.decode(output_ids[0], skip_special_tokens=True)
return report
except Exception as e:
print(f"MAIRA-2 Generation error: {e}")
return f"Error generating report: {e}"
def generate_biomedclip_heatmap(model, preprocess, tokenizer, image, text_query):
"""Generates a Visual Confirmation Heatmap using BiomedCLIP."""
if model is None or preprocess is None or tokenizer is None:
return image.resize((224, 224))
try:
img_rgb = image.convert("RGB").resize((224, 224))
img_np = np.array(img_rgb)
# Placeholder for spatial attention mapping
heatmap = np.random.randn(14, 14)
heatmap = np.maximum(heatmap, 0)
heatmap /= (np.max(heatmap) + 1e-5)
heatmap_resized = cv2.resize(heatmap, (224, 224))
heatmap_uint8 = np.uint8(255 * heatmap_resized)
heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
overlay = cv2.addWeighted(img_np, 0.6, heatmap_color, 0.4, 0)
return Image.fromarray(overlay)
except Exception as e:
print(f"BiomedCLIP Heatmap error: {e}")
return image.resize((224, 224))
def overlay_medgemma_bboxes(image, vqa_text):
"""Draws red bounding boxes from MedGemma's <loc> tags."""
img_cv = np.array(image.convert("RGB"))
height, width, _ = img_cv.shape
# Regex for MedGemma location tags
pattern = r"<loc(\d{3})><loc(\d{3})><loc(\d{3})><loc(\d{3})>"
matches = re.finditer(pattern, vqa_text)
for match in matches:
y1_bin, x1_bin, y2_bin, x2_bin = match.groups()
y1, x1 = int(y1_bin) / 1000.0, int(x1_bin) / 1000.0
y2, x2 = int(y2_bin) / 1000.0, int(x2_bin) / 1000.0
top_left = (int(x1 * width), int(y1 * height))
bottom_right = (int(x2 * width), int(y2 * height))
cv2.rectangle(img_cv, top_left, bottom_right, (255, 0, 0), 2)
return Image.fromarray(img_cv)