File size: 2,900 Bytes
28a2884
 
 
99d1be1
513ab80
28a2884
4b05c12
e9beda9
 
856dd67
 
99d1be1
18c6d65
 
 
 
 
 
 
 
 
856dd67
513ab80
99d1be1
513ab80
 
99d1be1
18c6d65
99d1be1
 
513ab80
 
7634259
856dd67
18c6d65
 
513ab80
 
 
28a2884
513ab80
18c6d65
5456c8e
28a2884
 
513ab80
99d1be1
 
18c6d65
 
99d1be1
 
18c6d65
99d1be1
 
 
 
 
 
 
 
 
 
18c6d65
99d1be1
 
 
18c6d65
99d1be1
 
 
18c6d65
99d1be1
b7bae01
 
 
 
99d1be1
18c6d65
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
from PIL import Image
from transformers import AutoModel, AutoTokenizer
from io import BytesIO
import base64

class EndpointHandler:
    def __init__(self, model_dir=None):
        self.load_model()

    def load_model(self):
        model_name = "openbmb/MiniCPM-V-2_6-int4"
        print(f"Loading model: {model_name}")
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
            self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True, torch_dtype=torch.float16)
            self.model.eval()
            print("Model loaded successfully.")
        except Exception as e:
            print(f"Error loading model: {e}")
            raise

    def load_image(self, image_base64):
        try:
            print("Decoding base64 image...")
            image_bytes = base64.b64decode(image_base64)
            image = Image.open(BytesIO(image_bytes)).convert("RGB")
            print("Image loaded successfully.")
            return image
        except Exception as e:
            print(f"Failed to decode or open image: {e}")
            raise ValueError(f"Failed to open image from base64 string: {e}")

    def predict(self, request):
        print(f"Received request: {request}")
        
        image_base64 = request.get("inputs", {}).get("image")  # Get the base64 image from the request
        question = request.get("inputs", {}).get("question", "What is in the image?")
        stream = request.get("inputs", {}).get("stream", False)

        if not image_base64:
            print("Missing image in the request.")
            return {"error": "Missing image."}

        try:
            image = self.load_image(image_base64)  # Decode base64 and load image
            msgs = [{"role": "user", "content": [image, question]}]

            print(f"Processing prediction with question: {question}")

            if stream:
                generated_text = ""
                print("Starting stream prediction...")
                res = self.model.chat(
                    image=None,
                    msgs=msgs,
                    tokenizer=self.tokenizer,
                    sampling=True,
                    temperature=0.7,
                    stream=True
                )
                for new_text in res:
                    generated_text += new_text
                print("Stream prediction completed.")
                return {"output": generated_text}
            else:
                output = self.model.chat(image=None, msgs=msgs, tokenizer=self.tokenizer)
                print("Prediction completed.")
                return {"output": output}

        except Exception as e:
            print(f"Error during prediction: {e}")
            return {"error": str(e)}

    def __call__(self, data):
        return self.predict(data)

# Example usage
handler = EndpointHandler()