Local_OCR_Demo / test_real_docs.py
DocUA's picture
Initial commit: DeepSeek-OCR-2 & MedGemma-1.5 multimodal analysis app with ZeroGPU support
b752d16
from transformers import AutoModel, AutoTokenizer
import torch
import os
from PIL import Image
import time
# Force CPU for stability
device = "cpu"
print(f"Using device: {device}")
# Patch to avoid CUDA calls in custom code
torch.Tensor.cuda = lambda self, *args, **kwargs: self.to(device)
torch.nn.Module.cuda = lambda self, *args, **kwargs: self.to(device)
model_name = 'deepseek-ai/DeepSeek-OCR-2'
def test_docs():
print(f"Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
print(f"Loading model (may take a minute)...")
# Load with default parameters that worked in test_minimal.py
model = AutoModel.from_pretrained(
model_name,
trust_remote_code=True,
use_safetensors=True
)
model = model.eval()
# After loading, we monkeypatch bfloat16 for the inference logic
torch.bfloat16 = torch.float32
image_dir = "doc_images"
output_dir = "ocr_results"
os.makedirs(output_dir, exist_ok=True)
images = sorted([f for f in os.listdir(image_dir) if f.endswith(".png")])
for img_name in images:
img_path = os.path.join(image_dir, img_name)
print(f"\n--- Processing: {img_name} ---")
# DeepSeek-OCR-2 needs specific ratios for its hardcoded query embeddings
# base_size=1024 -> n_query=256 (supported)
# image_size=768 -> n_query=144 (supported)
prompt = "<image>\nFree OCR. "
start_time = time.time()
try:
with torch.no_grad():
res = model.infer(
tokenizer,
prompt=prompt,
image_file=img_path,
output_path=output_dir,
base_size=1024, # Must be 1024 for 256 queries
image_size=768, # Must be 768 for 144 queries
crop_mode=False,
eval_mode=True
)
elapsed = time.time() - start_time
print(f"Done in {elapsed:.2f}s")
result_file = os.path.join(output_dir, f"{img_name}.md")
with open(result_file, "w") as f:
f.write(res)
print(f"Result saved to {result_file}")
print("Preview (first 500 chars):")
print("-" * 20)
print(res[:500] + "...")
print("-" * 20)
except Exception as e:
print(f"Inference failed for {img_name}: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
test_docs()