Local_OCR_Demo / compare_models.py
DocUA's picture
Initial commit: DeepSeek-OCR-2 & MedGemma-1.5 multimodal analysis app with ZeroGPU support
b752d16
import torch
from transformers import AutoModel, AutoTokenizer, AutoProcessor, AutoModelForImageTextToText
from PIL import Image
import fitz
import os
import time
DEEPSEEK_MODEL = 'deepseek-ai/DeepSeek-OCR-2'
MEDGEMMA_MODEL = 'google/medgemma-1.5-4b-it'
if torch.backends.mps.is_available():
print("Patching torch for MPS compatibility...")
device = "mps"
torch.Tensor.cuda = lambda self, *args, **kwargs: self.to("mps")
torch.nn.Module.cuda = lambda self, *args, **kwargs: self.to("mps")
torch.bfloat16 = torch.float16
dtype = torch.float16
else:
device = "cpu"
dtype = torch.float32
def get_page_image(pdf_path, page_num=0):
doc = fitz.open(pdf_path)
page = doc.load_page(page_num)
pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
doc.close()
return img
def run_deepseek(img):
tokenizer = AutoTokenizer.from_pretrained(DEEPSEEK_MODEL, trust_remote_code=True)
model = AutoModel.from_pretrained(DEEPSEEK_MODEL, trust_remote_code=True, use_safetensors=True)
model = model.to(device=device, dtype=dtype).eval()
with torch.no_grad():
# Need a temp file for deepseek's .infer
img.save("temp_comp.png")
res = model.infer(
tokenizer,
prompt="<image>\nFree OCR. ",
image_file="temp_comp.png",
output_path="outputs",
base_size=1024,
image_size=768,
crop_mode=True,
eval_mode=True
)
os.remove("temp_comp.png")
return res
def run_medgemma(img):
processor = AutoProcessor.from_pretrained(MEDGEMMA_MODEL)
model = AutoModelForImageTextToText.from_pretrained(
MEDGEMMA_MODEL,
trust_remote_code=True,
dtype=dtype if device == "mps" else torch.float32,
device_map="auto" if device != "mps" else None
).eval()
if device == "mps":
model = model.to("mps")
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": img},
{"type": "text", "text": "Extract all text from this medical document."}
]
}
]
inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device)
with torch.no_grad():
output = model.generate(**inputs, max_new_tokens=2048)
input_len = inputs["input_ids"].shape[-1]
return processor.decode(output[0][input_len:], skip_special_tokens=True)
def compare():
pdf_path = "doc_for_testing/pdf12_un.pdf"
if not os.path.exists(pdf_path):
print("PDF not found.")
return
img = get_page_image(pdf_path)
print("\n--- Running DeepSeek-OCR-2 ---")
start = time.time()
ds_res = run_deepseek(img)
print(f"Time: {time.time() - start:.2f}s")
print("\n--- Running MedGemma-1.5-4B ---")
start = time.time()
mg_res = run_medgemma(img)
print(f"Time: {time.time() - start:.2f}s")
with open("model_comparison.md", "w") as f:
f.write("# Comparison Report: DeepSeek-OCR-2 vs MedGemma-1.5-4B\n\n")
f.write("## DeepSeek-OCR-2 Result\n\n")
f.write(ds_res + "\n\n")
f.write("## MedGemma-1.5-4B Result\n\n")
f.write(mg_res + "\n")
if __name__ == "__main__":
compare()