| from typing import Any, Dict, List |
|
|
| from colbert.infra import ColBERTConfig |
| from colbert.modeling.checkpoint import Checkpoint |
| import torch |
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
| MODEL = "fdurant/colbert-xm-for-inference-api" |
|
|
| class EndpointHandler(): |
|
|
| def __init__(self, path=""): |
| self._config = ColBERTConfig( |
| |
| doc_maxlen=512, |
| nbits=2, |
| kmeans_niters=4, |
| nranks=-1, |
| checkpoint=MODEL, |
| ) |
| self._checkpoint = Checkpoint(self._config.checkpoint, colbert_config=self._config, verbose=3) |
|
|
| def __call__(self, data: Any) -> List[Dict[str, Any]]: |
| inputs = data["inputs"] |
| texts = [] |
| if isinstance(inputs, str): |
| texts = [inputs] |
| elif isinstance(inputs, list) and all(isinstance(text, str) for text in inputs): |
| texts = inputs |
| else: |
| raise ValueError("Invalid input data format") |
| with torch.inference_mode(): |
| |
| if len(texts) == 1: |
| |
| logger.info(f"Query: {texts}") |
| embedding = self._checkpoint.queryFromText( |
| queries=texts, |
| full_length_search=False, |
| ) |
| logger.info(f"Query embedding shape: {embedding.shape}") |
| return [ |
| {"input": inputs, "query_embedding": embedding.tolist()[0]} |
| ] |
| elif len(texts) > 1: |
| |
| logger.info(f"Batch of chunks: {texts}") |
| embeddings, token_counts = self._checkpoint.docFromText( |
| docs=texts, |
| bsize=self._config.bsize, |
| keep_dims=True, |
| return_tokens=True, |
| ) |
| for text, embedding, token_count in zip(texts, embeddings, token_counts): |
| logger.info(f"Chunk: {text}") |
| logger.info(f"Chunk embedding shape: {embedding.shape}") |
| logger.info(f"Chunk count: {token_count}") |
| return [ |
| {"input": _input, "chunk_embedding": embedding.tolist(), "token_count": token_count.tolist()} |
| for _input, embedding, token_count in zip(texts, embeddings, token_counts) |
| ] |
| else: |
| raise ValueError("No data to process") |
|
|