| from typing import Dict, Any |
| import torch |
| import base64 |
| import io |
| from PIL import Image |
| from transformers import AutoModelForCausalLM, AutoProcessor |
|
|
| PROMPTS = { |
| "ocr": "OCR:", |
| "table": "Table Recognition:", |
| "formula": "Formula Recognition:", |
| "chart": "Chart Recognition:", |
| } |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True) |
| self.model = AutoModelForCausalLM.from_pretrained( |
| path, |
| trust_remote_code=True, |
| torch_dtype=torch.bfloat16, |
| ).to(self.device).eval() |
|
|
| def _load_image(self, image_field): |
| if isinstance(image_field, Image.Image): |
| return image_field.convert("RGB") |
| if isinstance(image_field, (bytes, bytearray)): |
| return Image.open(io.BytesIO(image_field)).convert("RGB") |
| if isinstance(image_field, str): |
| data = image_field |
| if data.startswith("data:"): |
| data = data.split(",", 1)[1] |
| return Image.open(io.BytesIO(base64.b64decode(data))).convert("RGB") |
| raise ValueError("Unsupported image input type") |
|
|
| def __call__(self, data): |
| inputs_data = data.get("inputs", data) |
| if isinstance(inputs_data, str): |
| inputs_data = {"image": inputs_data} |
|
|
| image_field = inputs_data.get("image") |
| if image_field is None: |
| return {"error": "Missing 'image' (base64-encoded) in inputs"} |
|
|
| params = data.get("parameters", {}) if isinstance(data, dict) else {} |
| task = inputs_data.get("task") or params.get("task", "ocr") |
| prompt = ( |
| inputs_data.get("prompt") |
| or params.get("prompt") |
| or PROMPTS.get(task, PROMPTS["ocr"]) |
| ) |
| max_new_tokens = int( |
| inputs_data.get("max_new_tokens") |
| or params.get("max_new_tokens", 1024) |
| ) |
|
|
| image = self._load_image(image_field) |
|
|
| messages = [{ |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": image}, |
| {"type": "text", "text": prompt}, |
| ], |
| }] |
|
|
| model_inputs = self.processor.apply_chat_template( |
| messages, |
| tokenize=True, |
| add_generation_prompt=True, |
| return_dict=True, |
| return_tensors="pt", |
| ).to(self.device) |
|
|
| with torch.inference_mode(): |
| output_ids = self.model.generate( |
| **model_inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=False, |
| use_cache=True, |
| ) |
|
|
| text = self.processor.batch_decode(output_ids, skip_special_tokens=True)[0] |
| return {"generated_text": text, "task": task} |
|
|