File size: 2,407 Bytes
48869de | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 | from typing import Dict, List, Any
from PIL import Image
import torch
import base64
import io
from transformers import AutoModelForCausalLM, AutoTokenizer
class EndpointHandler:
def __init__(self, path=""):
# Load model and tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(
path,
device_map="auto",
torch_dtype=torch.bfloat16,
trust_remote_code=True
).eval()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Args:
data (:obj: `Dict[str, Any]`):
- "inputs": Base64 encoded image or URL
- "parameters": Dict of generation parameters (optional)
"""
# 1. Parse Input
inputs = data.get("inputs", "")
parameters = data.get("parameters", {})
# Determine prompt: GLM-OCR uses specific triggers
# Options: "Text Recognition:", "Table Recognition:", "Formula Recognition:"
prompt = data.get("prompt", "Text Recognition:")
# 2. Process Image
if inputs.startswith("http"):
import requests
image = Image.open(requests.get(inputs, stream=True).raw).convert("RGB")
else:
# Assume base64
image_data = base64.b64decode(inputs.split(",")[-1])
image = Image.open(io.BytesIO(image_data)).convert("RGB")
# 3. Build Message and Inference
# GLM-OCR typically expects messages in a specific format
query = self.tokenizer.from_list_format([
{"image": image},
{"text": prompt},
])
# Generate
# Note: Adjust max_new_tokens based on your document length
inputs_processed = self.tokenizer(query, add_special_tokens=False, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs_processed,
max_new_tokens=parameters.get("max_new_tokens", 2048),
do_sample=parameters.get("do_sample", False),
**parameters
)
response = self.tokenizer.decode(outputs[0][inputs_processed['input_ids'].shape[1]:], skip_special_tokens=True)
return [{"generated_text": response}] |