Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from transformers import AutoModel, AutoTokenizer, AutoProcessor, AutoModelForImageTextToText | |
| from PIL import Image | |
| import fitz | |
| import os | |
| import time | |
| DEEPSEEK_MODEL = 'deepseek-ai/DeepSeek-OCR-2' | |
| MEDGEMMA_MODEL = 'google/medgemma-1.5-4b-it' | |
| if torch.backends.mps.is_available(): | |
| print("Patching torch for MPS compatibility...") | |
| device = "mps" | |
| torch.Tensor.cuda = lambda self, *args, **kwargs: self.to("mps") | |
| torch.nn.Module.cuda = lambda self, *args, **kwargs: self.to("mps") | |
| torch.bfloat16 = torch.float16 | |
| dtype = torch.float16 | |
| else: | |
| device = "cpu" | |
| dtype = torch.float32 | |
| def get_page_image(pdf_path, page_num=0): | |
| doc = fitz.open(pdf_path) | |
| page = doc.load_page(page_num) | |
| pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) | |
| img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| doc.close() | |
| return img | |
| def run_deepseek(img): | |
| tokenizer = AutoTokenizer.from_pretrained(DEEPSEEK_MODEL, trust_remote_code=True) | |
| model = AutoModel.from_pretrained(DEEPSEEK_MODEL, trust_remote_code=True, use_safetensors=True) | |
| model = model.to(device=device, dtype=dtype).eval() | |
| with torch.no_grad(): | |
| # Need a temp file for deepseek's .infer | |
| img.save("temp_comp.png") | |
| res = model.infer( | |
| tokenizer, | |
| prompt="<image>\nFree OCR. ", | |
| image_file="temp_comp.png", | |
| output_path="outputs", | |
| base_size=1024, | |
| image_size=768, | |
| crop_mode=True, | |
| eval_mode=True | |
| ) | |
| os.remove("temp_comp.png") | |
| return res | |
| def run_medgemma(img): | |
| processor = AutoProcessor.from_pretrained(MEDGEMMA_MODEL) | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| MEDGEMMA_MODEL, | |
| trust_remote_code=True, | |
| dtype=dtype if device == "mps" else torch.float32, | |
| device_map="auto" if device != "mps" else None | |
| ).eval() | |
| if device == "mps": | |
| model = model.to("mps") | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": img}, | |
| {"type": "text", "text": "Extract all text from this medical document."} | |
| ] | |
| } | |
| ] | |
| inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| output = model.generate(**inputs, max_new_tokens=2048) | |
| input_len = inputs["input_ids"].shape[-1] | |
| return processor.decode(output[0][input_len:], skip_special_tokens=True) | |
| def compare(): | |
| pdf_path = "doc_for_testing/pdf12_un.pdf" | |
| if not os.path.exists(pdf_path): | |
| print("PDF not found.") | |
| return | |
| img = get_page_image(pdf_path) | |
| print("\n--- Running DeepSeek-OCR-2 ---") | |
| start = time.time() | |
| ds_res = run_deepseek(img) | |
| print(f"Time: {time.time() - start:.2f}s") | |
| print("\n--- Running MedGemma-1.5-4B ---") | |
| start = time.time() | |
| mg_res = run_medgemma(img) | |
| print(f"Time: {time.time() - start:.2f}s") | |
| with open("model_comparison.md", "w") as f: | |
| f.write("# Comparison Report: DeepSeek-OCR-2 vs MedGemma-1.5-4B\n\n") | |
| f.write("## DeepSeek-OCR-2 Result\n\n") | |
| f.write(ds_res + "\n\n") | |
| f.write("## MedGemma-1.5-4B Result\n\n") | |
| f.write(mg_res + "\n") | |
| if __name__ == "__main__": | |
| compare() | |