deepseek-OCR / handler.py
wealthcoders's picture
Update handler.py
ae154ba verified
raw
history blame
2.68 kB
from transformers import AutoModel, AutoTokenizer
from typing import Dict, List, Any
import torch
import base64
from io import BytesIO
from PIL import Image
import os
class EndpointHandler:
def __init__(self, model_dir = 'deepseek-ai/DeepSeek-OCR'):
model_path = model_dir
self.tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
local_files_only=bool(model_dir) # Only use local files if model_dir is provided
)
# Check if CUDA is available
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load model with appropriate settings
model_kwargs = {
'trust_remote_code': True,
'torch_dtype': torch.bfloat16 if self.device == 'cuda' else torch.float32
}
# Add flash attention if available and on CUDA
if self.device == 'cuda':
try:
model_kwargs['_attn_implementation'] = 'flash_attention_2'
except:
pass # Fall back to default if flash attention not available
self.model = AutoModel.from_pretrained(model_path, **model_kwargs)
self.model = self.model.eval()
# Move to appropriate device
if self.device == 'cuda':
self.model = self.model.cuda()
def __call__(self, data: Dict[str, Any]) -> str:
try:
inputs = data.get("inputs")
base64_string = inputs["base64"]
# Remove data URL prefix if present
if ',' in base64_string:
base64_string = base64_string.split(',')[1]
# Decode base64 to image
image_data = base64.b64decode(base64_string)
image = Image.open(BytesIO(image_data))
# Convert to RGB if necessary (handles PNG, JPEG, etc.)
if image.mode != 'RGB':
image = image.convert('RGB')
# Define the prompt for Markdown conversion
prompt = "<image>\n<|grounding|>Convert the document to markdown."
# Run OCR inference
result = self.model.infer(
self.tokenizer,
prompt=prompt,
image_file=image, # Pass PIL Image directly
output_path=output_path,
base_size=1024,
image_size=640,
crop_mode=True,
save_results=output_path is not None
)
return result
except Exception as e:
print(f"Error processing image: {e}")
return None