File size: 2,845 Bytes
2bf2b8c
 
bffb0ab
0fbd23f
bffb0ab
 
2bf2b8c
0fbd23f
 
9053cea
0fbd23f
2bf2b8c
0fbd23f
 
2bf2b8c
 
 
 
 
 
 
 
 
0fbd23f
9053cea
2bf2b8c
9053cea
 
 
 
0fbd23f
2bf2b8c
9053cea
 
6e113e1
9053cea
2bf2b8c
 
 
 
 
 
 
 
9053cea
 
2bf2b8c
9053cea
 
2bf2b8c
9053cea
 
 
2bf2b8c
 
 
 
 
 
9053cea
2bf2b8c
 
 
9053cea
2bf2b8c
 
bffb0ab
2bf2b8c
 
bffb0ab
 
 
2bf2b8c
 
 
 
 
 
 
 
 
 
 
 
 
9053cea
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from typing import Dict, Any
from transformers import AutoModelForImageTextToText, AutoProcessor
from PIL import Image
import torch
import base64
import io
import os
from subprocess import run

run("pip3 install pillow==9.4.0", shell=True, check=True)

class EndpointHandler:
    def __init__(self, path=""):
        self.HF_READ_TOKEN = os.getenv("HF_READ_TOKEN")
        # Load model and processor with 4-bit quantization
        self.model = AutoModelForImageTextToText.from_pretrained(
            path,
            load_in_4bit=True,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        self.processor = AutoProcessor.from_pretrained(path)
        print("Model and processor loaded")

    def decode_image(self, image_input):
        """Convert base64 string or bytes to PIL Image"""
        if isinstance(image_input, str):
            image_data = base64.b64decode(image_input)
        elif isinstance(image_input, bytes):
            image_data = image_input
        else:
            raise ValueError("Image must be base64 string or bytes")
        return Image.open(io.BytesIO(image_data)).convert("RGB")

    def __call__(self, data)
        try:
            # Validate inputs
            if "image" not in data or "text" not in data:
                return {"error": "Both 'image' and 'text' are required"}
            
            # Process image
            image = self.decode_image(data["image"])
            
            # Prepare chat template
            messages = [
                {
                    "role": "user", 
                    "content": [
                        {"type": "image"},
                        {"type": "text", "text": data["text"]}
                    ]
                }
            ]
            input_text = self.processor.apply_chat_template(messages)
            
            # Process inputs
            inputs = self.processor(
                text=input_text,
                images=image,
                return_tensors="pt",
                padding=True,
                truncation=True
            ).to("cuda")

            # Generate response
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=256,
                temperature=0.7,
                min_p=0.1,
                do_sample=True
            )
            
            # Decode and clean output
            generated = self.processor.decode(
                outputs[0], 
                skip_special_tokens=True
            ).strip()
            
            # Remove prompt fragments
            if "assistant" in generated:
                generated = generated.split("assistant")[-1].strip()
                
            return {"generated_text": generated}
            
        except Exception as e:
            return {"error": str(e)}