File size: 5,385 Bytes
e031746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
import json
import torch
import base64
from io import BytesIO
from typing import List, Dict, Any, Union
from PIL import Image
from transformers import AutoProcessor
from custom_st import Transformer

class ModelHandler:
    """

    Custom handler for the embedding model using the Transformer class from custom_st.py

    """
    def __init__(self):
        self.initialized = False
        self.model = None
        self.processor = None
        self.device = None
        self.default_task = "retrieval"  # Default task, can be overridden in initialize
        self.max_seq_length = 8192  # Default max sequence length

    def initialize(self, context):
        """

        Initialize model and processor

        """
        self.initialized = True
        
        # Get model directory
        properties = context.system_properties
        model_dir = properties.get("model_dir")
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Load config if exists
        config_path = os.path.join(model_dir, "config.json")
        if os.path.exists(config_path):
            with open(config_path, 'r') as f:
                config = json.load(f)
                self.default_task = config.get("default_task", self.default_task)
                self.max_seq_length = config.get("max_seq_length", self.max_seq_length)
        
        # Initialize model
        self.model = Transformer(
            model_name_or_path=model_dir,
            max_seq_length=self.max_seq_length,
            model_args={"default_task": self.default_task}
        )
        self.model.model.to(self.device)
        self.model.model.eval()
        
        # Get processor from the model
        self.processor = self.model.processor

    def preprocess(self, data):
        """

        Process input data for the model

        """
        inputs = []
        
        # Extract request body
        for row in data:
            body = row.get("body", {})
            if isinstance(body, (bytes, bytearray)):
                body = json.loads(body.decode('utf-8'))
            elif isinstance(body, str):
                body = json.loads(body)
            
            # Handle different input formats
            if "inputs" in body:
                raw_inputs = body["inputs"]
                if isinstance(raw_inputs, str):
                    inputs.append(raw_inputs)
                elif isinstance(raw_inputs, list):
                    inputs.extend(raw_inputs)
            elif "text" in body:
                inputs.append(body["text"])
            elif "image" in body:
                # Handle base64 encoded images
                image_data = body["image"]
                if isinstance(image_data, str) and image_data.startswith("data:image"):
                    # Extract base64 data from data URL
                    image_data = image_data.split(",")[1]
                    image = Image.open(BytesIO(base64.b64decode(image_data))).convert("RGB")
                    inputs.append(image)
                else:
                    inputs.append(image_data)  # URL or file path
            elif "inputs" not in body and not body:
                # Empty request, return empty response
                return []
        
        # Use the model's tokenize method to process inputs
        if inputs:
            features = self.model.tokenize(inputs)
            return features
        
        return []

    def inference(self, features):
        """

        Run inference with the processed features

        """
        if not features:
            return {"embeddings": []}
        
        # Move tensors to the device
        for key, value in features.items():
            if isinstance(value, torch.Tensor):
                features[key] = value.to(self.device)
        
        with torch.no_grad():
            outputs = self.model.forward(features, task=self.default_task)
        
        # Get the embeddings
        embeddings = outputs.get("sentence_embedding", None)
        
        if embeddings is not None:
            # Convert to list for JSON serialization
            return {"embeddings": embeddings.cpu().numpy().tolist()}
        else:
            return {"error": "No embeddings were generated"}

    def postprocess(self, inference_output):
        """

        Process model output for the response

        """
        return [inference_output]

    def handle(self, data, context):
        """

        Main handler function

        """
        if not self.initialized:
            self.initialize(context)
            
        if not data:
            return {"embeddings": []}
            
        try:
            processed_data = self.preprocess(data)
            if not processed_data:
                return [{"embeddings": []}]
                
            inference_result = self.inference(processed_data)
            return self.postprocess(inference_result)
        except Exception as e:
            raise Exception(f"Error processing request: {str(e)}")

# Define the handler for torchserve
_service = ModelHandler()

def handle(data, context):
    """

    Torchserve handler function

    """
    return _service.handle(data, context)