| import subprocess |
| import sys |
| import os |
|
|
| |
| if not os.path.exists("/tmp/.transformers_installed"): |
| subprocess.check_call([ |
| sys.executable, "-m", "pip", "install", "-q", "--upgrade", |
| "git+https://github.com/huggingface/transformers.git", |
| "accelerate" |
| ]) |
| open("/tmp/.transformers_installed", "w").close() |
|
|
| import torch |
| import base64 |
| import io |
| from typing import Dict, List, Any |
| from PIL import Image |
| from transformers import AutoProcessor, AutoModelForImageTextToText |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True) |
| self.model = AutoModelForImageTextToText.from_pretrained( |
| path, |
| torch_dtype=torch.float16, |
| device_map="auto", |
| trust_remote_code=True, |
| ).eval() |
| self.device = next(self.model.parameters()).device |
| print(f"GLM-OCR loaded on {self.device}") |
|
|
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| inputs_data = data.get("inputs", data) |
|
|
| if isinstance(inputs_data, dict): |
| image_b64 = inputs_data.get("image", "") |
| prompt = inputs_data.get("prompt", "Text Recognition:") |
| elif isinstance(inputs_data, str): |
| image_b64 = inputs_data |
| prompt = "Text Recognition:" |
| else: |
| return [{"error": "Send {inputs: {image: base64, prompt: str}}"}] |
|
|
| try: |
| image_bytes = base64.b64decode(image_b64) |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
| except Exception as e: |
| return [{"error": f"Image decode failed: {str(e)}"}] |
|
|
| messages = [{ |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": image}, |
| {"type": "text", "text": prompt}, |
| ], |
| }] |
|
|
| text = self.processor.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
| proc_inputs = self.processor( |
| text=[text], images=[image], padding=True, return_tensors="pt" |
| ) |
| proc_inputs = {k: v.to(self.device) for k, v in proc_inputs.items()} |
|
|
| with torch.no_grad(): |
| output = self.model.generate( |
| **proc_inputs, |
| temperature=0.1, |
| max_new_tokens=8192, |
| do_sample=True, |
| ) |
|
|
| prompt_len = proc_inputs["input_ids"].shape[1] |
| new_tokens = output[:, prompt_len:] |
| text_output = self.processor.tokenizer.batch_decode( |
| new_tokens, skip_special_tokens=True |
| )[0] |
|
|
| return [{"generated_text": text_output}] |
|
|