File size: 2,882 Bytes
d6b1b16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from typing import Dict, Any
import torch
import base64
import io
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor

PROMPTS = {
    "ocr": "OCR:",
    "table": "Table Recognition:",
    "formula": "Formula Recognition:",
    "chart": "Chart Recognition:",
}


class EndpointHandler:
    def __init__(self, path: str = ""):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.processor = AutoProcessor.from_pretrained(path, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(
            path,
            trust_remote_code=True,
            torch_dtype=torch.bfloat16,
        ).to(self.device).eval()

    def _load_image(self, image_field):
        if isinstance(image_field, Image.Image):
            return image_field.convert("RGB")
        if isinstance(image_field, (bytes, bytearray)):
            return Image.open(io.BytesIO(image_field)).convert("RGB")
        if isinstance(image_field, str):
            data = image_field
            if data.startswith("data:"):
                data = data.split(",", 1)[1]
            return Image.open(io.BytesIO(base64.b64decode(data))).convert("RGB")
        raise ValueError("Unsupported image input type")

    def __call__(self, data):
        inputs_data = data.get("inputs", data)
        if isinstance(inputs_data, str):
            inputs_data = {"image": inputs_data}

        image_field = inputs_data.get("image")
        if image_field is None:
            return {"error": "Missing 'image' (base64-encoded) in inputs"}

        params = data.get("parameters", {}) if isinstance(data, dict) else {}
        task = inputs_data.get("task") or params.get("task", "ocr")
        prompt = (
            inputs_data.get("prompt")
            or params.get("prompt")
            or PROMPTS.get(task, PROMPTS["ocr"])
        )
        max_new_tokens = int(
            inputs_data.get("max_new_tokens")
            or params.get("max_new_tokens", 1024)
        )

        image = self._load_image(image_field)

        messages = [{
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": prompt},
            ],
        }]

        model_inputs = self.processor.apply_chat_template(
            messages,
            tokenize=True,
            add_generation_prompt=True,
            return_dict=True,
            return_tensors="pt",
        ).to(self.device)

        with torch.inference_mode():
            output_ids = self.model.generate(
                **model_inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                use_cache=True,
            )

        text = self.processor.batch_decode(output_ids, skip_special_tokens=True)[0]
        return {"generated_text": text, "task": task}