File size: 3,254 Bytes
ce654c7
5053334
 
 
a92f70e
5053334
 
 
 
 
 
 
 
 
 
e630943
5053334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc894d4
 
 
5053334
 
 
 
 
 
bc894d4
5053334
 
 
69a70df
5053334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from transformers import AutoModel, AutoTokenizer, AutoModelForImageTextToText, AutoProcessor
from typing import Dict, List, Any
import torch
import base64
import io
from io import BytesIO
from PIL import Image
import os
import tempfile

class EndpointHandler:
    def __init__(self, model_dir = 'scb10x/typhoon-ocr1.5-2b'):
        model_path = model_dir
        
        self.model = AutoModelForImageTextToText.from_pretrained(model_path, dtype="auto", device_map="auto")
        self.processor = AutoProcessor.from_pretrained(model_path)
    
    def __call__(self, data: Dict[str, Any]) -> str:
        try:
            base64_string = None
            if "inputs" in data and isinstance(data["inputs"], str):
                base64_string = data["inputs"]
            
            # Case 2: Base64 in nested inputs dictionary
            elif "inputs" in data and isinstance(data["inputs"], dict):
                base64_string = data["inputs"].get("base64")
            
            # Case 3: Direct base64 at root level
            elif "base64" in data:
                base64_string = data["base64"]
            
            # Case 4: Try raw data as base64
            elif isinstance(data, str):
                base64_string = data
            
            if not base64_string:
                return {"error": "No base64 string found in input data. Available keys: " + str(data.keys())}
            
            print("Found base64 string, length:", len(base64_string))
              
            # Remove data URL prefix if present
            if ',' in base64_string:
                base64_string = base64_string.split(',')[1]
        
            # Decode base64 to image
            image_data = base64.b64decode(base64_string)

            image = Image.open(io.BytesIO(image_data))
            
            messages = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "image": image,
                        },
                        {
                            "type": "text",
                            "text": "Return content as markdown"
                        }
                    ],
                }
            ]
            # Preparation for inference
            inputs = self.processor.apply_chat_template(
                messages,
                tokenize=True,
                add_generation_prompt=True,
                return_dict=True,
                return_tensors="pt"
            )
            inputs = inputs.to(self.model.device)

            # Inference: Generation of the output
            generated_ids = self.model.generate(**inputs, max_new_tokens=10000)
            generated_ids_trimmed = [
                out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
            ]
            output_text = self.processor.batch_decode(
                generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )
            print(output_text[0])

            return output_text[0]
        
        except Exception as e:
            print(f"Error processing image: {e}")
            return str(e)