File size: 2,683 Bytes
c35506c 68dd4db ae154ba 68dd4db c35506c 68dd4db c35506c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 | from transformers import AutoModel, AutoTokenizer
from typing import Dict, List, Any
import torch
import base64
from io import BytesIO
from PIL import Image
import os
class EndpointHandler:
def __init__(self, model_dir = 'deepseek-ai/DeepSeek-OCR'):
model_path = model_dir
self.tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
local_files_only=bool(model_dir) # Only use local files if model_dir is provided
)
# Check if CUDA is available
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load model with appropriate settings
model_kwargs = {
'trust_remote_code': True,
'torch_dtype': torch.bfloat16 if self.device == 'cuda' else torch.float32
}
# Add flash attention if available and on CUDA
if self.device == 'cuda':
try:
model_kwargs['_attn_implementation'] = 'flash_attention_2'
except:
pass # Fall back to default if flash attention not available
self.model = AutoModel.from_pretrained(model_path, **model_kwargs)
self.model = self.model.eval()
# Move to appropriate device
if self.device == 'cuda':
self.model = self.model.cuda()
def __call__(self, data: Dict[str, Any]) -> str:
try:
inputs = data.get("inputs")
base64_string = inputs["base64"]
# Remove data URL prefix if present
if ',' in base64_string:
base64_string = base64_string.split(',')[1]
# Decode base64 to image
image_data = base64.b64decode(base64_string)
image = Image.open(BytesIO(image_data))
# Convert to RGB if necessary (handles PNG, JPEG, etc.)
if image.mode != 'RGB':
image = image.convert('RGB')
# Define the prompt for Markdown conversion
prompt = "<image>\n<|grounding|>Convert the document to markdown."
# Run OCR inference
result = self.model.infer(
self.tokenizer,
prompt=prompt,
image_file=image, # Pass PIL Image directly
output_path=output_path,
base_size=1024,
image_size=640,
crop_mode=True,
save_results=output_path is not None
)
return result
except Exception as e:
print(f"Error processing image: {e}")
return None |