File size: 3,151 Bytes
eec31df
5053334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from transformers import AutoModel, AutoTokenizer, AutoModelForImageTextToText
from typing import Dict, List, Any
import torch
import base64
from io import BytesIO
from PIL import Image
import os
import tempfile

class EndpointHandler:
    def __init__(self, model_dir = 'scb10x/typhoon-ocr1.5-2b'):
        model_path = model_dir
        
        self.model = AutoModelForImageTextToText.from_pretrained(model_path, dtype="auto", device_map="auto")
        selfprocessor = AutoProcessor.from_pretrained(model_path)
    
    def __call__(self, data: Dict[str, Any]) -> str:
        try:
            base64_string = None
            if "inputs" in data and isinstance(data["inputs"], str):
                base64_string = data["inputs"]
            
            # Case 2: Base64 in nested inputs dictionary
            elif "inputs" in data and isinstance(data["inputs"], dict):
                base64_string = data["inputs"].get("base64")
            
            # Case 3: Direct base64 at root level
            elif "base64" in data:
                base64_string = data["base64"]
            
            # Case 4: Try raw data as base64
            elif isinstance(data, str):
                base64_string = data
            
            if not base64_string:
                return {"error": "No base64 string found in input data. Available keys: " + str(data.keys())}
            
            print("Found base64 string, length:", len(base64_string))
              
            # Remove data URL prefix if present
            if ',' in base64_string:
                base64_string = base64_string.split(',')[1]
        
            # Decode base64 to image
            image_data = base64.b64decode(base64_string)
        
            messages = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "image": image_data,
                        },
                        {
                            "type": "text",
                            "text": prompt
                        }
                    ],
                }
            ]
            # Preparation for inference
            inputs = self.processor.apply_chat_template(
                messages,
                tokenize=True,
                add_generation_prompt=True,
                return_dict=True,
                return_tensors="pt"
            )
            inputs = inputs.to(self.model.device)

            # Inference: Generation of the output
            generated_ids = self.model.generate(**inputs, max_new_tokens=10000)
            generated_ids_trimmed = [
                out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
            ]
            output_text = self.processor.batch_decode(
                generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )
            print(output_text[0])

            return output_text[0]
        
        except Exception as e:
            print(f"Error processing image: {e}")
            return str(e)