brainsqueeze commited on
Commit
fc7f1bb
·
verified ·
1 Parent(s): f002446

Delete ask_candid/base/retrieval/knowledge_base.py

Browse files
ask_candid/base/retrieval/knowledge_base.py DELETED
@@ -1,362 +0,0 @@
1
- from typing import Literal, Any
2
- from collections.abc import Iterator, Iterable
3
- from itertools import groupby
4
- import logging
5
-
6
- from langchain_core.documents import Document
7
-
8
- from ask_candid.base.retrieval.elastic import (
9
- build_sparse_vector_query,
10
- build_sparse_vector_and_text_query,
11
- news_query_builder,
12
- multi_search_base
13
- )
14
- from ask_candid.base.retrieval.sparse_lexical import SpladeEncoder
15
- from ask_candid.base.retrieval.schemas import ElasticHitsResult
16
- import ask_candid.base.retrieval.sources as S
17
- from ask_candid.services.small_lm import CandidSLM
18
-
19
- from ask_candid.base.config.connections import SEMANTIC_ELASTIC_QA, NEWS_ELASTIC
20
-
21
- SourceNames = Literal[
22
- "Candid Blog",
23
- "Candid Help",
24
- "Candid Learning",
25
- "Candid News",
26
- "IssueLab Research Reports",
27
- "YouTube Training"
28
- ]
29
- sparse_encoder = SpladeEncoder()
30
- logging.basicConfig(format="[%(levelname)s] (%(asctime)s) :: %(message)s")
31
- logger = logging.getLogger(__name__)
32
- logger.setLevel(logging.INFO)
33
-
34
-
35
- # TODO remove
36
- def get_context(field_name: str, hit: ElasticHitsResult, context_length: int = 1024, add_context: bool = True) -> str:
37
- """Pads the relevant chunk of text with context before and after
38
-
39
- Parameters
40
- ----------
41
- field_name : str
42
- a field with the long text that was chunked into pieces
43
- hit : ElasticHitsResult
44
- context_length : int, optional
45
- length of text to add before and after the chunk, by default 1024
46
- add_context : bool, optional
47
- Set to `False` to expand the text context by searching for the Elastic inner hit inside the larger document
48
- , by default True
49
-
50
- Returns
51
- -------
52
- str
53
- longer chunks stuffed together
54
- """
55
-
56
- chunks = []
57
- # NOTE chunks have tokens, long text is a string, but may contain html which affects tokenization
58
- long_text = hit.source.get(field_name) or ""
59
- long_text = long_text.lower()
60
-
61
- inner_hits_field = f"embeddings.{field_name}.chunks"
62
- found_chunks = hit.inner_hits.get(inner_hits_field, {}) if hit.inner_hits else None
63
- if found_chunks:
64
- for h in found_chunks.get("hits", {}).get("hits") or []:
65
- chunk = h.get("fields", {})[inner_hits_field][0]["chunk"][0]
66
-
67
- # cutting the middle because we may have tokenizing artifacts there
68
- chunk = chunk[3: -3]
69
-
70
- if add_context:
71
- # Find the start and end indices of the chunk in the large text
72
- start_index = long_text.find(chunk[:20])
73
-
74
- # Chunk is found
75
- if start_index != -1:
76
- end_index = start_index + len(chunk)
77
- pre_start_index = max(0, start_index - context_length)
78
- post_end_index = min(len(long_text), end_index + context_length)
79
- chunks.append(long_text[pre_start_index:post_end_index])
80
- else:
81
- chunks.append(chunk)
82
- return '\n\n'.join(chunks)
83
-
84
-
85
- def generate_queries(
86
- query: str,
87
- sources: list[SourceNames],
88
- news_days_ago: int = 60
89
- ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
90
- """Builds Elastic queries against indices which do or do not support sparse vector queries.
91
-
92
- Parameters
93
- ----------
94
- query : str
95
- Text describing a user's question or a description of investigative work which requires support from Candid's
96
- knowledge base
97
- sources : list[SourceNames]
98
- One or more sources of knowledge from different areas at Candid.
99
- * Candid Blog: Blog posts from Candid staff and trusted partners intended to help those in the sector or
100
- illuminate ongoing work
101
- * Candid Help: Candid FAQs to help user's get started with Candid's product platform and learning resources
102
- * Candid Learning: Training documents from Candid's subject matter experts
103
- * Candid News: News articles and press releases about real-time activity in the philanthropic sector
104
- * IssueLab Research Reports: Academic research reports about the social/philanthropic sector
105
- * YouTube Training: Transcripts from video-based training seminars from Candid's subject matter experts
106
- news_days_ago : int, optional
107
- How many days in the past to search for news articles, if a user is asking for recent trends then this value
108
- should be set lower >~ 10, by default 60
109
-
110
- Returns
111
- -------
112
- tuple[list[dict[str, Any]], list[dict[str, Any]]]
113
- (sparse vector queries, queries for indices which do not support sparse vectors)
114
- """
115
-
116
- vector_queries = []
117
- quasi_vector_queries = []
118
-
119
- for source_name in sources:
120
- if source_name == "Candid Blog":
121
- q = build_sparse_vector_query(query=query, fields=S.CandidBlogConfig.semantic_fields)
122
- q["_source"] = {"excludes": ["embeddings"]}
123
- q["size"] = 5
124
- vector_queries.extend([{"index": S.CandidBlogConfig.index_name}, q])
125
- elif source_name == "Candid Help":
126
- q = build_sparse_vector_query(query=query, fields=S.CandidHelpConfig.semantic_fields)
127
- q["_source"] = {"excludes": ["embeddings"]}
128
- q["size"] = 5
129
- vector_queries.extend([{"index": S.CandidHelpConfig.index_name}, q])
130
- elif source_name == "Candid Learning":
131
- q = build_sparse_vector_query(query=query, fields=S.CandidLearningConfig.semantic_fields)
132
- q["_source"] = {"excludes": ["embeddings"]}
133
- q["size"] = 5
134
- vector_queries.extend([{"index": S.CandidLearningConfig.index_name}, q])
135
- elif source_name == "Candid News":
136
- q = news_query_builder(
137
- query=query,
138
- fields=S.CandidNewsConfig.semantic_fields,
139
- encoder=sparse_encoder,
140
- days_ago=news_days_ago
141
- )
142
- q["size"] = 5
143
- quasi_vector_queries.extend([{"index": S.CandidNewsConfig.index_name}, q])
144
- elif source_name == "IssueLab Research Reports":
145
- q = build_sparse_vector_query(query=query, fields=S.IssueLabConfig.semantic_fields)
146
- q["_source"] = {"excludes": ["embeddings"]}
147
- q["size"] = 1
148
- vector_queries.extend([{"index": S.IssueLabConfig.index_name}, q])
149
- elif source_name == "YouTube Training":
150
- q = build_sparse_vector_and_text_query(
151
- query=query,
152
- semantic_fields=S.YoutubeConfig.semantic_fields,
153
- text_fields=S.YoutubeConfig.text_fields,
154
- highlight_fields=S.YoutubeConfig.highlight_fields,
155
- excluded_fields=S.YoutubeConfig.excluded_fields
156
- )
157
- q["size"] = 5
158
- vector_queries.extend([{"index": S.YoutubeConfig.index_name}, q])
159
-
160
- return vector_queries, quasi_vector_queries
161
-
162
-
163
- def run_search(
164
- vector_searches: list[dict[str, Any]] | None = None,
165
- non_vector_searches: list[dict[str, Any]] | None = None,
166
- ) -> list[ElasticHitsResult]:
167
- def _msearch_response_generator(responses: Iterable[dict[str, Any]]) -> Iterator[ElasticHitsResult]:
168
- for query_group in responses:
169
- for h in query_group.get("hits", {}).get("hits", []):
170
- inner_hits = h.get("inner_hits", {})
171
-
172
- if not inner_hits and "news" in h.get("_index"):
173
- inner_hits = {"text": h.get("_source", {}).get("content")}
174
-
175
- yield ElasticHitsResult(
176
- index=h["_index"],
177
- id=h["_id"],
178
- score=h["_score"],
179
- source=h["_source"],
180
- inner_hits=inner_hits,
181
- highlight=h.get("highlight", {})
182
- )
183
-
184
- results = []
185
- if vector_searches is not None and len(vector_searches) > 0:
186
- hits = multi_search_base(queries=vector_searches, credentials=SEMANTIC_ELASTIC_QA)
187
- for hit in _msearch_response_generator(responses=hits):
188
- results.append(hit)
189
- if non_vector_searches is not None and len(non_vector_searches) > 0:
190
- hits = multi_search_base(queries=non_vector_searches, credentials=NEWS_ELASTIC)
191
- for hit in _msearch_response_generator(responses=hits):
192
- results.append(hit)
193
- return results
194
-
195
-
196
- def retrieved_text(hits: dict[str, Any]) -> str:
197
- """Extracts retrieved sub-texts from documents which are strong hits from semantic queries for the purpose of
198
- re-scoring by a secondary language model.
199
-
200
- Parameters
201
- ----------
202
- hits : Dict[str, Any]
203
-
204
- Returns
205
- -------
206
- str
207
- """
208
-
209
- nlp = CandidSLM()
210
-
211
- text = []
212
- for _, v in hits.items():
213
- if _ == "text":
214
- s = nlp.summarize(v, top_k=3)
215
- text.append(s.summary)
216
- # text.append(v)
217
- continue
218
-
219
- for h in (v.get("hits", {}).get("hits") or []):
220
- for _, field in h.get("fields", {}).items():
221
- for chunk in field:
222
- if chunk.get("chunk"):
223
- text.extend(chunk["chunk"])
224
- return '\n'.join(text)
225
-
226
-
227
- def reranker(
228
- query_results: Iterable[ElasticHitsResult],
229
- search_text: str | None = None,
230
- max_num_results: int = 5
231
- ) -> Iterator[ElasticHitsResult]:
232
- """Reranks Elasticsearch hits coming from multiple indices/queries which may have scores on different scales.
233
- This will shuffle results
234
-
235
- Parameters
236
- ----------
237
- query_results : Iterable[ElasticHitsResult]
238
-
239
- Yields
240
- ------
241
- Iterator[ElasticHitsResult]
242
- """
243
-
244
- results: list[ElasticHitsResult] = []
245
- texts: list[str] = []
246
- for _, data in groupby(query_results, key=lambda x: x.index):
247
- data = list(data) # noqa: PLW2901
248
- max_score = max(data, key=lambda x: x.score).score
249
- min_score = min(data, key=lambda x: x.score).score
250
-
251
- for d in data:
252
- d.score = (d.score - min_score) / (max_score - min_score + 1e-9)
253
- results.append(d)
254
-
255
- if search_text:
256
- if d.inner_hits:
257
- text = retrieved_text(d.inner_hits)
258
- if d.highlight:
259
- highlight_texts = []
260
- for k,v in d.highlight.items():
261
- v_text = '\n'.join(v)
262
- highlight_texts.append(v_text)
263
- text = '\n'.join(highlight_texts)
264
- texts.append(text)
265
-
266
- if search_text and len(texts) == len(results) and len(texts) > 1:
267
- logger.info("Re-ranking %d retrieval results", len(results))
268
- scores = sparse_encoder.query_reranking(query=search_text, documents=texts)
269
- for r, s in zip(results, scores):
270
- r.score = s
271
-
272
- yield from sorted(results, key=lambda x: x.score, reverse=True)[:max_num_results]
273
-
274
-
275
- def process_hit(hit: ElasticHitsResult) -> Document:
276
- if "issuelab-elser" in hit.index:
277
- doc = Document(
278
- page_content='\n\n'.join([
279
- hit.source.get("combined_item_description", ""),
280
- hit.source.get("description", ""),
281
- hit.source.get("combined_issuelab_findings", ""),
282
- get_context("content", hit, context_length=12)
283
- ]),
284
- metadata={
285
- "title": hit.source["title"],
286
- "source": "IssueLab",
287
- "source_id": hit.source["resource_id"],
288
- "url": hit.source.get("permalink", "")
289
- }
290
- )
291
- elif "youtube" in hit.index:
292
- highlight = hit.highlight or {}
293
- doc = Document(
294
- page_content='\n\n'.join([
295
- hit.source.get("title", ""),
296
- hit.source.get("semantic_description", ""),
297
- ' '.join(highlight.get("semantic_cc_text", []))
298
- ]),
299
- metadata={
300
- "title": hit.source.get("title", ""),
301
- "source": "Candid YouTube",
302
- "source_id": hit.source['video_id'],
303
- "url": f"https://www.youtube.com/watch?v={hit.source['video_id']}"
304
- }
305
- )
306
- elif "candid-blog" in hit.index:
307
- doc = Document(
308
- page_content='\n\n'.join([
309
- hit.source.get("title", ""),
310
- hit.source.get("excerpt", ""),
311
- get_context("content", hit, context_length=12, add_context=False),
312
- get_context("authors_text", hit, context_length=12, add_context=False),
313
- hit.source.get("title_summary_tags", "")
314
- ]),
315
- metadata={
316
- "title": hit.source.get("title", ""),
317
- "source": "Candid Blog",
318
- "source_id": hit.source["id"],
319
- "url": hit.source["link"]
320
- }
321
- )
322
- elif "candid-learning" in hit.index:
323
- doc = Document(
324
- page_content='\n\n'.join([
325
- hit.source.get("title", ""),
326
- hit.source.get("staff_recommendations", ""),
327
- hit.source.get("training_topics", ""),
328
- get_context("content", hit, context_length=12)
329
- ]),
330
- metadata={
331
- "title": hit.source["title"],
332
- "source": "Candid Learning",
333
- "source_id": hit.source["post_id"],
334
- "url": hit.source.get("url", "")
335
- }
336
- )
337
- elif "candid-help" in hit.index:
338
- doc = Document(
339
- page_content='\n\n'.join([
340
- hit.source.get("combined_article_description", ""),
341
- get_context("content", hit, context_length=12)
342
- ]),
343
- metadata={
344
- "title": hit.source.get("title", ""),
345
- "source": "Candid Help",
346
- "source_id": hit.source["id"],
347
- "url": hit.source.get("link", "")
348
- }
349
- )
350
- elif "news" in hit.index:
351
- doc = Document(
352
- page_content='\n\n'.join([hit.source.get("title", ""), hit.source.get("content", "")]),
353
- metadata={
354
- "title": hit.source.get("title", ""),
355
- "source": hit.source.get("site_name") or "Candid News",
356
- "source_id": hit.source["id"],
357
- "url": hit.source.get("link", "")
358
- }
359
- )
360
- else:
361
- raise ValueError(f"Unknown source result from index {hit.index}")
362
- return doc