GLM-OCR-endpoint / handler.py
caiofabio1's picture
Upload handler.py with huggingface_hub
7c7c31b verified
raw
history blame
2.48 kB
import torch
import base64
import io
import re
from typing import Dict, List, Any
from PIL import Image
from transformers import AutoProcessor, AutoModelForImageTextToText
class EndpointHandler:
def __init__(self, path=""):
self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
self.model = AutoModelForImageTextToText.from_pretrained(
path,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
).eval()
self.device = next(self.model.parameters()).device
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs_data = data.get("inputs", data)
# Accept base64 image
if isinstance(inputs_data, dict):
image_b64 = inputs_data.get("image", "")
prompt = inputs_data.get("prompt", "Text Recognition:")
elif isinstance(inputs_data, str):
image_b64 = inputs_data
prompt = "Text Recognition:"
else:
return [{"error": "Invalid input format"}]
# Decode image
try:
image_bytes = base64.b64decode(image_b64)
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception as e:
return [{"error": f"Failed to decode image: {str(e)}"}]
# Build messages
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]
# Process
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
proc_inputs = self.processor(
text=[text], images=[image], padding=True, return_tensors="pt"
)
proc_inputs = {k: v.to(self.device) for k, v in proc_inputs.items()}
# Generate
with torch.no_grad():
output = self.model.generate(
**proc_inputs,
temperature=0.1,
max_new_tokens=8192,
do_sample=True,
)
prompt_len = proc_inputs["input_ids"].shape[1]
new_tokens = output[:, prompt_len:]
text_output = self.processor.tokenizer.batch_decode(
new_tokens, skip_special_tokens=True
)[0]
return [{"generated_text": text_output}]