import torch from transformers import AutoModel, AutoTokenizer, AutoProcessor, AutoModelForImageTextToText from PIL import Image import fitz import os import time DEEPSEEK_MODEL = 'deepseek-ai/DeepSeek-OCR-2' MEDGEMMA_MODEL = 'google/medgemma-1.5-4b-it' if torch.backends.mps.is_available(): print("Patching torch for MPS compatibility...") device = "mps" torch.Tensor.cuda = lambda self, *args, **kwargs: self.to("mps") torch.nn.Module.cuda = lambda self, *args, **kwargs: self.to("mps") torch.bfloat16 = torch.float16 dtype = torch.float16 else: device = "cpu" dtype = torch.float32 def get_page_image(pdf_path, page_num=0): doc = fitz.open(pdf_path) page = doc.load_page(page_num) pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) doc.close() return img def run_deepseek(img): tokenizer = AutoTokenizer.from_pretrained(DEEPSEEK_MODEL, trust_remote_code=True) model = AutoModel.from_pretrained(DEEPSEEK_MODEL, trust_remote_code=True, use_safetensors=True) model = model.to(device=device, dtype=dtype).eval() with torch.no_grad(): # Need a temp file for deepseek's .infer img.save("temp_comp.png") res = model.infer( tokenizer, prompt="\nFree OCR. ", image_file="temp_comp.png", output_path="outputs", base_size=1024, image_size=768, crop_mode=True, eval_mode=True ) os.remove("temp_comp.png") return res def run_medgemma(img): processor = AutoProcessor.from_pretrained(MEDGEMMA_MODEL) model = AutoModelForImageTextToText.from_pretrained( MEDGEMMA_MODEL, trust_remote_code=True, dtype=dtype if device == "mps" else torch.float32, device_map="auto" if device != "mps" else None ).eval() if device == "mps": model = model.to("mps") messages = [ { "role": "user", "content": [ {"type": "image", "image": img}, {"type": "text", "text": "Extract all text from this medical document."} ] } ] inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device) with torch.no_grad(): output = model.generate(**inputs, max_new_tokens=2048) input_len = inputs["input_ids"].shape[-1] return processor.decode(output[0][input_len:], skip_special_tokens=True) def compare(): pdf_path = "doc_for_testing/pdf12_un.pdf" if not os.path.exists(pdf_path): print("PDF not found.") return img = get_page_image(pdf_path) print("\n--- Running DeepSeek-OCR-2 ---") start = time.time() ds_res = run_deepseek(img) print(f"Time: {time.time() - start:.2f}s") print("\n--- Running MedGemma-1.5-4B ---") start = time.time() mg_res = run_medgemma(img) print(f"Time: {time.time() - start:.2f}s") with open("model_comparison.md", "w") as f: f.write("# Comparison Report: DeepSeek-OCR-2 vs MedGemma-1.5-4B\n\n") f.write("## DeepSeek-OCR-2 Result\n\n") f.write(ds_res + "\n\n") f.write("## MedGemma-1.5-4B Result\n\n") f.write(mg_res + "\n") if __name__ == "__main__": compare()