File size: 4,806 Bytes
c35506c 0c1f2c5 c35506c 68dd4db ae154ba 68dd4db 82c7d5c 68dd4db 82c7d5c 68dd4db 7b796cd 68dd4db 7b796cd 68dd4db 82c7d5c 68dd4db c35506c 1f02db6 c7cfe72 1f02db6 c35506c 40c2ea5 de83374 40c2ea5 0c1f2c5 8a4adbe a9f73c7 470f97a f998089 470f97a f998089 5c8d9ac 75411f2 f998089 47a7709 470f97a 476f460 470f97a 476f460 c35506c b899856 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
from transformers import AutoModel, AutoTokenizer
from typing import Dict, List, Any
import torch
import base64
from io import BytesIO
from PIL import Image
import os
import tempfile
class EndpointHandler:
def __init__(self, model_dir = 'deepseek-ai/DeepSeek-OCR'):
model_path = model_dir
self.tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
local_files_only=bool(model_dir)
)
# Check if CUDA is available
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {self.device}")
# Load model in float32 to avoid dtype conflicts
model_kwargs = {
'trust_remote_code': True,
'torch_dtype': torch.float32 # Use float32 instead of float16
}
# Explicitly disable flash attention
model_kwargs['_attn_implementation'] = 'eager'
self.model = AutoModel.from_pretrained(model_path, **model_kwargs)
self.model = self.model.eval()
# Move to appropriate device
if self.device == 'cuda':
self.model = self.model.cuda()
def __call__(self, data: Dict[str, Any]) -> str:
try:
base64_string = None
if "inputs" in data and isinstance(data["inputs"], str):
base64_string = data["inputs"]
# Case 2: Base64 in nested inputs dictionary
elif "inputs" in data and isinstance(data["inputs"], dict):
base64_string = data["inputs"].get("base64")
# Case 3: Direct base64 at root level
elif "base64" in data:
base64_string = data["base64"]
# Case 4: Try raw data as base64
elif isinstance(data, str):
base64_string = data
if not base64_string:
return {"error": "No base64 string found in input data. Available keys: " + str(data.keys())}
print("Found base64 string, length:", len(base64_string))
# Remove data URL prefix if present
if ',' in base64_string:
base64_string = base64_string.split(',')[1]
# Decode base64 to image
image_data = base64.b64decode(base64_string)
# Define the prompt for Markdown conversion
prompt = "<image>\n<|grounding|>Convert this document to markdown format using # headers, **bold** for important information, and Markdown table syntax (using | and -) instead of HTML."
with tempfile.TemporaryDirectory() as temp_dir:
image_path = os.path.join(temp_dir, "input_image.png")
with open(image_path, "wb") as f:
f.write(image_data)
print(f"Image saved to: {image_path}")
# Verify the image can be opened
try:
test_image = Image.open(image_path)
if test_image.mode != 'RGB':
test_image = test_image.convert('RGB')
test_image.save(image_path) # Save converted version
print(f"Image verified: {test_image.size}, mode: {test_image.mode}")
except Exception as img_error:
return {"error": f"Invalid image: {str(img_error)}"}
output_dir = os.path.join(temp_dir, "deepseek_out")
os.makedirs(output_dir, exist_ok=True)
# Run OCR inference
result = self.model.infer(
self.tokenizer,
prompt=prompt,
image_file=image_path, # Pass the PIL Image object directly
output_path=output_dir,
base_size=1024,
image_size=640,
crop_mode=True,
save_results=True,
#eval_mode=True
)
for fname in os.listdir(output_dir):
print("File:\n", fname)
if fname.endswith(".md") or fname.endswith(".mmd"):
md_path = os.path.join(output_dir, fname)
with open(md_path, 'r', encoding='utf-8') as f:
markdown = f.read()
print("Markdown output:\n", markdown)
return markdown
#print(str(result))
#return result
except Exception as e:
print(f"Error processing image: {e}")
return str(e) |