| from typing import Dict, List, Any |
| from PIL import Image |
| import torch |
| import base64 |
| import io |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) |
| self.model = AutoModelForCausalLM.from_pretrained( |
| path, |
| device_map="auto", |
| torch_dtype=torch.bfloat16, |
| trust_remote_code=True |
| ).eval() |
|
|
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| """ |
| Args: |
| data (:obj: `Dict[str, Any]`): |
| - "inputs": Base64 encoded image or URL |
| - "parameters": Dict of generation parameters (optional) |
| """ |
| |
| inputs = data.get("inputs", "") |
| parameters = data.get("parameters", {}) |
| |
| |
| |
| prompt = data.get("prompt", "Text Recognition:") |
|
|
| |
| if inputs.startswith("http"): |
| import requests |
| image = Image.open(requests.get(inputs, stream=True).raw).convert("RGB") |
| else: |
| |
| image_data = base64.b64decode(inputs.split(",")[-1]) |
| image = Image.open(io.BytesIO(image_data)).convert("RGB") |
|
|
| |
| |
| query = self.tokenizer.from_list_format([ |
| {"image": image}, |
| {"text": prompt}, |
| ]) |
| |
| |
| |
| inputs_processed = self.tokenizer(query, add_special_tokens=False, return_tensors="pt").to(self.model.device) |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs_processed, |
| max_new_tokens=parameters.get("max_new_tokens", 2048), |
| do_sample=parameters.get("do_sample", False), |
| **parameters |
| ) |
| |
| response = self.tokenizer.decode(outputs[0][inputs_processed['input_ids'].shape[1]:], skip_special_tokens=True) |
|
|
| return [{"generated_text": response}] |