File size: 2,882 Bytes
d6b1b16 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | from typing import Dict, Any
import torch
import base64
import io
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
PROMPTS = {
"ocr": "OCR:",
"table": "Table Recognition:",
"formula": "Formula Recognition:",
"chart": "Chart Recognition:",
}
class EndpointHandler:
def __init__(self, path: str = ""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(
path,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
).to(self.device).eval()
def _load_image(self, image_field):
if isinstance(image_field, Image.Image):
return image_field.convert("RGB")
if isinstance(image_field, (bytes, bytearray)):
return Image.open(io.BytesIO(image_field)).convert("RGB")
if isinstance(image_field, str):
data = image_field
if data.startswith("data:"):
data = data.split(",", 1)[1]
return Image.open(io.BytesIO(base64.b64decode(data))).convert("RGB")
raise ValueError("Unsupported image input type")
def __call__(self, data):
inputs_data = data.get("inputs", data)
if isinstance(inputs_data, str):
inputs_data = {"image": inputs_data}
image_field = inputs_data.get("image")
if image_field is None:
return {"error": "Missing 'image' (base64-encoded) in inputs"}
params = data.get("parameters", {}) if isinstance(data, dict) else {}
task = inputs_data.get("task") or params.get("task", "ocr")
prompt = (
inputs_data.get("prompt")
or params.get("prompt")
or PROMPTS.get(task, PROMPTS["ocr"])
)
max_new_tokens = int(
inputs_data.get("max_new_tokens")
or params.get("max_new_tokens", 1024)
)
image = self._load_image(image_field)
messages = [{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}]
model_inputs = self.processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
).to(self.device)
with torch.inference_mode():
output_ids = self.model.generate(
**model_inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
use_cache=True,
)
text = self.processor.batch_decode(output_ids, skip_special_tokens=True)[0]
return {"generated_text": text, "task": task}
|