File size: 3,446 Bytes
b752d16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
from transformers import AutoModel, AutoTokenizer, AutoProcessor, AutoModelForImageTextToText
from PIL import Image
import fitz
import os
import time

DEEPSEEK_MODEL = 'deepseek-ai/DeepSeek-OCR-2'
MEDGEMMA_MODEL = 'google/medgemma-1.5-4b-it'

if torch.backends.mps.is_available():
    print("Patching torch for MPS compatibility...")
    device = "mps"
    torch.Tensor.cuda = lambda self, *args, **kwargs: self.to("mps")
    torch.nn.Module.cuda = lambda self, *args, **kwargs: self.to("mps")
    torch.bfloat16 = torch.float16
    dtype = torch.float16
else:
    device = "cpu"
    dtype = torch.float32

def get_page_image(pdf_path, page_num=0):
    doc = fitz.open(pdf_path)
    page = doc.load_page(page_num)
    pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))
    img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
    doc.close()
    return img

def run_deepseek(img):
    tokenizer = AutoTokenizer.from_pretrained(DEEPSEEK_MODEL, trust_remote_code=True)
    model = AutoModel.from_pretrained(DEEPSEEK_MODEL, trust_remote_code=True, use_safetensors=True)
    model = model.to(device=device, dtype=dtype).eval()
    
    with torch.no_grad():
        # Need a temp file for deepseek's .infer
        img.save("temp_comp.png")
        res = model.infer(
            tokenizer, 
            prompt="<image>\nFree OCR. ", 
            image_file="temp_comp.png",
            output_path="outputs",
            base_size=1024, 
            image_size=768, 
            crop_mode=True, 
            eval_mode=True
        )
        os.remove("temp_comp.png")
    return res

def run_medgemma(img):
    processor = AutoProcessor.from_pretrained(MEDGEMMA_MODEL)
    model = AutoModelForImageTextToText.from_pretrained(
        MEDGEMMA_MODEL, 
        trust_remote_code=True, 
        dtype=dtype if device == "mps" else torch.float32,
        device_map="auto" if device != "mps" else None
    ).eval()
    if device == "mps":
        model = model.to("mps")
    
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": img},
                {"type": "text", "text": "Extract all text from this medical document."}
            ]
        }
    ]
    inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        output = model.generate(**inputs, max_new_tokens=2048)
    
    input_len = inputs["input_ids"].shape[-1]
    return processor.decode(output[0][input_len:], skip_special_tokens=True)

def compare():
    pdf_path = "doc_for_testing/pdf12_un.pdf"
    if not os.path.exists(pdf_path):
        print("PDF not found.")
        return
        
    img = get_page_image(pdf_path)
    
    print("\n--- Running DeepSeek-OCR-2 ---")
    start = time.time()
    ds_res = run_deepseek(img)
    print(f"Time: {time.time() - start:.2f}s")
    
    print("\n--- Running MedGemma-1.5-4B ---")
    start = time.time()
    mg_res = run_medgemma(img)
    print(f"Time: {time.time() - start:.2f}s")
    
    with open("model_comparison.md", "w") as f:
        f.write("# Comparison Report: DeepSeek-OCR-2 vs MedGemma-1.5-4B\n\n")
        f.write("## DeepSeek-OCR-2 Result\n\n")
        f.write(ds_res + "\n\n")
        f.write("## MedGemma-1.5-4B Result\n\n")
        f.write(mg_res + "\n")

if __name__ == "__main__":
    compare()