File size: 1,924 Bytes
b752d16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from transformers import AutoModel, AutoTokenizer
import torch
import torch.nn as nn
import os
from PIL import Image, ImageOps
import math

# Force CPU
device = "cpu"
dtype = torch.float32
print(f"Forcing device: {device} with dtype: {dtype}")

# Patch torch types to avoid mixed precision errors in their custom code
torch.bfloat16 = torch.float32  # Force bfloat16 to float32
torch.Tensor.cuda = lambda self, *args, **kwargs: self.to("cpu")
torch.nn.Module.cuda = lambda self, *args, **kwargs: self.to("cpu")

model_name = 'deepseek-ai/DeepSeek-OCR-2'

def test_inference():
    print(f"Loading tokenizer for {model_name}...")
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

    print(f"Loading model for {model_name}...")
    model = AutoModel.from_pretrained(
        model_name, 
        trust_remote_code=True, 
        use_safetensors=True,
        torch_dtype=torch.float32 # Explicitly float32
    )
    
    model = model.eval() # Already on CPU by default if no device_map

    output_dir = 'outputs'
    os.makedirs(output_dir, exist_ok=True)

    prompt = "<image>\nFree OCR. "
    image_file = 'sample_test.png'

    if not os.path.exists(image_file):
        print(f"Error: {image_file} not found.")
        return

    print("Running inference on CPU...")
    try:
        with torch.no_grad():
            res = model.infer(
                tokenizer, 
                prompt=prompt, 
                image_file=image_file, 
                output_path=output_dir,
                base_size=512, 
                image_size=384, 
                crop_mode=False,
                eval_mode=True
            )
        print("\n--- OCR Result ---")
        print(res)
        print("------------------")
    except Exception as e:
        print(f"Inference failed: {e}")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    test_inference()