| import os |
| import time |
| from typing import Any, Dict, Tuple |
| import asyncio |
| import numpy as np |
| import torch |
| from dotenv import load_dotenv |
| from vespa.application import Vespa |
| from vespa.io import VespaQueryResponse |
| from .colpali import SimMapGenerator |
| import backend.stopwords |
| import logging |
|
|
|
|
| class VespaQueryClient: |
| MAX_QUERY_TERMS = 64 |
| VESPA_SCHEMA_NAME = "pdf_page" |
| SELECT_FIELDS = "id,title,url,blur_image,page_number,snippet,text" |
|
|
| def __init__(self, logger: logging.Logger): |
| """ |
| Initialize the VespaQueryClient by loading environment variables and establishing a connection to the Vespa application. |
| """ |
| load_dotenv() |
| self.logger = logger |
|
|
| if os.environ.get("USE_MTLS") == "true": |
| self.logger.info("Connected using mTLS") |
| mtls_key = os.environ.get("VESPA_CLOUD_MTLS_KEY") |
| mtls_cert = os.environ.get("VESPA_CLOUD_MTLS_CERT") |
|
|
| self.vespa_app_url = os.environ.get("VESPA_APP_MTLS_URL") |
| if not self.vespa_app_url: |
| raise ValueError( |
| "Please set the VESPA_APP_MTLS_URL environment variable" |
| ) |
|
|
| if not mtls_cert or not mtls_key: |
| raise ValueError( |
| "USE_MTLS was true, but VESPA_CLOUD_MTLS_KEY and VESPA_CLOUD_MTLS_CERT were not set" |
| ) |
|
|
| |
| mtls_key_path = "/tmp/vespa-data-plane-private-key.pem" |
| with open(mtls_key_path, "w") as f: |
| f.write(mtls_key) |
|
|
| mtls_cert_path = "/tmp/vespa-data-plane-public-cert.pem" |
| with open(mtls_cert_path, "w") as f: |
| f.write(mtls_cert) |
|
|
| |
| self.app = Vespa( |
| url=self.vespa_app_url, key=mtls_key_path, cert=mtls_cert_path |
| ) |
| else: |
| self.logger.info("Connected using token") |
| self.vespa_app_url = os.environ.get("VESPA_APP_TOKEN_URL") |
| if not self.vespa_app_url: |
| raise ValueError( |
| "Please set the VESPA_APP_TOKEN_URL environment variable" |
| ) |
|
|
| self.vespa_cloud_secret_token = os.environ.get("VESPA_CLOUD_SECRET_TOKEN") |
|
|
| if not self.vespa_cloud_secret_token: |
| raise ValueError( |
| "Please set the VESPA_CLOUD_SECRET_TOKEN environment variable" |
| ) |
|
|
| |
| self.app = Vespa( |
| url=self.vespa_app_url, |
| vespa_cloud_secret_token=self.vespa_cloud_secret_token, |
| ) |
|
|
| self.app.wait_for_application_up() |
| self.logger.info(f"Connected to Vespa at {self.vespa_app_url}") |
|
|
| def get_fields(self, sim_map: bool = False): |
| if not sim_map: |
| return self.SELECT_FIELDS |
| else: |
| return "summaryfeatures" |
|
|
| def format_query_results( |
| self, query: str, response: VespaQueryResponse, hits: int = 5 |
| ) -> dict: |
| """ |
| Format the Vespa query results. |
| |
| Args: |
| query (str): The query text. |
| response (VespaQueryResponse): The response from Vespa. |
| hits (int, optional): Number of hits to display. Defaults to 5. |
| |
| Returns: |
| dict: The JSON content of the response. |
| """ |
| query_time = response.json.get("timing", {}).get("searchtime", -1) |
| query_time = round(query_time, 2) |
| count = response.json.get("root", {}).get("fields", {}).get("totalCount", 0) |
| result_text = f"Query text: '{query}', query time {query_time}s, count={count}, top results:\n" |
| self.logger.debug(result_text) |
| return response.json |
|
|
| async def query_vespa_bm25( |
| self, |
| query: str, |
| q_emb: torch.Tensor, |
| hits: int = 3, |
| timeout: str = "10s", |
| sim_map: bool = False, |
| **kwargs, |
| ) -> dict: |
| """ |
| Query Vespa using the BM25 ranking profile. |
| This corresponds to the "BM25" radio button in the UI. |
| |
| Args: |
| query (str): The query text. |
| q_emb (torch.Tensor): Query embeddings. |
| hits (int, optional): Number of hits to retrieve. Defaults to 3. |
| timeout (str, optional): Query timeout. Defaults to "10s". |
| |
| Returns: |
| dict: The formatted query results. |
| """ |
| async with self.app.asyncio(connections=1) as session: |
| query_embedding = self.format_q_embs(q_emb) |
|
|
| start = time.perf_counter() |
| response: VespaQueryResponse = await session.query( |
| body={ |
| "yql": ( |
| f"select {self.get_fields(sim_map=sim_map)} from {self.VESPA_SCHEMA_NAME} where userQuery();" |
| ), |
| "ranking": self.get_rank_profile("bm25", sim_map), |
| "query": query, |
| "timeout": timeout, |
| "hits": hits, |
| "input.query(qt)": query_embedding, |
| "presentation.timing": True, |
| **kwargs, |
| }, |
| ) |
| assert response.is_successful(), response.json |
| stop = time.perf_counter() |
| self.logger.debug( |
| f"Query time + data transfer took: {stop - start} s, Vespa reported searchtime was " |
| f"{response.json.get('timing', {}).get('searchtime', -1)} s" |
| ) |
| return self.format_query_results(query, response) |
|
|
| def float_to_binary_embedding(self, float_query_embedding: dict) -> dict: |
| """ |
| Convert float query embeddings to binary embeddings. |
| |
| Args: |
| float_query_embedding (dict): Dictionary of float embeddings. |
| |
| Returns: |
| dict: Dictionary of binary embeddings. |
| """ |
| binary_query_embeddings = {} |
| for key, vector in float_query_embedding.items(): |
| binary_vector = ( |
| np.packbits(np.where(np.array(vector) > 0, 1, 0)) |
| .astype(np.int8) |
| .tolist() |
| ) |
| binary_query_embeddings[key] = binary_vector |
| if len(binary_query_embeddings) >= self.MAX_QUERY_TERMS: |
| self.logger.warning( |
| f"Warning: Query has more than {self.MAX_QUERY_TERMS} terms. Truncating." |
| ) |
| break |
| return binary_query_embeddings |
|
|
| def create_nn_query_strings( |
| self, binary_query_embeddings: dict, target_hits_per_query_tensor: int = 20 |
| ) -> Tuple[str, dict]: |
| """ |
| Create nearest neighbor query strings for Vespa. |
| |
| Args: |
| binary_query_embeddings (dict): Binary query embeddings. |
| target_hits_per_query_tensor (int, optional): Target hits per query tensor. Defaults to 20. |
| |
| Returns: |
| Tuple[str, dict]: Nearest neighbor query string and query tensor dictionary. |
| """ |
| nn_query_dict = {} |
| for i in range(len(binary_query_embeddings)): |
| nn_query_dict[f"input.query(rq{i})"] = binary_query_embeddings[i] |
| nn = " OR ".join( |
| [ |
| f"({{targetHits:{target_hits_per_query_tensor}}}nearestNeighbor(embedding,rq{i}))" |
| for i in range(len(binary_query_embeddings)) |
| ] |
| ) |
| return nn, nn_query_dict |
|
|
| def format_q_embs(self, q_embs: torch.Tensor) -> dict: |
| """ |
| Convert query embeddings to a dictionary of lists. |
| |
| Args: |
| q_embs (torch.Tensor): Query embeddings tensor. |
| |
| Returns: |
| dict: Dictionary where each key is an index and value is the embedding list. |
| """ |
| return {idx: emb.tolist() for idx, emb in enumerate(q_embs)} |
|
|
| async def get_result_from_query( |
| self, |
| query: str, |
| q_embs: torch.Tensor, |
| ranking: str, |
| idx_to_token: dict, |
| ) -> Dict[str, Any]: |
| """ |
| Get query results from Vespa based on the ranking method. |
| |
| Args: |
| query (str): The query text. |
| q_embs (torch.Tensor): Query embeddings. |
| ranking (str): The ranking method to use. |
| idx_to_token (dict): Index to token mapping. |
| |
| Returns: |
| Dict[str, Any]: The query results. |
| """ |
|
|
| |
| query = backend.stopwords.filter(query) |
|
|
| rank_method = ranking.split("_")[0] |
| sim_map: bool = len(ranking.split("_")) > 1 and ranking.split("_")[1] == "sim" |
| if rank_method == "colpali": |
| result = await self.query_vespa_colpali( |
| query=query, ranking=rank_method, q_emb=q_embs, sim_map=sim_map |
| ) |
| elif rank_method == "hybrid": |
| result = await self.query_vespa_colpali( |
| query=query, ranking=rank_method, q_emb=q_embs, sim_map=sim_map |
| ) |
| elif rank_method == "bm25": |
| result = await self.query_vespa_bm25(query, q_embs, sim_map=sim_map) |
| else: |
| raise ValueError(f"Unsupported ranking: {rank_method}") |
| if "root" not in result or "children" not in result["root"]: |
| result["root"] = {"children": []} |
| return result |
| for single_result in result["root"]["children"]: |
| self.logger.debug(single_result["fields"].keys()) |
| return result |
|
|
| def get_sim_maps_from_query( |
| self, query: str, q_embs: torch.Tensor, ranking: str, idx_to_token: dict |
| ): |
| """ |
| Get similarity maps from Vespa based on the ranking method. |
| |
| Args: |
| query (str): The query text. |
| q_embs (torch.Tensor): Query embeddings. |
| ranking (str): The ranking method to use. |
| idx_to_token (dict): Index to token mapping. |
| |
| Returns: |
| Dict[str, Any]: The query results. |
| """ |
| |
| result = asyncio.run( |
| self.get_result_from_query(query, q_embs, ranking, idx_to_token) |
| ) |
| vespa_sim_maps = [] |
| for single_result in result["root"]["children"]: |
| vespa_sim_map = single_result["fields"].get("summaryfeatures", None) |
| if vespa_sim_map is not None: |
| vespa_sim_maps.append(vespa_sim_map) |
| else: |
| raise ValueError("No sim_map found in Vespa response") |
| return vespa_sim_maps |
|
|
| async def get_full_image_from_vespa(self, doc_id: str) -> str: |
| """ |
| Retrieve the full image from Vespa for a given document ID. |
| |
| Args: |
| doc_id (str): The document ID. |
| |
| Returns: |
| str: The full image data. |
| """ |
| async with self.app.asyncio(connections=1) as session: |
| start = time.perf_counter() |
| response: VespaQueryResponse = await session.query( |
| body={ |
| "yql": f'select full_image from {self.VESPA_SCHEMA_NAME} where id contains "{doc_id}"', |
| "ranking": "unranked", |
| "presentation.timing": True, |
| "ranking.matching.numThreadsPerSearch": 1, |
| }, |
| ) |
| assert response.is_successful(), response.json |
| stop = time.perf_counter() |
| self.logger.debug( |
| f"Getting image from Vespa took: {stop - start} s, Vespa reported searchtime was " |
| f"{response.json.get('timing', {}).get('searchtime', -1)} s" |
| ) |
| return response.json["root"]["children"][0]["fields"]["full_image"] |
|
|
| def get_results_children(self, result: VespaQueryResponse) -> list: |
| return result["root"]["children"] |
|
|
| def results_to_search_results( |
| self, result: VespaQueryResponse, idx_to_token: dict |
| ) -> list: |
| |
| fields_to_add = [ |
| f"sim_map_{token}_{idx}" |
| for idx, token in idx_to_token.items() |
| if not SimMapGenerator.should_filter_token(token) |
| ] |
| for child in result["root"]["children"]: |
| for sim_map_key in fields_to_add: |
| child["fields"][sim_map_key] = None |
| return self.get_results_children(result) |
|
|
| async def get_suggestions(self, query: str) -> list: |
| async with self.app.asyncio(connections=1) as session: |
| start = time.perf_counter() |
| yql = f'select questions from {self.VESPA_SCHEMA_NAME} where questions matches (".*{query}.*")' |
| response: VespaQueryResponse = await session.query( |
| body={ |
| "yql": yql, |
| "query": query, |
| "ranking": "unranked", |
| "presentation.timing": True, |
| "presentation.summary": "suggestions", |
| "ranking.matching.numThreadsPerSearch": 1, |
| }, |
| ) |
| assert response.is_successful(), response.json |
| stop = time.perf_counter() |
| self.logger.debug( |
| f"Getting suggestions from Vespa took: {stop - start} s, Vespa reported searchtime was " |
| f"{response.json.get('timing', {}).get('searchtime', -1)} s" |
| ) |
| search_results = ( |
| response.json["root"]["children"] |
| if "root" in response.json and "children" in response.json["root"] |
| else [] |
| ) |
| questions = [ |
| result["fields"]["questions"] |
| for result in search_results |
| if "questions" in result["fields"] |
| ] |
|
|
| unique_questions = set([item for sublist in questions for item in sublist]) |
|
|
| |
| if "string" in unique_questions: |
| unique_questions.remove("string") |
|
|
| return list(unique_questions) |
|
|
| def get_rank_profile(self, ranking: str, sim_map: bool) -> str: |
| if sim_map: |
| return f"{ranking}_sim" |
| else: |
| return ranking |
|
|
| async def query_vespa_colpali( |
| self, |
| query: str, |
| ranking: str, |
| q_emb: torch.Tensor, |
| target_hits_per_query_tensor: int = 100, |
| hnsw_explore_additional_hits: int = 300, |
| hits: int = 3, |
| timeout: str = "10s", |
| sim_map: bool = False, |
| **kwargs, |
| ) -> dict: |
| """ |
| Query Vespa using nearest neighbor search with mixed tensors for MaxSim calculations. |
| This corresponds to the "ColPali" radio button in the UI. |
| |
| Args: |
| query (str): The query text. |
| q_emb (torch.Tensor): Query embeddings. |
| target_hits_per_query_tensor (int, optional): Target hits per query tensor. Defaults to 20. |
| hits (int, optional): Number of hits to retrieve. Defaults to 3. |
| timeout (str, optional): Query timeout. Defaults to "10s". |
| |
| Returns: |
| dict: The formatted query results. |
| """ |
| async with self.app.asyncio(connections=1) as session: |
| float_query_embedding = self.format_q_embs(q_emb) |
| binary_query_embeddings = self.float_to_binary_embedding( |
| float_query_embedding |
| ) |
|
|
| |
| query_tensors = { |
| "input.query(qtb)": binary_query_embeddings, |
| "input.query(qt)": float_query_embedding, |
| } |
| nn_string, nn_query_dict = self.create_nn_query_strings( |
| binary_query_embeddings, target_hits_per_query_tensor |
| ) |
| query_tensors.update(nn_query_dict) |
| response: VespaQueryResponse = await session.query( |
| body={ |
| **query_tensors, |
| "presentation.timing": True, |
| "yql": ( |
| f"select {self.get_fields(sim_map=sim_map)} from {self.VESPA_SCHEMA_NAME} where {nn_string} or userQuery()" |
| ), |
| "ranking.profile": self.get_rank_profile( |
| ranking=ranking, sim_map=sim_map |
| ), |
| "timeout": timeout, |
| "hits": hits, |
| "query": query, |
| "hnsw.exploreAdditionalHits": hnsw_explore_additional_hits, |
| "ranking.rerankCount": 100, |
| **kwargs, |
| }, |
| ) |
| assert response.is_successful(), response.json |
| return self.format_query_results(query, response) |
|
|
| async def keepalive(self) -> bool: |
| """ |
| Query Vespa to keep the connection alive. |
| |
| Returns: |
| bool: True if the connection is alive. |
| """ |
| async with self.app.asyncio(connections=1) as session: |
| response: VespaQueryResponse = await session.query( |
| body={ |
| "yql": f"select title from {self.VESPA_SCHEMA_NAME} where true limit 1;", |
| "ranking": "unranked", |
| "query": "keepalive", |
| "timeout": "3s", |
| "hits": 1, |
| }, |
| ) |
| assert response.is_successful(), response.json |
| return True |
|
|