import os import json import torch import base64 from io import BytesIO from typing import List, Dict, Any, Union from PIL import Image from transformers import AutoProcessor from custom_st import Transformer class ModelHandler: """ Custom handler for the embedding model using the Transformer class from custom_st.py """ def __init__(self): self.initialized = False self.model = None self.processor = None self.device = None self.default_task = "retrieval" # Default task, can be overridden in initialize self.max_seq_length = 8192 # Default max sequence length def initialize(self, context): """ Initialize model and processor """ self.initialized = True # Get model directory properties = context.system_properties model_dir = properties.get("model_dir") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load config if exists config_path = os.path.join(model_dir, "config.json") if os.path.exists(config_path): with open(config_path, 'r') as f: config = json.load(f) self.default_task = config.get("default_task", self.default_task) self.max_seq_length = config.get("max_seq_length", self.max_seq_length) # Initialize model self.model = Transformer( model_name_or_path=model_dir, max_seq_length=self.max_seq_length, model_args={"default_task": self.default_task} ) self.model.model.to(self.device) self.model.model.eval() # Get processor from the model self.processor = self.model.processor def preprocess(self, data): """ Process input data for the model """ inputs = [] # Extract request body for row in data: body = row.get("body", {}) if isinstance(body, (bytes, bytearray)): body = json.loads(body.decode('utf-8')) elif isinstance(body, str): body = json.loads(body) # Handle different input formats if "inputs" in body: raw_inputs = body["inputs"] if isinstance(raw_inputs, str): inputs.append(raw_inputs) elif isinstance(raw_inputs, list): inputs.extend(raw_inputs) elif "text" in body: inputs.append(body["text"]) elif "image" in body: # Handle base64 encoded images image_data = body["image"] if isinstance(image_data, str) and image_data.startswith("data:image"): # Extract base64 data from data URL image_data = image_data.split(",")[1] image = Image.open(BytesIO(base64.b64decode(image_data))).convert("RGB") inputs.append(image) else: inputs.append(image_data) # URL or file path elif "inputs" not in body and not body: # Empty request, return empty response return [] # Use the model's tokenize method to process inputs if inputs: features = self.model.tokenize(inputs) return features return [] def inference(self, features): """ Run inference with the processed features """ if not features: return {"embeddings": []} # Move tensors to the device for key, value in features.items(): if isinstance(value, torch.Tensor): features[key] = value.to(self.device) with torch.no_grad(): outputs = self.model.forward(features, task=self.default_task) # Get the embeddings embeddings = outputs.get("sentence_embedding", None) if embeddings is not None: # Convert to list for JSON serialization return {"embeddings": embeddings.cpu().numpy().tolist()} else: return {"error": "No embeddings were generated"} def postprocess(self, inference_output): """ Process model output for the response """ return [inference_output] def handle(self, data, context): """ Main handler function """ if not self.initialized: self.initialize(context) if not data: return {"embeddings": []} try: processed_data = self.preprocess(data) if not processed_data: return [{"embeddings": []}] inference_result = self.inference(processed_data) return self.postprocess(inference_result) except Exception as e: raise Exception(f"Error processing request: {str(e)}") # Define the handler for torchserve _service = ModelHandler() def handle(data, context): """ Torchserve handler function """ return _service.handle(data, context)