| # handler.py | |
| import os | |
| import onnxruntime as ort | |
| import numpy as np | |
| from transformers import AutoTokenizer | |
| from typing import Dict, List, Any | |
| from colbert_configuration import ColBERTConfig # Import ColBERTConfig | |
| # Assuming modeling.py and colbert_configuration.py are in the same directory | |
| # We'll use local imports since this handler will run within the model's directory context | |
| # For ConstBERT to be recognized, you need to ensure these are importable. | |
| # If you run into issues, consider a custom Docker image or ensuring the model | |
| # is loadable via AutoModel.from_pretrained if it has auto_map in config.json | |
| # For simplicity, we're relying on ConstBERT.from_pretrained working with ONNXRuntime path. | |
| # Note: The EndpointHandler class must be named exactly this. | |
| class EndpointHandler: | |
| def __init__(self, path=""): # path will be '/repository' on HF Endpoints | |
| # `path` is the directory where your model files (model.onnx, tokenizer files) are located. | |
| # Load the tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained(path) | |
| print(f"Tokenizer loaded from: {path}") | |
| # Use the doc_maxlen that the ONNX model was *actually exported with* (250). | |
| # This ensures consistency between the handler's tokenizer and the ONNX model's expectation. | |
| self.doc_max_length = 250 | |
| print(f"Hardcoded doc_maxlen for tokenizer as: {self.doc_max_length}") | |
| # NOTE: If you need other colbert_config parameters, you'd load it here, | |
| # but for doc_max_length, we are explicitly setting it to avoid mismatches. | |
| # self.colbert_config = ColBERTConfig.load_from_checkpoint(path) | |
| # self.doc_max_length = self.colbert_config.doc_maxlen | |
| # Load the ONNX model | |
| onnx_model_path = os.path.join(path, "model.onnx") | |
| self.session = ort.InferenceSession(onnx_model_path) | |
| print(f"ONNX model loaded from: {onnx_model_path}") | |
| # Get input names from the ONNX model | |
| self.input_names = [input.name for input in self.session.get_inputs()] | |
| print(f"ONNX input names: {self.input_names}") | |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| """ | |
| Inference call for the endpoint. | |
| Args: | |
| data (Dict[str, Any]): The request payload. | |
| Expected to contain "inputs" (str or list of str). | |
| Returns: | |
| List[Dict[str, Any]]: A list of dictionaries, where each dict | |
| contains the raw multi-vector output for an input. | |
| Example: [{"embedding": [[...], [...], ...]}, ...] | |
| """ | |
| inputs = data.pop("inputs", None) | |
| if inputs is None: | |
| raise ValueError("No 'inputs' found in the request payload.") | |
| # Ensure inputs is a list | |
| if isinstance(inputs, str): | |
| inputs = [inputs] | |
| # Tokenize the inputs, ensuring consistent padding/truncation to doc_max_length | |
| tokenized_inputs = self.tokenizer( | |
| inputs, | |
| padding="max_length", # Use max_length padding | |
| truncation=True, | |
| max_length=self.doc_max_length, # Use the loaded doc_max_length | |
| return_tensors="np" | |
| ) | |
| input_ids = tokenized_inputs["input_ids"] | |
| attention_mask = tokenized_inputs["attention_mask"] | |
| # Prepare ONNX input dictionary | |
| onnx_inputs = { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask | |
| } | |
| # Run ONNX inference | |
| outputs = self.session.run(None, onnx_inputs) | |
| # The first output is your multi-vector embedding | |
| multi_vector_embeddings = outputs[0] | |
| # Convert to list of lists (JSON serializable) | |
| # Assuming batch_size will be 1 for typical endpoint requests, but handling potential batching from client for robustness. | |
| result_list = [] | |
| for i in range(multi_vector_embeddings.shape[0]): | |
| # Each element in the result_list will be a dictionary for one input, | |
| # containing its multi-vector embedding (fixed 32 x 128) | |
| result_list.append({"embedding": multi_vector_embeddings[i].tolist()}) | |
| return result_list | |