File size: 2,683 Bytes
c35506c
 
 
 
 
 
 
 
 
68dd4db
ae154ba
68dd4db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c35506c
 
 
68dd4db
c35506c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from transformers import AutoModel, AutoTokenizer
from typing import Dict, List, Any
import torch
import base64
from io import BytesIO
from PIL import Image
import os

class EndpointHandler:
    def __init__(self, model_dir = 'deepseek-ai/DeepSeek-OCR'):
        model_path = model_dir
        
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path, 
            trust_remote_code=True,
            local_files_only=bool(model_dir)  # Only use local files if model_dir is provided
        )
        
        # Check if CUDA is available
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        # Load model with appropriate settings
        model_kwargs = {
            'trust_remote_code': True,
            'torch_dtype': torch.bfloat16 if self.device == 'cuda' else torch.float32
        }
        
        # Add flash attention if available and on CUDA
        if self.device == 'cuda':
            try:
                model_kwargs['_attn_implementation'] = 'flash_attention_2'
            except:
                pass  # Fall back to default if flash attention not available
        
        self.model = AutoModel.from_pretrained(model_path, **model_kwargs)
        self.model = self.model.eval()
        
        # Move to appropriate device
        if self.device == 'cuda':
            self.model = self.model.cuda()
    
    def __call__(self, data: Dict[str, Any]) -> str:
        try:
            inputs = data.get("inputs")
            base64_string = inputs["base64"]
            # Remove data URL prefix if present
            if ',' in base64_string:
                base64_string = base64_string.split(',')[1]
        
            # Decode base64 to image
            image_data = base64.b64decode(base64_string)
            image = Image.open(BytesIO(image_data))
        
        # Convert to RGB if necessary (handles PNG, JPEG, etc.)
            if image.mode != 'RGB':
                image = image.convert('RGB')
        
        # Define the prompt for Markdown conversion
            prompt = "<image>\n<|grounding|>Convert the document to markdown."
        
        # Run OCR inference
            result = self.model.infer(
                self.tokenizer,
                prompt=prompt,
                image_file=image,  # Pass PIL Image directly
                output_path=output_path,
                base_size=1024,
                image_size=640,
                crop_mode=True,
                save_results=output_path is not None
            )
        
            return result
        
        except Exception as e:
            print(f"Error processing image: {e}")
            return None