Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,446 Bytes
b752d16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import torch
from transformers import AutoModel, AutoTokenizer, AutoProcessor, AutoModelForImageTextToText
from PIL import Image
import fitz
import os
import time
DEEPSEEK_MODEL = 'deepseek-ai/DeepSeek-OCR-2'
MEDGEMMA_MODEL = 'google/medgemma-1.5-4b-it'
if torch.backends.mps.is_available():
print("Patching torch for MPS compatibility...")
device = "mps"
torch.Tensor.cuda = lambda self, *args, **kwargs: self.to("mps")
torch.nn.Module.cuda = lambda self, *args, **kwargs: self.to("mps")
torch.bfloat16 = torch.float16
dtype = torch.float16
else:
device = "cpu"
dtype = torch.float32
def get_page_image(pdf_path, page_num=0):
doc = fitz.open(pdf_path)
page = doc.load_page(page_num)
pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
doc.close()
return img
def run_deepseek(img):
tokenizer = AutoTokenizer.from_pretrained(DEEPSEEK_MODEL, trust_remote_code=True)
model = AutoModel.from_pretrained(DEEPSEEK_MODEL, trust_remote_code=True, use_safetensors=True)
model = model.to(device=device, dtype=dtype).eval()
with torch.no_grad():
# Need a temp file for deepseek's .infer
img.save("temp_comp.png")
res = model.infer(
tokenizer,
prompt="<image>\nFree OCR. ",
image_file="temp_comp.png",
output_path="outputs",
base_size=1024,
image_size=768,
crop_mode=True,
eval_mode=True
)
os.remove("temp_comp.png")
return res
def run_medgemma(img):
processor = AutoProcessor.from_pretrained(MEDGEMMA_MODEL)
model = AutoModelForImageTextToText.from_pretrained(
MEDGEMMA_MODEL,
trust_remote_code=True,
dtype=dtype if device == "mps" else torch.float32,
device_map="auto" if device != "mps" else None
).eval()
if device == "mps":
model = model.to("mps")
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": img},
{"type": "text", "text": "Extract all text from this medical document."}
]
}
]
inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=2048)
input_len = inputs["input_ids"].shape[-1]
return processor.decode(output[0][input_len:], skip_special_tokens=True)
def compare():
pdf_path = "doc_for_testing/pdf12_un.pdf"
if not os.path.exists(pdf_path):
print("PDF not found.")
return
img = get_page_image(pdf_path)
print("\n--- Running DeepSeek-OCR-2 ---")
start = time.time()
ds_res = run_deepseek(img)
print(f"Time: {time.time() - start:.2f}s")
print("\n--- Running MedGemma-1.5-4B ---")
start = time.time()
mg_res = run_medgemma(img)
print(f"Time: {time.time() - start:.2f}s")
with open("model_comparison.md", "w") as f:
f.write("# Comparison Report: DeepSeek-OCR-2 vs MedGemma-1.5-4B\n\n")
f.write("## DeepSeek-OCR-2 Result\n\n")
f.write(ds_res + "\n\n")
f.write("## MedGemma-1.5-4B Result\n\n")
f.write(mg_res + "\n")
if __name__ == "__main__":
compare()
|