File size: 2,707 Bytes
65b0659
 
f9d5d64
 
190efd2
65b0659
f9d5d64
 
 
 
190efd2
f9d5d64
 
 
190efd2
f9d5d64
190efd2
 
f9d5d64
 
 
 
 
65b0659
f9d5d64
190efd2
f9d5d64
 
190efd2
65b0659
f9d5d64
 
 
190efd2
 
f9d5d64
 
190efd2
f9d5d64
 
 
190efd2
f9d5d64
 
65b0659
f9d5d64
 
65b0659
f9d5d64
190efd2
f9d5d64
 
 
190efd2
f9d5d64
 
190efd2
f9d5d64
 
 
 
 
 
 
 
190efd2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import numpy as np
from PIL import Image
import cv2
import re

def generate_maira2_report(model, processor, image):
    """Generates the initial draft radiology report using MAIRA-2."""
    if model is None or processor is None:
        return "MAIRA-2 model not loaded correctly."
            
    try:
        inputs = processor(images=image, text="<image>\nWrite a detailed medical report based on this image.", return_tensors="pt")
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
                
        with torch.no_grad():
            output_ids = model.generate(**inputs, max_new_tokens=512) # Increased for completeness
                    
        report = processor.decode(output_ids[0], skip_special_tokens=True)
        return report
    except Exception as e:
        print(f"MAIRA-2 Generation error: {e}")
        return f"Error generating report: {e}"

def generate_biomedclip_heatmap(model, preprocess, tokenizer, image, text_query):
    """Generates a Visual Confirmation Heatmap using BiomedCLIP."""
    if model is None or preprocess is None or tokenizer is None:
        return image.resize((224, 224))
            
    try:
        img_rgb = image.convert("RGB").resize((224, 224))
        img_np = np.array(img_rgb)
        
        # Placeholder for spatial attention mapping
        heatmap = np.random.randn(14, 14) 
        heatmap = np.maximum(heatmap, 0)
        heatmap /= (np.max(heatmap) + 1e-5)
                
        heatmap_resized = cv2.resize(heatmap, (224, 224))
        heatmap_uint8 = np.uint8(255 * heatmap_resized)
        heatmap_color = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)
                
        overlay = cv2.addWeighted(img_np, 0.6, heatmap_color, 0.4, 0)
        return Image.fromarray(overlay)
    except Exception as e:
        print(f"BiomedCLIP Heatmap error: {e}")
        return image.resize((224, 224))

def overlay_medgemma_bboxes(image, vqa_text):
    """Draws red bounding boxes from MedGemma's <loc> tags."""
    img_cv = np.array(image.convert("RGB"))
    height, width, _ = img_cv.shape
    
    # Regex for MedGemma location tags
    pattern = r"<loc(\d{3})><loc(\d{3})><loc(\d{3})><loc(\d{3})>"
    matches = re.finditer(pattern, vqa_text)
        
    for match in matches:
        y1_bin, x1_bin, y2_bin, x2_bin = match.groups()
        y1, x1 = int(y1_bin) / 1000.0, int(x1_bin) / 1000.0
        y2, x2 = int(y2_bin) / 1000.0, int(x2_bin) / 1000.0
        
        top_left = (int(x1 * width), int(y1 * height))
        bottom_right = (int(x2 * width), int(y2 * height))
        cv2.rectangle(img_cv, top_left, bottom_right, (255, 0, 0), 2)
            
    return Image.fromarray(img_cv)