GLM-OCR / handler.py
nanekhachatryan's picture
Added handler.py to enable HF Inference
48869de verified
from typing import Dict, List, Any
from PIL import Image
import torch
import base64
import io
from transformers import AutoModelForCausalLM, AutoTokenizer
class EndpointHandler:
def __init__(self, path=""):
# Load model and tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(
path,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True
).eval()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Args:
data (:obj: `Dict[str, Any]`):
- "inputs": Base64 encoded image or URL
- "parameters": Dict of generation parameters (optional)
"""
# 1. Parse Input
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
# Determine prompt: GLM-OCR uses specific triggers
# Options: "Text Recognition:", "Table Recognition:", "Formula Recognition:"
prompt = data.get("prompt", "Text Recognition:")
# 2. Process Image
if inputs.startswith("http"):
import requests
image = Image.open(requests.get(inputs, stream=True).raw).convert("RGB")
else:
# Assume base64
image_data = base64.b64decode(inputs.split(",")[-1])
image = Image.open(io.BytesIO(image_data)).convert("RGB")
# 3. Build Message and Inference
# GLM-OCR typically expects messages in a specific format
query = self.tokenizer.from_list_format([
{"image": image},
{"text": prompt},
])
# Generate
# Note: Adjust max_new_tokens based on your document length
inputs_processed = self.tokenizer(query, add_special_tokens=False, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs_processed,
max_new_tokens=parameters.get("max_new_tokens", 2048),
do_sample=parameters.get("do_sample", False),
**parameters
)
response = self.tokenizer.decode(outputs[0][inputs_processed['input_ids'].shape[1]:], skip_special_tokens=True)
return [{"generated_text": response}]