Spaces:
Running
Running
| from typing import Any | |
| from collections.abc import Iterator | |
| from elasticsearch import Elasticsearch | |
| from ask_candid.base.retrieval.sparse_lexical import SpladeEncoder | |
| from ask_candid.base.config.connections import BaseElasticAPIKeyCredential, BaseElasticSearchConnection | |
| NEWS_TRUST_SCORE_THRESHOLD = 0.8 | |
| SPARSE_ENCODING_SCORE_THRESHOLD = 0.4 | |
| def build_sparse_vector_query( | |
| query: str, | |
| fields: tuple[str, ...], | |
| inference_id: str = ".elser-2-elasticsearch" | |
| ) -> dict[str, Any]: | |
| """Builds a valid Elasticsearch text expansion query payload | |
| Parameters | |
| ---------- | |
| query : str | |
| Search context string | |
| fields : tuple[str, ...] | |
| Semantic text field names | |
| inference_id : str, optional | |
| ID of model deployed in Elasticsearch, by default ".elser-2-elasticsearch" | |
| Returns | |
| ------- | |
| dict[str, Any] | |
| """ | |
| output = [] | |
| for f in fields: | |
| output.append({ | |
| "nested": { | |
| "path": f"embeddings.{f}.chunks", | |
| "query": { | |
| "sparse_vector": { | |
| "field": f"embeddings.{f}.chunks.vector", | |
| "inference_id": inference_id, | |
| "prune": True, | |
| "query": query, | |
| # "boost": 1 / len(fields) | |
| } | |
| }, | |
| "inner_hits": { | |
| "_source": False, | |
| "size": 2, | |
| "fields": [f"embeddings.{f}.chunks.chunk"] | |
| } | |
| } | |
| }) | |
| return {"query": {"bool": {"should": output}}} | |
| def build_sparse_vector_and_text_query( | |
| query: str, | |
| semantic_fields: tuple[str, ...], | |
| text_fields: tuple[str, ...] | None, | |
| highlight_fields: tuple[str, ...] | None, | |
| excluded_fields: tuple[str, ...] | None, | |
| inference_id: str = ".elser-2-elasticsearch" | |
| ) -> dict[str, Any]: | |
| """Builds Elasticsearch sparse vector and text query payload | |
| Parameters | |
| ---------- | |
| query : str | |
| Search context string | |
| semantic_fields : tuple[str] | |
| Semantic text field names | |
| highlight_fields: tuple[str] | |
| Fields which relevant chunks will be helpful for the agent to read | |
| text_fields : tuple[str] | |
| Regular text fields | |
| excluded_fields : tuple[str] | |
| Fields to exclude from the source | |
| inference_id : str, optional | |
| ID of model deployed in Elasticsearch, by default ".elser-2-elasticsearch" | |
| Returns | |
| ------- | |
| dict[str, Any] | |
| """ | |
| output = [] | |
| final_query = {} | |
| for f in semantic_fields: | |
| output.append({ | |
| "sparse_vector": { | |
| "field": f"{f}", | |
| "inference_id": inference_id, | |
| "query": query, | |
| "boost": 1, | |
| "prune": True # doesn't seem it changes anything if we use text queries additionally | |
| } | |
| }) | |
| if text_fields: | |
| output.append({ | |
| "multi_match": { | |
| "fields": text_fields, | |
| "query": query, | |
| "boost": 3 | |
| } | |
| }) | |
| final_query = { | |
| "track_total_hits": False, | |
| "query": { | |
| "bool": {"should": output} | |
| } | |
| } | |
| if highlight_fields: | |
| final_query["highlight"] = { | |
| "fields": { | |
| f"{f}": { | |
| "type": "semantic", # ensures that highlighting is applied exclusively to semantic_text fields. | |
| "number_of_fragments": 2, # number of chunks | |
| "order": "none" # can be "score", but we have only two and hope for context | |
| } | |
| for f in highlight_fields | |
| } | |
| } | |
| if excluded_fields: | |
| final_query["_source"] = {"excludes": list(excluded_fields)} | |
| return final_query | |
| def news_query_builder( | |
| query: str, | |
| fields: tuple[str, ...], | |
| encoder: SpladeEncoder, | |
| days_ago: int = 60, | |
| ) -> dict[str, Any]: | |
| """Builds a valid Elasticsearch query against Candid news, simulating a token expansion. | |
| Parameters | |
| ---------- | |
| query : str | |
| Search context string | |
| Returns | |
| ------- | |
| dict[str, Any] | |
| """ | |
| tokens = encoder.token_expand(query) | |
| elastic_query = { | |
| "_source": ["id", "link", "title", "content", "site_name"], | |
| "query": { | |
| "bool": { | |
| "filter": [ | |
| {"range": {"event_date": {"gte": f"now-{days_ago}d/d"}}}, | |
| {"range": {"insert_date": {"gte": f"now-{days_ago}d/d"}}}, | |
| {"range": {"article_trust_worthiness": {"gt": NEWS_TRUST_SCORE_THRESHOLD}}} | |
| ], | |
| "should": [] | |
| } | |
| } | |
| } | |
| for token, score in tokens.items(): | |
| if score > SPARSE_ENCODING_SCORE_THRESHOLD: | |
| elastic_query["query"]["bool"]["should"].append({ | |
| "multi_match": { | |
| "query": token, | |
| "fields": fields, | |
| "boost": score | |
| } | |
| }) | |
| return elastic_query | |
| def issuelab_query_builder( | |
| query: str, | |
| fields: tuple[str, ...], | |
| highlight_fields: tuple[str, ...] | None, | |
| encoder: SpladeEncoder, | |
| ) -> dict[str, Any]: | |
| """Builds a valid Elasticsearch query against Issuelab, simulating a token expansion. | |
| Parameters | |
| ---------- | |
| query : str | |
| Search context string | |
| Returns | |
| ------- | |
| dict[str, Any] | |
| """ | |
| tokens = encoder.token_expand(query) | |
| elastic_query = { | |
| "_source": ["issuelab_id", "issuelab_url", "title", "description", "content"], | |
| "query": { | |
| "bool": { | |
| # "filter": [ | |
| # # {"range": {"event_date": {"gte": f"now-{days_ago}d/d"}}}, | |
| # # {"range": {"insert_date": {"gte": f"now-{days_ago}d/d"}}}, | |
| # # {"range": {"article_trust_worthiness": {"gt": NEWS_TRUST_SCORE_THRESHOLD}}} | |
| # ], | |
| "should": [] | |
| } | |
| }, | |
| "highlight": { | |
| "fields": dict.fromkeys(highlight_fields or ("content", "description"), {}) | |
| } | |
| } | |
| for token, score in tokens.items(): | |
| if score > SPARSE_ENCODING_SCORE_THRESHOLD: | |
| elastic_query["query"]["bool"]["should"].append({ | |
| "multi_match": { | |
| "query": token, | |
| "fields": fields, | |
| "boost": score | |
| } | |
| }) | |
| return elastic_query | |
| def multi_search_base( | |
| queries: list[dict[str, Any]], | |
| credentials: BaseElasticSearchConnection | BaseElasticAPIKeyCredential, | |
| timeout: int = 180 | |
| ) -> Iterator[dict[str, Any]]: | |
| """Handles multi-search queries on a single cluster given the relevant credetials object | |
| Parameters | |
| ---------- | |
| queries : list[dict[str, Any]] | |
| `msearch` query object, (see: https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-msearch) | |
| credentials : BaseElasticSearchConnection | BaseElasticAPIKeyCredential | |
| timeout : int, optional, by default 180 | |
| Yields | |
| ------ | |
| Iterator[dict[str, Any]] | |
| Raises | |
| ------ | |
| TypeError | |
| Raised if invalid credentials are passed | |
| """ | |
| if isinstance(credentials, BaseElasticAPIKeyCredential): | |
| es = Elasticsearch( | |
| cloud_id=credentials.cloud_id, | |
| api_key=credentials.api_key, | |
| verify_certs=False, | |
| request_timeout=timeout | |
| ) | |
| elif isinstance(credentials, BaseElasticSearchConnection): | |
| es = Elasticsearch( | |
| credentials.url, | |
| http_auth=(credentials.username, credentials.password), | |
| timeout=timeout | |
| ) | |
| else: | |
| raise TypeError(f"Invalid credentials of type `{type(credentials)}") | |
| yield from es.msearch(body=queries).get("responses", []) | |
| es.close() | |