File size: 2,407 Bytes
48869de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from typing import Dict, List, Any
from PIL import Image
import torch
import base64
import io
from transformers import AutoModelForCausalLM, AutoTokenizer

class EndpointHandler:
    def __init__(self, path=""):
        # Load model and tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(
            path, 
            device_map="auto", 
            torch_dtype=torch.bfloat16, 
            trust_remote_code=True
        ).eval()

    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Args:
            data (:obj: `Dict[str, Any]`):
                - "inputs": Base64 encoded image or URL
                - "parameters": Dict of generation parameters (optional)
        """
        # 1. Parse Input
        inputs = data.get("inputs", "")
        parameters = data.get("parameters", {})
        
        # Determine prompt: GLM-OCR uses specific triggers
        # Options: "Text Recognition:", "Table Recognition:", "Formula Recognition:"
        prompt = data.get("prompt", "Text Recognition:")

        # 2. Process Image
        if inputs.startswith("http"):
            import requests
            image = Image.open(requests.get(inputs, stream=True).raw).convert("RGB")
        else:
            # Assume base64
            image_data = base64.b64decode(inputs.split(",")[-1])
            image = Image.open(io.BytesIO(image_data)).convert("RGB")

        # 3. Build Message and Inference
        # GLM-OCR typically expects messages in a specific format
        query = self.tokenizer.from_list_format([
            {"image": image},
            {"text": prompt},
        ])
        
        # Generate
        # Note: Adjust max_new_tokens based on your document length
        inputs_processed = self.tokenizer(query, add_special_tokens=False, return_tensors="pt").to(self.model.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs_processed,
                max_new_tokens=parameters.get("max_new_tokens", 2048),
                do_sample=parameters.get("do_sample", False),
                **parameters
            )
        
        response = self.tokenizer.decode(outputs[0][inputs_processed['input_ids'].shape[1]:], skip_special_tokens=True)

        return [{"generated_text": response}]