PaddleOCR-VL-endpoint / handler.py
Tuannodev's picture
Update handler.py
92f1a62 verified
import subprocess
import sys
import os
# Install transformers dev build that supports paddleocr_vl
subprocess.check_call([
sys.executable, "-m", "pip", "install", "--quiet", "--upgrade",
"git+https://github.com/huggingface/transformers.git",
])
# Monkey-patch old API that huggingface_inference_toolkit depends on
import transformers
if not hasattr(transformers, "file_utils"):
import types
file_utils_shim = types.ModuleType("transformers.file_utils")
# Provide stubs for commonly used symbols
file_utils_shim.is_tf_available = transformers.utils.is_tf_available
file_utils_shim.is_torch_available = transformers.utils.is_torch_available
sys.modules["transformers.file_utils"] = file_utils_shim
transformers.file_utils = file_utils_shim
from transformers import AutoProcessor, AutoModelForImageTextToText
import torch
from PIL import Image
import requests
from io import BytesIO
import base64
MODEL_ID = "strangervisionhf/PaddleOCR-VL-1.5-hf-transformers-v5.2.0.dev0"
class EndpointHandler:
def __init__(self, path=""):
self.processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
self.model = AutoModelForImageTextToText.from_pretrained(
MODEL_ID,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="auto",
)
self.model.eval()
def __call__(self, data: dict):
inputs_data = data.get("inputs", data)
image_src = inputs_data.get("image", "")
if not image_src:
return {"error": "No image provided"}
if image_src.startswith("http://") or image_src.startswith("https://"):
response = requests.get(image_src, stream=True, timeout=30)
image = Image.open(response.raw).convert("RGB")
else:
image = Image.open(BytesIO(base64.b64decode(image_src))).convert("RGB")
prompt = inputs_data.get("text", "Recognize text in the image.")
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]
text_input = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
model_inputs = self.processor(
text=[text_input],
images=[image],
return_tensors="pt",
).to(self.model.device)
with torch.no_grad():
output_ids = self.model.generate(
**model_inputs,
max_new_tokens=512,
do_sample=False,
)
output_ids_trimmed = [
o[len(i):] for i, o in zip(model_inputs.input_ids, output_ids)
]
output_text = self.processor.batch_decode(
output_ids_trimmed, skip_special_tokens=True
)
return {"generated_text": output_text[0]}