deepseek-OCR / handler.py
wealthcoders's picture
Update handler.py
75411f2 verified
from transformers import AutoModel, AutoTokenizer
from typing import Dict, List, Any
import torch
import base64
from io import BytesIO
from PIL import Image
import os
import tempfile
class EndpointHandler:
def __init__(self, model_dir = 'deepseek-ai/DeepSeek-OCR'):
model_path = model_dir
self.tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
local_files_only=bool(model_dir)
)
# Check if CUDA is available
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {self.device}")
# Load model in float32 to avoid dtype conflicts
model_kwargs = {
'trust_remote_code': True,
'torch_dtype': torch.float32 # Use float32 instead of float16
}
# Explicitly disable flash attention
model_kwargs['_attn_implementation'] = 'eager'
self.model = AutoModel.from_pretrained(model_path, **model_kwargs)
self.model = self.model.eval()
# Move to appropriate device
if self.device == 'cuda':
self.model = self.model.cuda()
def __call__(self, data: Dict[str, Any]) -> str:
try:
base64_string = None
if "inputs" in data and isinstance(data["inputs"], str):
base64_string = data["inputs"]
# Case 2: Base64 in nested inputs dictionary
elif "inputs" in data and isinstance(data["inputs"], dict):
base64_string = data["inputs"].get("base64")
# Case 3: Direct base64 at root level
elif "base64" in data:
base64_string = data["base64"]
# Case 4: Try raw data as base64
elif isinstance(data, str):
base64_string = data
if not base64_string:
return {"error": "No base64 string found in input data. Available keys: " + str(data.keys())}
print("Found base64 string, length:", len(base64_string))
# Remove data URL prefix if present
if ',' in base64_string:
base64_string = base64_string.split(',')[1]
# Decode base64 to image
image_data = base64.b64decode(base64_string)
# Define the prompt for Markdown conversion
prompt = "<image>\n<|grounding|>Convert this document to markdown format using # headers, **bold** for important information, and Markdown table syntax (using | and -) instead of HTML."
with tempfile.TemporaryDirectory() as temp_dir:
image_path = os.path.join(temp_dir, "input_image.png")
with open(image_path, "wb") as f:
f.write(image_data)
print(f"Image saved to: {image_path}")
# Verify the image can be opened
try:
test_image = Image.open(image_path)
if test_image.mode != 'RGB':
test_image = test_image.convert('RGB')
test_image.save(image_path) # Save converted version
print(f"Image verified: {test_image.size}, mode: {test_image.mode}")
except Exception as img_error:
return {"error": f"Invalid image: {str(img_error)}"}
output_dir = os.path.join(temp_dir, "deepseek_out")
os.makedirs(output_dir, exist_ok=True)
# Run OCR inference
result = self.model.infer(
self.tokenizer,
prompt=prompt,
image_file=image_path, # Pass the PIL Image object directly
output_path=output_dir,
base_size=1024,
image_size=640,
crop_mode=True,
save_results=True,
#eval_mode=True
)
for fname in os.listdir(output_dir):
print("File:\n", fname)
if fname.endswith(".md") or fname.endswith(".mmd"):
md_path = os.path.join(output_dir, fname)
with open(md_path, 'r', encoding='utf-8') as f:
markdown = f.read()
print("Markdown output:\n", markdown)
return markdown
#print(str(result))
#return result
except Exception as e:
print(f"Error processing image: {e}")
return str(e)