File size: 5,385 Bytes
e031746 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import os
import json
import torch
import base64
from io import BytesIO
from typing import List, Dict, Any, Union
from PIL import Image
from transformers import AutoProcessor
from custom_st import Transformer
class ModelHandler:
"""
Custom handler for the embedding model using the Transformer class from custom_st.py
"""
def __init__(self):
self.initialized = False
self.model = None
self.processor = None
self.device = None
self.default_task = "retrieval" # Default task, can be overridden in initialize
self.max_seq_length = 8192 # Default max sequence length
def initialize(self, context):
"""
Initialize model and processor
"""
self.initialized = True
# Get model directory
properties = context.system_properties
model_dir = properties.get("model_dir")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load config if exists
config_path = os.path.join(model_dir, "config.json")
if os.path.exists(config_path):
with open(config_path, 'r') as f:
config = json.load(f)
self.default_task = config.get("default_task", self.default_task)
self.max_seq_length = config.get("max_seq_length", self.max_seq_length)
# Initialize model
self.model = Transformer(
model_name_or_path=model_dir,
max_seq_length=self.max_seq_length,
model_args={"default_task": self.default_task}
)
self.model.model.to(self.device)
self.model.model.eval()
# Get processor from the model
self.processor = self.model.processor
def preprocess(self, data):
"""
Process input data for the model
"""
inputs = []
# Extract request body
for row in data:
body = row.get("body", {})
if isinstance(body, (bytes, bytearray)):
body = json.loads(body.decode('utf-8'))
elif isinstance(body, str):
body = json.loads(body)
# Handle different input formats
if "inputs" in body:
raw_inputs = body["inputs"]
if isinstance(raw_inputs, str):
inputs.append(raw_inputs)
elif isinstance(raw_inputs, list):
inputs.extend(raw_inputs)
elif "text" in body:
inputs.append(body["text"])
elif "image" in body:
# Handle base64 encoded images
image_data = body["image"]
if isinstance(image_data, str) and image_data.startswith("data:image"):
# Extract base64 data from data URL
image_data = image_data.split(",")[1]
image = Image.open(BytesIO(base64.b64decode(image_data))).convert("RGB")
inputs.append(image)
else:
inputs.append(image_data) # URL or file path
elif "inputs" not in body and not body:
# Empty request, return empty response
return []
# Use the model's tokenize method to process inputs
if inputs:
features = self.model.tokenize(inputs)
return features
return []
def inference(self, features):
"""
Run inference with the processed features
"""
if not features:
return {"embeddings": []}
# Move tensors to the device
for key, value in features.items():
if isinstance(value, torch.Tensor):
features[key] = value.to(self.device)
with torch.no_grad():
outputs = self.model.forward(features, task=self.default_task)
# Get the embeddings
embeddings = outputs.get("sentence_embedding", None)
if embeddings is not None:
# Convert to list for JSON serialization
return {"embeddings": embeddings.cpu().numpy().tolist()}
else:
return {"error": "No embeddings were generated"}
def postprocess(self, inference_output):
"""
Process model output for the response
"""
return [inference_output]
def handle(self, data, context):
"""
Main handler function
"""
if not self.initialized:
self.initialize(context)
if not data:
return {"embeddings": []}
try:
processed_data = self.preprocess(data)
if not processed_data:
return [{"embeddings": []}]
inference_result = self.inference(processed_data)
return self.postprocess(inference_result)
except Exception as e:
raise Exception(f"Error processing request: {str(e)}")
# Define the handler for torchserve
_service = ModelHandler()
def handle(data, context):
"""
Torchserve handler function
"""
return _service.handle(data, context)
|