from typing import Dict, Any import torch import base64 import io from PIL import Image from transformers import AutoModelForCausalLM, AutoProcessor PROMPTS = { "ocr": "OCR:", "table": "Table Recognition:", "formula": "Formula Recognition:", "chart": "Chart Recognition:", } class EndpointHandler: def __init__(self, path: str = ""): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True) self.model = AutoModelForCausalLM.from_pretrained( path, trust_remote_code=True, torch_dtype=torch.bfloat16, ).to(self.device).eval() def _load_image(self, image_field): if isinstance(image_field, Image.Image): return image_field.convert("RGB") if isinstance(image_field, (bytes, bytearray)): return Image.open(io.BytesIO(image_field)).convert("RGB") if isinstance(image_field, str): data = image_field if data.startswith("data:"): data = data.split(",", 1)[1] return Image.open(io.BytesIO(base64.b64decode(data))).convert("RGB") raise ValueError("Unsupported image input type") def __call__(self, data): inputs_data = data.get("inputs", data) if isinstance(inputs_data, str): inputs_data = {"image": inputs_data} image_field = inputs_data.get("image") if image_field is None: return {"error": "Missing 'image' (base64-encoded) in inputs"} params = data.get("parameters", {}) if isinstance(data, dict) else {} task = inputs_data.get("task") or params.get("task", "ocr") prompt = ( inputs_data.get("prompt") or params.get("prompt") or PROMPTS.get(task, PROMPTS["ocr"]) ) max_new_tokens = int( inputs_data.get("max_new_tokens") or params.get("max_new_tokens", 1024) ) image = self._load_image(image_field) messages = [{ "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt}, ], }] model_inputs = self.processor.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt", ).to(self.device) with torch.inference_mode(): output_ids = self.model.generate( **model_inputs, max_new_tokens=max_new_tokens, do_sample=False, use_cache=True, ) text = self.processor.batch_decode(output_ids, skip_special_tokens=True)[0] return {"generated_text": text, "task": task}