Spaces:
Running
Running
| from typing import Literal, Any | |
| from collections.abc import Iterator, Iterable | |
| from itertools import groupby | |
| import logging | |
| from langchain_core.documents import Document | |
| from ask_candid.base.retrieval.elastic import ( | |
| # build_sparse_vector_query, | |
| build_sparse_vector_and_text_query, | |
| news_query_builder, | |
| issuelab_query_builder, | |
| multi_search_base | |
| ) | |
| from ask_candid.base.retrieval.sparse_lexical import SpladeEncoder | |
| from ask_candid.base.retrieval.schemas import ElasticHitsResult | |
| import ask_candid.base.retrieval.sources as S | |
| from ask_candid.base.config.connections import SEMANTIC_ELASTIC, ELSER_INFERENCE_ID, NEWS_ELASTIC | |
| from ask_candid.services.small_lm import CandidSmallLanguageModel | |
| SourceNames = Literal[ | |
| "Candid Blog", | |
| "Candid Help", | |
| "Candid Learning", | |
| "Candid News", | |
| "IssueLab Research Reports", | |
| "YouTube Training" | |
| ] | |
| sparse_encoder = SpladeEncoder() | |
| logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s") | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024, add_context: bool = True) -> str: | |
| """Pads the relevant chunk of text with context before and after | |
| Parameters | |
| ---------- | |
| field_name : str | |
| a field with the long text that was chunked into pieces | |
| hit : ElasticHitsResult | |
| context_length : int, optional | |
| length of text to add before and after the chunk, by default 1024 | |
| add_context : bool, optional | |
| Set to `False` to expand the text context by searching for the Elastic inner hit inside the larger document | |
| , by default True | |
| Returns | |
| ------- | |
| str | |
| longer chunks stuffed together | |
| """ | |
| chunks = [] | |
| # NOTE chunks have tokens, long text is a string, but may contain html which affects tokenization | |
| long_text = hit.source.get(field_name) or "" | |
| long_text = long_text.lower() | |
| inner_hits_field = f"embeddings.{field_name}.chunks" | |
| found_chunks = hit.inner_hits.get(inner_hits_field, {}) if hit.inner_hits else None | |
| if found_chunks: | |
| for h in found_chunks.get("hits", {}).get("hits") or []: | |
| chunk = h.get("fields", {})[inner_hits_field][0]["chunk"][0] | |
| # cutting the middle because we may have tokenizing artifacts there | |
| chunk = chunk[3: -3] | |
| if add_context: | |
| # Find the start and end indices of the chunk in the large text | |
| start_index = long_text.find(chunk[:20]) | |
| # Chunk is found | |
| if start_index != -1: | |
| end_index = start_index + len(chunk) | |
| pre_start_index = max(0, start_index - context_length) | |
| post_end_index = min(len(long_text), end_index + context_length) | |
| chunks.append(long_text[pre_start_index:post_end_index]) | |
| else: | |
| chunks.append(chunk) | |
| return '\n\n'.join(chunks) | |
| def generate_queries( | |
| query: str, | |
| sources: list[SourceNames], | |
| news_days_ago: int = 60 | |
| ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: | |
| """Builds Elastic queries against indices which do or do not support sparse vector queries. | |
| Parameters | |
| ---------- | |
| query : str | |
| Text describing a user's question or a description of investigative work which requires support from Candid's | |
| knowledge base | |
| sources : list[SourceNames] | |
| One or more sources of knowledge from different areas at Candid. | |
| * Candid Blog: Blog posts from Candid staff and trusted partners intended to help those in the sector or | |
| illuminate ongoing work | |
| * Candid Help: Candid FAQs to help user's get started with Candid's product platform and learning resources | |
| * Candid Learning: Training documents from Candid's subject matter experts | |
| * Candid News: News articles and press releases about real-time activity in the philanthropic sector | |
| * IssueLab Research Reports: Academic research reports about the social/philanthropic sector | |
| * YouTube Training: Transcripts from video-based training seminars from Candid's subject matter experts | |
| news_days_ago : int, optional | |
| How many days in the past to search for news articles, if a user is asking for recent trends then this value | |
| should be set lower >~ 10, by default 60 | |
| Returns | |
| ------- | |
| tuple[list[dict[str, Any]], list[dict[str, Any]]] | |
| (sparse vector queries, queries for indices which do not support sparse vectors) | |
| """ | |
| vector_queries = [] | |
| quasi_vector_queries = [] | |
| for source_name in sources: | |
| if source_name == "Candid Blog": | |
| q = build_sparse_vector_and_text_query( | |
| query=query, | |
| semantic_fields=S.CandidBlogConfig.semantic_fields, | |
| text_fields=S.CandidBlogConfig.text_fields, | |
| highlight_fields=S.CandidBlogConfig.highlight_fields, | |
| excluded_fields=S.CandidBlogConfig.excluded_fields, | |
| inference_id=ELSER_INFERENCE_ID | |
| ) | |
| q["size"] = 5 | |
| vector_queries.extend([{"index": S.CandidBlogConfig.index_name}, q]) | |
| elif source_name == "Candid Help": | |
| q = build_sparse_vector_and_text_query( | |
| query=query, | |
| semantic_fields=S.CandidHelpConfig.semantic_fields, | |
| text_fields=S.CandidHelpConfig.text_fields, | |
| highlight_fields=S.CandidHelpConfig.highlight_fields, | |
| excluded_fields=S.CandidHelpConfig.excluded_fields, | |
| inference_id=ELSER_INFERENCE_ID | |
| ) | |
| q["size"] = 5 | |
| vector_queries.extend([{"index": S.CandidHelpConfig.index_name}, q]) | |
| elif source_name == "Candid Learning": | |
| q = build_sparse_vector_and_text_query( | |
| query=query, | |
| semantic_fields=S.CandidLearningConfig.semantic_fields, | |
| text_fields=S.CandidLearningConfig.text_fields, | |
| highlight_fields=S.CandidLearningConfig.highlight_fields, | |
| excluded_fields=S.CandidLearningConfig.excluded_fields, | |
| inference_id=ELSER_INFERENCE_ID | |
| ) | |
| q["size"] = 5 | |
| vector_queries.extend([{"index": S.CandidLearningConfig.index_name}, q]) | |
| elif source_name == "Candid News": | |
| q = news_query_builder( | |
| query=query, | |
| fields=S.CandidNewsConfig.semantic_fields, | |
| encoder=sparse_encoder, | |
| days_ago=news_days_ago | |
| ) | |
| q["size"] = 5 | |
| quasi_vector_queries.extend([{"index": S.CandidNewsConfig.index_name}, q]) | |
| elif source_name == "IssueLab Research Reports": | |
| # q = build_sparse_vector_query(query=query, fields=S.IssueLabConfig.semantic_fields) | |
| # q["_source"] = {"excludes": ["embeddings"]} | |
| # q["size"] = 1 | |
| # vector_queries.extend([{"index": S.IssueLabConfig.index_name}, q]) | |
| q = issuelab_query_builder( | |
| query=query, | |
| fields=S.IssueLabConfig.semantic_fields, | |
| highlight_fields=S.IssueLabConfig.highlight_fields, | |
| encoder=sparse_encoder, | |
| ) | |
| q["size"] = 1 | |
| quasi_vector_queries.extend([{"index": S.IssueLabConfig.index_name}, q]) | |
| elif source_name == "YouTube Training": | |
| q = build_sparse_vector_and_text_query( | |
| query=query, | |
| semantic_fields=S.YoutubeConfig.semantic_fields, | |
| text_fields=S.YoutubeConfig.text_fields, | |
| highlight_fields=S.YoutubeConfig.highlight_fields, | |
| excluded_fields=S.YoutubeConfig.excluded_fields, | |
| inference_id=ELSER_INFERENCE_ID | |
| ) | |
| q["size"] = 5 | |
| vector_queries.extend([{"index": S.YoutubeConfig.index_name}, q]) | |
| return vector_queries, quasi_vector_queries | |
| def run_search( | |
| vector_searches: list[dict[str, Any]] | None = None, | |
| non_vector_searches: list[dict[str, Any]] | None = None, | |
| ) -> list[ElasticHitsResult]: | |
| """Elastic query runner which executes both sparse vector, and quasi-sparse vector queries and concatenates results. | |
| This does not include re-ranking. | |
| Parameters | |
| ---------- | |
| vector_searches : list[dict[str, Any]] | None, optional | |
| Sparse vector multi-search queries which , by default None | |
| non_vector_searches : list[dict[str, Any]] | None, optional | |
| Keyword-based multi-search queries, by default None | |
| Returns | |
| ------- | |
| list[ElasticHitsResult] | |
| Concatenated results | |
| """ | |
| def _msearch_response_generator(responses: Iterable[dict[str, Any]]) -> Iterator[ElasticHitsResult]: | |
| for query_group in responses: | |
| for h in query_group.get("hits", {}).get("hits", []): | |
| inner_hits = h.get("inner_hits", {}) | |
| if not inner_hits and "news" in h.get("_index"): | |
| inner_hits = {"text": h.get("_source", {}).get("content")} | |
| if not inner_hits and "issuelab" in h.get("_index"): | |
| inner_hits = {"text": h.get("_source", {}).get("content")} | |
| yield ElasticHitsResult( | |
| index=h["_index"], | |
| id=h["_id"], | |
| score=h["_score"], | |
| source=h["_source"], | |
| inner_hits=inner_hits, | |
| highlight=h.get("highlight", {}) | |
| ) | |
| results = [] | |
| if vector_searches is not None and len(vector_searches) > 0: | |
| hits = multi_search_base(queries=vector_searches, credentials=SEMANTIC_ELASTIC) | |
| for hit in _msearch_response_generator(responses=hits): | |
| results.append(hit) | |
| if non_vector_searches is not None and len(non_vector_searches) > 0: | |
| hits = multi_search_base(queries=non_vector_searches, credentials=NEWS_ELASTIC) | |
| for hit in _msearch_response_generator(responses=hits): | |
| results.append(hit) | |
| return results | |
| def retrieved_text(hits: dict[str, Any]) -> str: | |
| """Extracts retrieved sub-texts from documents which are strong hits from semantic queries for the purpose of | |
| re-scoring by a secondary language model. | |
| Parameters | |
| ---------- | |
| hits : dict[str, Any] | |
| Returns | |
| ------- | |
| str | |
| """ | |
| nlp = CandidSmallLanguageModel() | |
| text = [] | |
| for _, v in hits.items(): | |
| if _ == "text": | |
| s = nlp.summarize(v, top_k=3) | |
| text.append(s.summary) | |
| # text.append(v) | |
| continue | |
| for h in (v.get("hits", {}).get("hits") or []): | |
| for _, field in h.get("fields", {}).items(): | |
| for chunk in field: | |
| if chunk.get("chunk"): | |
| text.extend(chunk["chunk"]) | |
| return '\n'.join(text) | |
| def reranker( | |
| query_results: Iterable[ElasticHitsResult], | |
| search_text: str | None = None, | |
| max_num_results: int = 5 | |
| ) -> Iterator[ElasticHitsResult]: | |
| """Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales. | |
| This will shuffle results | |
| Parameters | |
| ---------- | |
| query_results : Iterable[ElasticHitsResult] | |
| Yields | |
| ------ | |
| Iterator[ElasticHitsResult] | |
| """ | |
| results: list[ElasticHitsResult] = [] | |
| texts: list[str] = [] | |
| for _, data in groupby(query_results, key=lambda x: x.index): | |
| data = list(data) # noqa: PLW2901 | |
| max_score = max(data, key=lambda x: x.score).score | |
| min_score = min(data, key=lambda x: x.score).score | |
| for d in data: | |
| d.score = (d.score - min_score) / (max_score - min_score + 1e-9) | |
| results.append(d) | |
| if search_text: | |
| if d.inner_hits: | |
| text = retrieved_text(d.inner_hits) | |
| if d.highlight: | |
| highlight_texts = [] | |
| for k, v in d.highlight.items(): | |
| highlight_texts.append('\n'.join(v)) | |
| text = '\n'.join(highlight_texts) | |
| texts.append(text) | |
| if search_text and len(texts) == len(results) and len(texts) > max_num_results: | |
| logger.info("Re-ranking %d retrieval results", len(results)) | |
| scores = sparse_encoder.query_reranking(query=search_text, documents=texts) | |
| for r, s in zip(results, scores): | |
| r.score = s | |
| yield from sorted(results, key=lambda x: x.score, reverse=True)[:max_num_results] | |
| def process_hit(hit: ElasticHitsResult) -> Document: | |
| """Process a raw Elasticsearch document into a structured langchain `Document` object. | |
| Parameters | |
| ---------- | |
| hit : ElasticHitsResult | |
| Returns | |
| ------- | |
| Document | |
| Raises | |
| ------ | |
| ValueError | |
| Raised if a result from an unknown index is passed in | |
| """ | |
| nlp = CandidSmallLanguageModel() | |
| if "issuelab-elser" in hit.index: | |
| doc = Document( | |
| page_content='\n\n'.join([ | |
| hit.source.get("combined_item_description", ""), | |
| hit.source.get("description", ""), | |
| hit.source.get("combined_issuelab_findings", ""), | |
| get_context("content", hit, context_length=12) | |
| ]), | |
| metadata={ | |
| "title": hit.source["title"], | |
| "source": "IssueLab", | |
| "source_id": hit.source["resource_id"], | |
| "url": hit.source.get("permalink", "") | |
| } | |
| ) | |
| elif "issuelab" in hit.index: | |
| content_summary = "" | |
| if hit.source.get("content", ""): | |
| content_summary = nlp.summarize(hit.source.get("content", ""), top_k=20).summary | |
| doc = Document( | |
| page_content='\n\n'.join([hit.source.get("description", ""), content_summary]), | |
| metadata={ | |
| "title": hit.source["title"], | |
| "source": "IssueLab", | |
| "source_id": hit.source["issuelab_id"], | |
| "url": hit.source.get("issuelab_url", "") | |
| } | |
| ) | |
| elif "youtube" in hit.index: | |
| highlight = hit.highlight or {} | |
| doc = Document( | |
| page_content='\n\n'.join([ | |
| hit.source.get("title", ""), | |
| hit.source.get("semantic_description", ""), | |
| ' '.join(highlight.get("semantic_cc_text", [])) | |
| ]), | |
| metadata={ | |
| "title": hit.source.get("title", ""), | |
| "source": "Candid YouTube", | |
| "source_id": hit.source['video_id'], | |
| "url": f"https://www.youtube.com/watch?v={hit.source['video_id']}" | |
| } | |
| ) | |
| elif "blog" in hit.index: | |
| highlight = hit.highlight or {} | |
| blog_url = hit.source.get("link", "") | |
| doc = Document( | |
| page_content='\n\n'.join([ | |
| hit.source.get("title_summary_tags_text", ""), | |
| ' '.join(highlight.get("semantic_content", [])), | |
| hit.source.get("authors_text", "") | |
| ]), | |
| metadata={ | |
| "title": hit.source.get("title", ""), | |
| "source": "Candid Blog", | |
| "source_id": hit.source["id"], | |
| "url": blog_url | |
| } | |
| ) | |
| elif "learning" in hit.index: | |
| highlight = hit.highlight or {} | |
| doc = Document( | |
| page_content='\n\n'.join([ | |
| hit.source.get("semantic_title_short_description", ""), | |
| ' '.join(highlight.get("semantic_lessons_content", [])) | |
| ]), | |
| metadata={ | |
| "title": hit.source["title"], | |
| "source": "Candid Learning", | |
| "source_id": hit.source["course_id"], | |
| "url": hit.source.get("course_url", "") | |
| } | |
| ) | |
| elif "help" in hit.index: | |
| highlight = hit.highlight or {} | |
| doc = Document( | |
| page_content='\n\n'.join([ | |
| hit.source.get("semantic_title_summary_question_category", ""), | |
| ' '.join(highlight.get("semantic_content", [])) | |
| ]), | |
| metadata={ | |
| "title": hit.source.get("title", ""), | |
| "source": "Candid Help", | |
| "source_id": hit.source["article_id"], | |
| "url": f"""https://help.candid.org/s/article/{hit.source.get("url", "")}""" | |
| } | |
| ) | |
| elif "news" in hit.index: | |
| doc = Document( | |
| page_content='\n\n'.join([hit.source.get("title", ""), hit.source.get("content", "")]), | |
| metadata={ | |
| "title": hit.source.get("title", ""), | |
| "source": hit.source.get("site_name") or "Candid News", | |
| "source_id": hit.source["id"], | |
| "url": hit.source.get("link", "") | |
| } | |
| ) | |
| else: | |
| raise ValueError(f"Unknown source result from index {hit.index}") | |
| return doc | |