Spaces:
Running
on
Zero
Running
on
Zero
| from transformers import AutoModel, AutoTokenizer | |
| import torch | |
| import os | |
| from PIL import Image | |
| import time | |
| # Force CPU for stability | |
| device = "cpu" | |
| print(f"Using device: {device}") | |
| # Patch to avoid CUDA calls in custom code | |
| torch.Tensor.cuda = lambda self, *args, **kwargs: self.to(device) | |
| torch.nn.Module.cuda = lambda self, *args, **kwargs: self.to(device) | |
| model_name = 'deepseek-ai/DeepSeek-OCR-2' | |
| def test_docs(): | |
| print(f"Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| print(f"Loading model (may take a minute)...") | |
| # Load with default parameters that worked in test_minimal.py | |
| model = AutoModel.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| use_safetensors=True | |
| ) | |
| model = model.eval() | |
| # After loading, we monkeypatch bfloat16 for the inference logic | |
| torch.bfloat16 = torch.float32 | |
| image_dir = "doc_images" | |
| output_dir = "ocr_results" | |
| os.makedirs(output_dir, exist_ok=True) | |
| images = sorted([f for f in os.listdir(image_dir) if f.endswith(".png")]) | |
| for img_name in images: | |
| img_path = os.path.join(image_dir, img_name) | |
| print(f"\n--- Processing: {img_name} ---") | |
| # DeepSeek-OCR-2 needs specific ratios for its hardcoded query embeddings | |
| # base_size=1024 -> n_query=256 (supported) | |
| # image_size=768 -> n_query=144 (supported) | |
| prompt = "<image>\nFree OCR. " | |
| start_time = time.time() | |
| try: | |
| with torch.no_grad(): | |
| res = model.infer( | |
| tokenizer, | |
| prompt=prompt, | |
| image_file=img_path, | |
| output_path=output_dir, | |
| base_size=1024, # Must be 1024 for 256 queries | |
| image_size=768, # Must be 768 for 144 queries | |
| crop_mode=False, | |
| eval_mode=True | |
| ) | |
| elapsed = time.time() - start_time | |
| print(f"Done in {elapsed:.2f}s") | |
| result_file = os.path.join(output_dir, f"{img_name}.md") | |
| with open(result_file, "w") as f: | |
| f.write(res) | |
| print(f"Result saved to {result_file}") | |
| print("Preview (first 500 chars):") | |
| print("-" * 20) | |
| print(res[:500] + "...") | |
| print("-" * 20) | |
| except Exception as e: | |
| print(f"Inference failed for {img_name}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| if __name__ == "__main__": | |
| test_docs() | |