File size: 4,806 Bytes
c35506c
 
 
 
 
 
 
0c1f2c5
c35506c
 
68dd4db
ae154ba
68dd4db
 
 
 
82c7d5c
68dd4db
 
 
 
82c7d5c
68dd4db
7b796cd
68dd4db
 
7b796cd
68dd4db
 
82c7d5c
 
68dd4db
 
 
 
 
 
 
c35506c
 
 
1f02db6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7cfe72
1f02db6
 
 
 
c35506c
 
 
 
 
 
 
40c2ea5
de83374
40c2ea5
0c1f2c5
 
8a4adbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9f73c7
470f97a
 
 
f998089
 
 
 
 
470f97a
f998089
 
 
5c8d9ac
75411f2
f998089
47a7709
470f97a
 
476f460
470f97a
476f460
 
 
 
 
 
 
c35506c
 
 
b899856
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from transformers import AutoModel, AutoTokenizer
from typing import Dict, List, Any
import torch
import base64
from io import BytesIO
from PIL import Image
import os
import tempfile

class EndpointHandler:
    def __init__(self, model_dir = 'deepseek-ai/DeepSeek-OCR'):
        model_path = model_dir
        
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path, 
            trust_remote_code=True,
            local_files_only=bool(model_dir)
        )
        
        # Check if CUDA is available
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(f"Using device: {self.device}")
        
       # Load model in float32 to avoid dtype conflicts
        model_kwargs = {
            'trust_remote_code': True,
            'torch_dtype': torch.float32  # Use float32 instead of float16
        }
        
        # Explicitly disable flash attention
        model_kwargs['_attn_implementation'] = 'eager'
        
        self.model = AutoModel.from_pretrained(model_path, **model_kwargs)
        self.model = self.model.eval()
        
        # Move to appropriate device
        if self.device == 'cuda':
            self.model = self.model.cuda()
    
    def __call__(self, data: Dict[str, Any]) -> str:
        try:
            base64_string = None
            if "inputs" in data and isinstance(data["inputs"], str):
                base64_string = data["inputs"]
            
            # Case 2: Base64 in nested inputs dictionary
            elif "inputs" in data and isinstance(data["inputs"], dict):
                base64_string = data["inputs"].get("base64")
            
            # Case 3: Direct base64 at root level
            elif "base64" in data:
                base64_string = data["base64"]
            
            # Case 4: Try raw data as base64
            elif isinstance(data, str):
                base64_string = data
            
            if not base64_string:
                return {"error": "No base64 string found in input data. Available keys: " + str(data.keys())}
            
            print("Found base64 string, length:", len(base64_string))
              
            # Remove data URL prefix if present
            if ',' in base64_string:
                base64_string = base64_string.split(',')[1]
        
            # Decode base64 to image
            image_data = base64.b64decode(base64_string)
        
            # Define the prompt for Markdown conversion
            prompt = "<image>\n<|grounding|>Convert this document to markdown format using # headers, **bold** for important information, and Markdown table syntax (using | and -) instead of HTML."
            
            with tempfile.TemporaryDirectory() as temp_dir:
                
                image_path = os.path.join(temp_dir, "input_image.png")
                with open(image_path, "wb") as f:
                    f.write(image_data)
                
                print(f"Image saved to: {image_path}")
                
                # Verify the image can be opened
                try:
                    test_image = Image.open(image_path)
                    if test_image.mode != 'RGB':
                        test_image = test_image.convert('RGB')
                        test_image.save(image_path)  # Save converted version
                    print(f"Image verified: {test_image.size}, mode: {test_image.mode}")
                except Exception as img_error:
                    return {"error": f"Invalid image: {str(img_error)}"}

                output_dir = os.path.join(temp_dir, "deepseek_out")
                os.makedirs(output_dir, exist_ok=True)
                
                 # Run OCR inference
                result = self.model.infer(
                    self.tokenizer,
                    prompt=prompt,
                    image_file=image_path,  # Pass the PIL Image object directly
                    output_path=output_dir,
                    base_size=1024,
                    image_size=640,
                    crop_mode=True,
                    save_results=True,
                    #eval_mode=True
                )

                for fname in os.listdir(output_dir):
                    print("File:\n", fname)
                    if fname.endswith(".md") or fname.endswith(".mmd"):
                        md_path = os.path.join(output_dir, fname)
                        with open(md_path, 'r', encoding='utf-8') as f:
                            markdown = f.read()
                        print("Markdown output:\n", markdown)
                        return markdown

                #print(str(result))
                #return result
        
        except Exception as e:
            print(f"Error processing image: {e}")
            return str(e)