Vivek Vaddina commited on
Commit
92f1a38
·
unverified ·
1 Parent(s): c289850

♻️ Refactor to improve model performance

Browse files
Files changed (8) hide show
  1. README.md +68 -0
  2. app.py +117 -446
  3. requirements.txt +4 -0
  4. src/config.py +33 -1
  5. src/hyde_rag.py +0 -206
  6. src/main.py +249 -323
  7. src/prompts.yaml +43 -44
  8. src/utils.py +148 -0
README.md CHANGED
@@ -11,3 +11,71 @@ short_description: answer based on input documents
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+ # RAG HYDE
16
+
17
+ This challenge is to build a minimal but powerful Retrieval-Augmented Generation (RAG) workflow inspired by the techniques in the articles already shared.
18
+
19
+ Your solution should:
20
+
21
+ * Ingest and chunk local PDFs efficiently.
22
+ * Use a small, fast embedding model for retrieval.
23
+ * Apply HyDE to improve query relevance.
24
+ * Fuse multiple query variants with Reciprocal Rank Fusion for higher accuracy.
25
+ * Generate concise, context-grounded answers with clear source citations.
26
+
27
+ Goal: Deliver a working RAG example that is fast, lightweight, and high-quality in both retrieval and final answers.
28
+
29
+
30
+ ## Description
31
+
32
+ ### Approach
33
+ - Get the folder containing data
34
+ - Build Corpus & Index
35
+ - Get the user query
36
+ - transform it to separate search & intent (Query Rewrite)
37
+ - generate hypothetical documents from a (short) user query
38
+ - get their corresponding embeddings
39
+ - for each of those embeddings, get relevant results from the corpus
40
+ - fuse them all together using Reciprocal Rank Fusion
41
+ - extract top relevant results
42
+ - pre-process those to make a context and send it to the LLM one last time with the user query
43
+ - receive the final answer based on the user's query & provided context.
44
+
45
+
46
+ ## Run
47
+
48
+ `MODEL_COMBOS` in [config.py](config.py) provides multiple variants of embedding & generative LLM model combinations keeping in mind the host system's limitations & capabilities.
49
+
50
+
51
+ ## Tips/Observations:
52
+
53
+
54
+ - Extracting text as a markdown greatly preserved the structure and continuity of the text. This resulted in better logical chunking which in turn led to better embeddings and as a consequence, better search results.
55
+
56
+ - Reading the document via `docling` extracted more and correct text compared to `pymupdf4llm` but at a bit of an expense of speed. It is enabled by default for prioritising accuracy.
57
+ - This proved esp. useful in extracting data containing lots of tables spread over multiple pages.
58
+ - You can pass `--fast-extract` from CLI or tick a box via gradio UI to use pymupdf instead.
59
+
60
+ - Increasing the model size (coupled with correct text extraction in markdown) greatly improved performance. The Qwen3 models very much adhered to instructions but the smaller variants instead of hallucinating simply fell back to saying _'I don't know'_ (as per instructions). The `4B` variant understood the user intent which sometimes was vague and yet managed to give relevant results. The base variant is huge and it wouldn't have been fit and run fast enough on a consumer grade laptop GPU. Loading the `AWQ` variant of it helped as it occupied substantially less memory compared to the original without much loss in performance.
61
+
62
+ - This model also showed great multilingual capabilities. User can upload document in one language and ask questions in another. Or they could upload multilingual documents and ask multilingual queries. For the demo, I tested mostly in English & German.
63
+
64
+ - The data is now stored in datasets format that allows for better storage & scaling (arrow) along with indexing (FAISS) for querying.
65
+
66
+ ---
67
+
68
+ ## Limitations / Known Issues
69
+
70
+ - Even though `docling` with mostly default options proved to be better than `pymupdf4llm` to extract text, it's not perfect everytime. There're instances where _pymupdf_ extracted text from an embedded image inside a PDF better than docling. However, docling is highly configurable and allows for deep customization via 'pipelines'. And it also comes with a very permissive license for commercial use compared to PyMuPDF.
71
+ - docling comes with `easyocr` by default for text OCR. It's not powerful enough compared to _tesseract_ or similar models. But since installing the latter and linking it with docling involves touching system config, it's not pursued.
72
+
73
+ - When user uploads multiple PDFs, we can improve load times by reading them asynchronously. Attempts to do that with `docling` sometimes resulted in pages with ordering different than the original. So it's dropped for the demo. More investigation is needed later.
74
+
75
+
76
+ ## Next Steps
77
+
78
+ - Checkout [EmbeddingGemma](https://huggingface.co/blog/embeddinggemma) for embeddings
79
+ - Checkout [fastembed](https://github.com/qdrant/fastembed) to generate embeddings faster
80
+ - Improve text extraction via docling pipeline
81
+ - Checkout `GGUF` models for CPU Inferencing
app.py CHANGED
@@ -1,474 +1,145 @@
1
- import re
2
- import math
3
- import yaml
4
- import json
5
- import torch
6
- import faiss
7
  import string
8
- import asyncio
9
- import pymupdf
10
  import gradio as gr
 
 
 
11
 
12
- from time import time
13
- from pathlib import Path
14
- from functools import lru_cache
15
- from ast import literal_eval
16
- from collections import defaultdict
17
- from concurrent.futures import ThreadPoolExecutor
18
- from sentence_transformers import SentenceTransformer
19
- from langchain_text_splitters import SentenceTransformersTokenTextSplitter
20
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
21
 
22
- from src.config import PROMPTS_FILEPATH, log
23
-
24
-
25
- async def load_pdfs(files, max_concurrence=5):
26
- """
27
- Load multiple PDF files async.
28
-
29
- Args:
30
- folder (str): Path to folder containing PDF files
31
- max_concurrence (int): Maximum number of concurrent PDF processing tasks
32
-
33
- Returns:
34
- list: List of tuples containing (filename, extracted_text)
35
- """
36
-
37
- def _load_pdf_sync(file):
38
- """Synchronous PDF loading function for thread pool execution"""
39
- text = ""
40
- try:
41
- with pymupdf.open(file, filetype="pdf") as doc:
42
- text = "\n".join(page.get_text() for page in doc)
43
- except Exception:
44
- log.exception(f"Error reading {file.name}")
45
- pass
46
-
47
- return (file.name, text)
48
-
49
- loop = asyncio.get_event_loop()
50
- with ThreadPoolExecutor(max_workers=max_concurrence) as executor:
51
- futures = [
52
- loop.run_in_executor(executor, _load_pdf_sync, file)
53
- for file in files
54
- if file is not None
55
- ]
56
-
57
- results = await asyncio.gather(*futures, return_exceptions=True)
58
-
59
- valid_results = [result for result in results if not isinstance(result, Exception)]
60
-
61
- log.info(f"successfully processed {len(valid_results)} out of {len(files)} PDFs ")
62
- return valid_results
63
-
64
-
65
- async def build_corpus(pdfs, text_splitter, **load_kwargs):
66
- texts = await load_pdfs(pdfs, **load_kwargs)
67
- corpus, meta = [], []
68
- for file_name, raw_text in texts:
69
- chunks = text_splitter.split_text(raw_text)
70
- for i, chunk in enumerate(chunks):
71
- corpus.append(chunk)
72
- meta.append({"file": file_name, "chunk_id": i})
73
- return corpus, meta
74
-
75
-
76
- def generate_text(
77
- tokenizer, model, user_prompts, system_prompt=None, **llm_kwargs
78
- ): # max_new_tokens=512, temperature=.4):
79
- if system_prompt is None or "":
80
- system_prompt = "You are a helpful assistant."
81
-
82
- if isinstance(user_prompts, str):
83
- user_prompts = [user_prompts]
84
-
85
- messages = [
86
- [
87
- {"role": "system", "content": system_prompt},
88
- {"role": "user", "content": user_prompt},
89
- ]
90
- for user_prompt in user_prompts
91
- ]
92
-
93
- texts = tokenizer.apply_chat_template(
94
- messages, tokenize=False, add_generation_prompt=True
95
- )
96
-
97
- model_inputs = tokenizer(
98
- texts, return_tensors="pt", truncation=True, padding=True
99
- ).to(model.device)
100
- generated_ids = model.generate(
101
- **model_inputs,
102
- max_new_tokens=llm_kwargs.pop("max_new_tokens", 512),
103
- temperature=llm_kwargs.pop("temperature", 0.4),
104
- )
105
- generated_ids = [
106
- output_ids[len(input_ids) :]
107
- for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
108
- ]
109
- response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
110
- return response if len(user_prompts) > 1 else response[0]
111
-
112
-
113
- def load_models(
114
- embed_model_name: str,
115
- gen_model_name: str,
116
- causal_lm: bool = False,
117
- device=None,
118
- bitsandbytesconfig=None,
119
- ):
120
- # This will take some time to run for the first time if the model(s) don't exist locally.
121
- if not device:
122
- device = "cuda" if torch.cuda.is_available() else "cpu"
123
- embedder = SentenceTransformer(
124
- embed_model_name,
125
- device=device,
126
- model_kwargs={"dtype": "float16"} if device == "cuda" else {},
127
- )
128
-
129
- if not causal_lm:
130
- tok = AutoTokenizer.from_pretrained(gen_model_name)
131
- gen = AutoModelForSeq2SeqLM.from_pretrained(
132
- gen_model_name, # device_map='auto',
133
- quantization_config=bitsandbytesconfig if bitsandbytesconfig else None,
134
- )
135
- else:
136
- tok = AutoTokenizer.from_pretrained(gen_model_name, padding_side="left")
137
- gen = AutoModelForCausalLM.from_pretrained(
138
- gen_model_name,
139
- dtype="float16", # device_map='auto',
140
- quantization_config=bitsandbytesconfig if bitsandbytesconfig else None,
141
- )
142
- gen.to(device)
143
- return embedder, tok, gen
144
-
145
-
146
- def make_query_variants(
147
- tokenizer, model, query: str, prompt: str, n: int = 3, **llm_kwargs
148
- ):
149
- instructions = f"Now give me at least {n} variations."
150
- resp = generate_text(tokenizer, model, query + instructions, prompt, **llm_kwargs)
151
-
152
- clean_resp = re.sub(r"^\d+\.\s*", "", resp, flags=re.MULTILINE).split("\n")
153
- return [query] + [q for q in clean_resp if q.strip()]
154
-
155
-
156
- def clean_rewrite_resp(resp):
157
- try:
158
- resp = json.loads(resp) # Parse JSON
159
- except json.JSONDecodeError:
160
- try:
161
- resp = literal_eval(resp) # Fallback parse
162
- except Exception:
163
- pass # Keep resp as-is if both fail
164
-
165
- # Ensure resp is a string before strip and slicing
166
- if isinstance(resp, str):
167
- resp = resp.strip()
168
- if resp:
169
- start = resp.find("{")
170
- if start != -1:
171
- end = resp[::-1].find("}")
172
- if end != -1:
173
- resp = resp[start : len(resp) - end]
174
- return clean_rewrite_resp(resp)
175
- return resp
176
-
177
-
178
- def transform_query(
179
- tokenizer, model, query: str, rewrite_prompt: str, **llm_kwargs
180
- ) -> dict:
181
- """split the query into things to search and actions to take"""
182
- resp = generate_text(tokenizer, model, query, rewrite_prompt, **llm_kwargs)
183
- try:
184
- resp = clean_rewrite_resp(resp)
185
- except:
186
- pass
187
- return resp
188
-
189
-
190
- def aggregate_queries_and_tasks(
191
- tokenizer,
192
- model,
193
- orig_query,
194
- rewrite_prompt,
195
- variants_prompt,
196
- n_variations=3,
197
- **llm_kwargs,
198
  ):
199
- # make variations for the original query as is
200
- queries = make_query_variants(
201
- tokenizer,
202
- model,
203
- orig_query.strip(),
204
- variants_prompt,
205
- n_variations,
206
- **llm_kwargs,
 
207
  )
208
 
209
- start = time()
210
- tr_q = transform_query(tokenizer, model, orig_query.strip(), rewrite_prompt)
211
- end = time()
212
- log.debug(f"\t\t transforming query task took {(end - start):.1f} seconds...")
213
-
214
- # transformed query might have multiple things to search and tasks to perform depending on user query
215
- # recursively get variations for each of the search queries but keep the tasks as is.
216
- tasks = []
217
- if isinstance(tr_q, dict):
218
- search_results, tasks = tr_q.get("search", []), tr_q.get("tasks", [])
219
- for search_result in search_results:
220
- queries.extend(
221
- make_query_variants(
222
- tokenizer,
223
- model,
224
- search_result,
225
- variants_prompt,
226
- n_variations,
227
- **llm_kwargs,
228
- )
229
- )
230
-
231
- queries = [q.strip(string.punctuation) for q in queries]
232
- tasks = [t.strip(string.punctuation) for t in tasks]
233
-
234
- # keep the original user query as is (if in case LLM messes up the original query) and pick some after shuffling the rest
235
- # This is disabled as we don't do loops and instead take advantage of batches.
236
- # Since it's efficient, we can take many query variations at once without worrying about performance.
237
- # q, queries = queries[:1], queries[1:]
238
- # shuffle(queries)
239
- # q += queries[:n_variations-1]
240
-
241
- return queries, tasks
242
-
243
-
244
- def build_index(corpus_emb, n_cells=5, n_probe=2):
245
- log.debug(f"building index with {n_cells=}, {n_probe=}")
246
- d = corpus_emb.shape[1]
247
- quantizer = faiss.IndexFlatIP(d)
248
- index = faiss.IndexIVFFlat(quantizer, d, n_cells)
249
- index.n_probe = n_probe
250
- index.train(corpus_emb)
251
- index.add(corpus_emb)
252
- # index.make_direct_map()
253
- return index
254
 
255
 
256
- def reciprocal_rank_fusion(indices, top_k=3, denom=50):
257
- ii = indices.tolist()
258
- scores = defaultdict(int)
259
- for row in ii:
260
- for rank, chunk_id in enumerate(row):
261
- scores[chunk_id] += 1 / (rank + denom)
262
- results = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
263
- return [chunk_id for chunk_id, _ in results]
264
-
265
-
266
- class HyDeRAGFusion:
267
- def __init__(
268
- self,
269
- embed_model: str,
270
- generator_llm_model: str,
271
- causal_lm: bool = True,
272
- chunk_overlap: int = 50,
273
- tokens_per_chunk: int = 256,
274
- embed_batch_size: int = 64,
275
- bitsandbytesconfig=None,
276
- ):
277
- self.embed_batch_size = embed_batch_size
278
- self.text_splitter = SentenceTransformersTokenTextSplitter(
279
- chunk_overlap, embed_model, tokens_per_chunk
280
- )
281
- self.embedder, self.tok, self.gen = load_models(
282
- embed_model, generator_llm_model, causal_lm, bitsandbytesconfig
283
- )
284
- with open(PROMPTS_FILEPATH) as fl:
285
- self.prompts = yaml.safe_load(fl)
286
-
287
- @lru_cache(maxsize=8)
288
- def preprocess_pdfs(self, pdfs, data_load_kwargs={}, faiss_index_kwargs={}):
289
- self.corpus, self.meta = asyncio.run(
290
- build_corpus(pdfs, self.text_splitter, **data_load_kwargs)
291
- )
292
- log.debug(f"{len(self.corpus)}, {len(self.meta)}")
293
- self.corpus_emb = self.embedder.encode(
294
- self.corpus,
295
- batch_size=self.embed_batch_size,
296
- show_progress_bar=True,
297
- normalize_embeddings=True,
298
- )
299
- log.debug(f"{self.corpus_emb.shape}")
300
-
301
- # https://github.com/facebookresearch/faiss/issues/112
302
- # n_cells = int(round(4 * (self.corpus_emb.shape[0])**.5))
303
-
304
- # one centroid for every 100 or so vectors and 20% of them as n_probe
305
- n_cells = faiss_index_kwargs.pop("n_cells", self.corpus_emb.shape[0] // 100 + 1)
306
- n_probe = faiss_index_kwargs.pop("n_probe", math.ceil(0.2 * n_cells))
307
-
308
- self.index = build_index(self.corpus_emb, n_cells, n_probe)
309
-
310
- def retrieve(
311
- self, query, n_variants=3, top_k_per_variant=10, top_k_retrieve=3, **llm_kwargs
312
- ):
313
- start = time()
314
-
315
- queries, tasks = aggregate_queries_and_tasks(
316
- self.tok,
317
- self.gen,
318
- query.strip(),
319
- self.prompts["rewrite"],
320
- self.prompts["variants"],
321
- n_variants,
322
- **llm_kwargs,
323
- )
324
-
325
- end = time()
326
- log.debug(f"aggregate task took {(end - start):.1f} seconds...")
327
-
328
- start = time()
329
- hyde_docs = generate_text(
330
- self.tok, self.gen, queries, self.prompts["hyde"], **llm_kwargs
331
- )
332
- end = time()
333
- log.debug(f"generating hyde docs took {(end - start):.1f} seconds...")
334
-
335
- start = time()
336
- chunks = []
337
- for hyde_doc in hyde_docs:
338
- chunks.extend(self.text_splitter.split_text(hyde_doc))
339
- q_emb = self.embedder.encode(
340
- chunks, batch_size=self.embed_batch_size, normalize_embeddings=True
341
- )
342
- end = time()
343
- log.debug(f"embedding hyde docs took {(end - start):.1f} seconds...")
344
-
345
- _, I = self.index.search(q_emb, top_k_per_variant)
346
- chunk_ids = reciprocal_rank_fusion(I, top_k_retrieve)
347
- return chunk_ids, tasks
348
-
349
- def answer(self, query, doc_ids, tasks, max_ctx_chars=128000):
350
- total, text, prompt_length = 0, "", 10000
351
- sep = "\n\n-----\n\n"
352
- tasks = ", ".join(tasks)
353
-
354
- for doc_id in doc_ids:
355
- # adding tags in the context caused more hallucinations.
356
- # Instead, we list them as sources beneath the model response.
357
- # _meta = self.meta[doc_id]
358
- # tag = f"(source: {_meta['file_name']}:{_meta['chunk_id']})"
359
- chunk = self.corpus[doc_id].strip()
360
- tag = ""
361
-
362
- ctx = f"{sep}{tag}\n\n{chunk}"
363
- if total + len(ctx) + len(tasks) + len(sep) + prompt_length > max_ctx_chars:
364
- break
365
-
366
- text += ctx
367
- total = len(text)
368
-
369
- text += f"{sep}{tasks}"
370
-
371
- # instruction = "Answer concisely and also cite file names & chunk ids inline like (pdf_file_name:chunk_id)."
372
- instruction = "go ahead and answer!"
373
- user_query = f"\nq: {query}\n\nctx:{text}" + f"\n\n{instruction}\n\n"
374
-
375
- start = time()
376
- resp = generate_text(
377
- self.tok,
378
- self.gen,
379
- user_query,
380
- self.prompts["final_answer"],
381
- temperature=0.3,
382
- )
383
- end = time()
384
- log.debug(f"final resp took {(end - start):.1f} seconds...")
385
-
386
- return resp
387
-
388
-
389
- def initial_setup(embed_model, generator_model, bitsandbytesconfig=None):
390
- return HyDeRAGFusion(
391
- embed_model, generator_model, bitsandbytesconfig=bitsandbytesconfig
392
- )
393
-
394
-
395
- start = time()
396
- HRF = initial_setup("sentence-transformers/LaBSE", "Qwen/Qwen2.5-0.5B-Instruct")
397
- end = time()
398
- msg = f"init took {(end - start):.1f} seconds"
399
- log.debug(msg)
400
-
401
 
402
- def main(
403
- pdfs,
404
- query,
405
- n_variants=3,
406
- top_k_per_variant=5,
407
- top_k_retrieve=3,
408
- temperature=0.4,
409
- max_new_tokens=512,
410
- ):
411
- start = time()
412
- if pdfs:
413
- HRF.preprocess_pdfs(tuple(sorted(pdfs)))
414
 
415
- if query:
416
- llm_kwargs = {
417
- "temperature": temperature,
418
- "max_new_tokens": max_new_tokens,
419
- }
420
- doc_ids, tasks = HRF.retrieve(
421
- query,
422
- int(n_variants),
423
- int(top_k_per_variant),
424
- int(top_k_retrieve),
425
- **llm_kwargs,
426
- )
427
- docs = [HRF.corpus[doc_id] for doc_id in doc_ids]
428
- reply = HRF.answer(query, doc_ids, tasks)
429
- sources = [
430
- {
431
- "source": f"{Path(HRF.meta[doc_id]['file']).stem}:{HRF.meta[doc_id]['chunk_id']}",
432
- "content": doc,
433
- }
434
- for doc_id, doc in zip(doc_ids, docs)
435
- ]
436
 
437
- resp = f"{reply}\n\n{'-' * 25}\n\n"
438
- resp += "Top 3 sources:"
439
- resp += f"\n\n{'-' * 25}\n\n"
440
- for source in sources:
441
- resp += f"source: {source['source']}\n\n"
442
- resp += source["content"]
443
- resp += f"\n\n{'-' * 25}\n\n"
444
 
445
- end = time()
446
- log.debug(f"final resp took {(end - start):.1f} seconds")
447
- return resp
448
 
449
 
450
- def reset_text_on_file_change(pdfs):
451
- """
452
- Reset text input when input docs change
453
- """
454
- return ""
455
 
456
 
457
  with gr.Blocks(title="RAG with HYDE") as demo:
458
  gr.Markdown("# RAG with HYDE")
459
  with gr.Row():
460
  pdf_input = gr.File(
461
- label="upload PDF(s)", file_types=[".pdf"], file_count="multiple"
 
 
462
  )
463
- query = gr.Textbox(label="question")
464
 
465
- gr.Markdown("*Please be patient after hitting the submit button*")
466
- btn = gr.Button("Submit")
467
- answer = gr.Markdown(label="### Answer")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
469
- btn.click(main, inputs=[pdf_input, query], outputs=answer)
470
- pdf_input.change(reset_text_on_file_change, inputs=pdf_input, outputs=query)
 
 
 
471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
 
473
  if __name__ == "__main__":
474
  demo.launch()
 
1
+ import sys
 
 
 
 
 
2
  import string
 
 
3
  import gradio as gr
4
+ from src.main import ask
5
+ from src.utils import empty_cache
6
+ from src.config import log
7
 
 
 
 
 
 
 
 
 
 
8
 
9
+ def ask_wrapper(
10
+ pdfs,
11
+ query,
12
+ model_combo_key,
13
+ fast_extract,
14
+ n_variants,
15
+ top_k_per_variant,
16
+ top_k_retrieve,
17
+ temperature,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  ):
19
+ resp = ask.callback(
20
+ pdfs,
21
+ query,
22
+ model_combo_key,
23
+ fast_extract,
24
+ n_variants,
25
+ top_k_per_variant,
26
+ top_k_retrieve,
27
+ temperature,
28
  )
29
 
30
+ return f"## Final Answer:\n\n{resp}\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
 
33
+ def reset(pdfs):
34
+ """
35
+ Reset text input and empty cache
36
+ """
37
+ log.warning("emptying cache")
38
+ empty_cache()
39
+ return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ # Enable the button only when both fields are nonempty
43
+ def _enable_submit_if_filled(pdfs, query):
44
+ status = bool(pdfs) and bool(len(query.strip(string.punctuation + " ")) > 10)
45
+ return gr.update(interactive=status)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
 
 
 
 
 
 
 
47
 
48
+ def disable_button():
49
+ return gr.update(interactive=False, value="Processing...")
 
50
 
51
 
52
+ def enable_button():
53
+ return gr.update(interactive=True, value="Submit")
 
 
 
54
 
55
 
56
  with gr.Blocks(title="RAG with HYDE") as demo:
57
  gr.Markdown("# RAG with HYDE")
58
  with gr.Row():
59
  pdf_input = gr.File(
60
+ label="upload PDF(s)",
61
+ file_types=[".pdf"],
62
+ file_count="multiple",
63
  )
64
+ query = gr.Textbox(label="Question (Enter at least 10 valid characters)")
65
 
66
+ with gr.Accordion("Advanced Settings", open=False):
67
+ gr.Markdown(
68
+ "*These parameters have sensible defaults but can be customized if needed*"
69
+ )
70
+ with gr.Row():
71
+ _default_combo = "linux" if sys.platform == "linux" else "mac"
72
+ model_combo_key = gr.Dropdown(
73
+ label="Model Combo Key",
74
+ choices=[_default_combo, "HF-mid"],
75
+ value=_default_combo,
76
+ )
77
+ fast_extract = gr.Checkbox(
78
+ value=False, label="Use PyMuPDF to extract content in markdown"
79
+ )
80
+ n_variants = gr.Number(
81
+ value=3, minimum=1, maximum=5, label="no. of query variants"
82
+ )
83
+ with gr.Row():
84
+ top_k_per_variant = gr.Number(
85
+ value=5,
86
+ minimum=2,
87
+ maximum=10,
88
+ label="top `k` hits per query variant for RRF",
89
+ )
90
+ top_k_retrieve = gr.Number(
91
+ value=3,
92
+ minimum=1,
93
+ maximum=5,
94
+ label="top `k` chunks to retrieve after RRF",
95
+ )
96
+ temperature = gr.Slider(
97
+ value=0.7, minimum=0.1, maximum=1.0, step=0.1, label="temperature"
98
+ )
99
 
100
+ gr.Markdown(
101
+ "### *Please be patient after hitting the submit button* esp. for the first question after uploading new document(s)"
102
+ )
103
+ submit_btn = gr.Button("Submit", variant="primary", interactive=False)
104
+ answer = gr.Markdown(label="## Answer")
105
 
106
+ pdf_input.change(
107
+ _enable_submit_if_filled, [pdf_input, query], submit_btn, queue=False
108
+ )
109
+ query.change(_enable_submit_if_filled, [pdf_input, query], submit_btn, queue=False)
110
+
111
+ submit_btn.click(fn=disable_button, outputs=submit_btn).then(
112
+ fn=ask_wrapper,
113
+ inputs=[
114
+ pdf_input,
115
+ query,
116
+ model_combo_key,
117
+ fast_extract,
118
+ n_variants,
119
+ top_k_per_variant,
120
+ top_k_retrieve,
121
+ temperature,
122
+ ],
123
+ outputs=answer,
124
+ ).then(fn=enable_button, outputs=submit_btn)
125
+
126
+ query.submit(fn=disable_button, outputs=submit_btn).then(
127
+ fn=ask_wrapper,
128
+ inputs=[
129
+ pdf_input,
130
+ query,
131
+ model_combo_key,
132
+ fast_extract,
133
+ n_variants,
134
+ top_k_per_variant,
135
+ top_k_retrieve,
136
+ temperature,
137
+ ],
138
+ outputs=answer,
139
+ ).then(fn=enable_button, outputs=submit_btn)
140
+
141
+ pdf_input.change(reset, pdf_input, query)
142
+ demo.load(reset, pdf_input, query)
143
 
144
  if __name__ == "__main__":
145
  demo.launch()
requirements.txt CHANGED
@@ -3,6 +3,10 @@ numpy
3
  transformers
4
  sentence-transformers
5
  pymupdf
 
6
  langchain-text-splitters
7
  #faiss-cpu
8
  faiss-gpu-cu12
 
 
 
 
3
  transformers
4
  sentence-transformers
5
  pymupdf
6
+ pymupdf4llm
7
  langchain-text-splitters
8
  #faiss-cpu
9
  faiss-gpu-cu12
10
+ autoawq
11
+ docling
12
+
src/config.py CHANGED
@@ -8,7 +8,7 @@ def get_logger(LOG_LEVEL="INFO"):
8
  LOG_PATH = Path("logs.log")
9
  formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
10
 
11
- log = logging.Logger("hyde_rag")
12
  log.setLevel(LOG_LEVEL)
13
 
14
  file_handler = logging.FileHandler(LOG_PATH)
@@ -21,3 +21,35 @@ def get_logger(LOG_LEVEL="INFO"):
21
 
22
 
23
  log = get_logger("DEBUG")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  LOG_PATH = Path("logs.log")
9
  formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
10
 
11
+ log = logging.Logger("agentic_search")
12
  log.setLevel(LOG_LEVEL)
13
 
14
  file_handler = logging.FileHandler(LOG_PATH)
 
21
 
22
 
23
  log = get_logger("DEBUG")
24
+
25
+ MODEL_COMBOS = {
26
+ "linux": {
27
+ "embed_model": "Qwen/Qwen3-Embedding-0.6B",
28
+ "gen_model": "Qwen/Qwen3-4B-AWQ",
29
+ # 'gen_model': "Qwen/Qwen3-0.6B-GPTQ-Int8"
30
+ # 'gen_model': "Qwen/Qwen3-1.7B-GPTQ-Int8"
31
+ },
32
+ # feel free to replace with any ??B-MLX-?bit versions from Qwen3 Collection at:
33
+ # https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f
34
+ "mac": {
35
+ "embed_model": "Qwen/Qwen3-Embedding-0.6B",
36
+ "gen_model": "Qwen/Qwen3-4B-MLX-4bit",
37
+ },
38
+ "mac_mid": {
39
+ "embed_model": "Qwen/Qwen3-Embedding-0.6B",
40
+ "gen_model": "Qwen/Qwen3-4B-MLX-6bit",
41
+ },
42
+ "mac_high": {
43
+ "embed_model": "Qwen/Qwen3-Embedding-0.6B",
44
+ "gen_model": "Qwen/Qwen3-4B-MLX-8bit",
45
+ },
46
+ # HF-low is same as `linux-local`
47
+ "HF-mid": {
48
+ "embed_model": "Qwen/Qwen3-Embedding-0.6B",
49
+ "gen_model": "Qwen/Qwen3-8B-AWQ",
50
+ },
51
+ "HF-high": {
52
+ "embed_model": "Qwen/Qwen3-Embedding-4B",
53
+ "gen_model": "Qwen/Qwen3-14B-AWQ",
54
+ },
55
+ }
src/hyde_rag.py DELETED
@@ -1,206 +0,0 @@
1
- # hyde_ragfusion.py
2
- # Minimal HyDE + RAG-Fusion over local PDFs.
3
- # Dependencies: transformers, sentence-transformers, scikit-learn, pymupdf, numpy
4
-
5
- import os
6
- import re
7
- import heapq
8
- import fitz # PyMuPDF
9
- from sklearn.neighbors import NearestNeighbors
10
- from sentence_transformers import SentenceTransformer
11
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
12
-
13
-
14
- # -----------------------------
15
- # Ingestion & Chunking
16
- # -----------------------------
17
- def load_pdfs(folder):
18
- docs = []
19
- for fn in os.listdir(folder):
20
- if fn.lower().endswith(".pdf"):
21
- path = os.path.join(folder, fn)
22
- with fitz.open(path) as doc:
23
- text = "\n".join(page.get_text("text") for page in doc)
24
- text = re.sub(r"\s+\n", "\n", text).strip()
25
- docs.append((fn, text))
26
- return docs
27
-
28
-
29
- def chunk_text(text, chunk_size=300, overlap=50):
30
- words = text.split()
31
- chunks, i = [], 0
32
- while i < len(words):
33
- chunk = " ".join(words[i : i + chunk_size])
34
- chunks.append(chunk)
35
- i += chunk_size - overlap
36
- return chunks
37
-
38
-
39
- def build_corpus(pdf_folder):
40
- raw = load_pdfs(pdf_folder)
41
- corpus, meta = [], []
42
- for fn, txt in raw:
43
- for i, ch in enumerate(chunk_text(txt)):
44
- corpus.append(ch)
45
- meta.append({"file": fn, "chunk_id": i})
46
- return corpus, meta
47
-
48
-
49
- # -----------------------------
50
- # Models (local)
51
- # -----------------------------
52
- def load_models():
53
- # Small, fast encoder for embeddings
54
- embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
55
- # Lightweight local generator for HyDE + answers
56
- gen_name = "google/flan-t5-base"
57
- tok = AutoTokenizer.from_pretrained(gen_name)
58
- gen = AutoModelForSeq2SeqLM.from_pretrained(gen_name)
59
- return embedder, tok, gen
60
-
61
-
62
- # -----------------------------
63
- # Index (cosine)
64
- # -----------------------------
65
- def fit_index(embeddings, n_neighbors=12):
66
- nn = NearestNeighbors(metric="cosine", algorithm="auto")
67
- nn.fit(embeddings)
68
- return nn
69
-
70
-
71
- # -----------------------------
72
- # RAG-Fusion (query variants) + HyDE
73
- # -----------------------------
74
- Q_VARIANTS_PROMPT = """You rewrite the user query into {n} diverse, specific search queries (short).
75
- User query: "{q}"
76
- Return each on a new line, no numbering, no extra text."""
77
-
78
- HYDE_PROMPT = """Write a factual, neutral, self-contained paragraph that could answer:
79
- "{q}"
80
- Avoid fluff. Include likely key terms and entities. 120-180 words."""
81
-
82
- ANSWER_PROMPT = """You are a helpful assistant. Use ONLY the provided context.
83
- Question: {q}
84
-
85
- Context:
86
- {ctx}
87
-
88
- Answer concisely and cite file names & chunk ids inline like (file:chunk).
89
- """
90
-
91
-
92
- def generate_text(gen, tok, prompt, max_new_tokens=160, temperature=0.3):
93
- inputs = tok(prompt, return_tensors="pt")
94
- out = gen.generate(
95
- **inputs,
96
- max_new_tokens=max_new_tokens,
97
- do_sample=False,
98
- temperature=temperature,
99
- )
100
- return tok.decode(out[0], skip_special_tokens=True).strip()
101
-
102
-
103
- def make_query_variants(gen, tok, q, n=4):
104
- txt = generate_text(
105
- gen, tok, Q_VARIANTS_PROMPT.format(q=q, n=n), max_new_tokens=120
106
- )
107
- # Split cleanly into lines (drop empties/dups; include original)
108
- lines = [l.strip(" -•\t") for l in txt.split("\n") if l.strip()]
109
- uniq = []
110
- seen = set()
111
- for l in lines + [q]:
112
- if l not in seen:
113
- seen.add(l)
114
- uniq.append(l)
115
- return uniq[:n]
116
-
117
-
118
- def hyde_doc(gen, tok, q):
119
- return generate_text(gen, tok, HYDE_PROMPT.format(q=q), max_new_tokens=220)
120
-
121
-
122
- # -----------------------------
123
- # Retrieval + RRF
124
- # -----------------------------
125
- def cosine_search(nn, corpus_embeddings, query_vec, top_k=8):
126
- dists, idxs = nn.kneighbors(query_vec.reshape(1, -1), n_neighbors=top_k)
127
- # Convert cosine distance to similarity
128
- sims = 1 - dists[0]
129
- return list(zip(idxs[0].tolist(), sims.tolist()))
130
-
131
-
132
- def reciprocal_rank_fusion(rank_lists, k=60, top_k=8):
133
- # rank_lists: list of [doc_id, ...] ordered best→worst
134
- scores = {}
135
- for ranks in rank_lists:
136
- for rank, doc_id in enumerate(ranks, start=1):
137
- scores[doc_id] = scores.get(doc_id, 0.0) + 1.0 / (k + rank)
138
- # top by fused score
139
- best = heapq.nlargest(top_k, scores.items(), key=lambda x: x[1])
140
- return [doc_id for doc_id, _ in best]
141
-
142
-
143
- # -----------------------------
144
- # Pipeline
145
- # -----------------------------
146
- class HyDeRAGFusion:
147
- def __init__(self, pdf_folder):
148
- self.corpus, self.meta = build_corpus(pdf_folder)
149
- self.embedder, self.tok, self.gen = load_models()
150
- self.corpus_emb = self.embedder.encode(
151
- self.corpus,
152
- batch_size=64,
153
- show_progress_bar=True,
154
- normalize_embeddings=True,
155
- )
156
- self.nn = fit_index(self.corpus_emb)
157
-
158
- def retrieve(self, query, n_variants=4, per_variant_k=8, final_top_k=6, rrf_k=60):
159
- variants = make_query_variants(self.gen, self.tok, query, n=n_variants)
160
- rank_lists = []
161
- for v in variants:
162
- hypo = hyde_doc(self.gen, self.tok, v) # HyDE
163
- q_vec = self.embedder.encode([hypo], normalize_embeddings=True)[0]
164
- hits = cosine_search(self.nn, self.corpus_emb, q_vec, top_k=per_variant_k)
165
- rank_lists.append([doc_id for doc_id, _ in hits])
166
- fused = reciprocal_rank_fusion(rank_lists, k=rrf_k, top_k=final_top_k)
167
- return fused
168
-
169
- def answer(self, query, doc_ids, max_ctx_chars=4000):
170
- # Build compact context with inline provenance
171
- ctx_parts = []
172
- total = 0
173
- for i in doc_ids:
174
- piece = self.corpus[i]
175
- tag = f"(source: {self.meta[i]['file']}:{self.meta[i]['chunk_id']})"
176
- chunk = piece.strip()
177
- if total + len(chunk) + len(tag) + 5 > max_ctx_chars:
178
- break
179
- ctx_parts.append(f"{chunk}\n{tag}")
180
- total += len(chunk) + len(tag) + 5
181
- ctx = "\n\n---\n\n".join(ctx_parts)
182
- prompt = ANSWER_PROMPT.format(q=query, ctx=ctx)
183
- return generate_text(self.gen, self.tok, prompt, max_new_tokens=300)
184
-
185
-
186
- # -----------------------------
187
- # Example usage
188
- # -----------------------------
189
- if __name__ == "__main__":
190
- import argparse
191
-
192
- ap = argparse.ArgumentParser()
193
- ap.add_argument("--pdf_folder", required=True, help="Folder with PDFs to index")
194
- ap.add_argument("--query", required=True, help="Your user question")
195
- ap.add_argument("--show_sources", action="store_true")
196
- args = ap.parse_args()
197
-
198
- rag = HyDeRAGFusion(args.pdf_folder)
199
- doc_ids = rag.retrieve(args.query)
200
- answer = rag.answer(args.query, doc_ids)
201
- print("\n=== ANSWER ===\n")
202
- print(answer)
203
- if args.show_sources:
204
- print("\n=== TOP SOURCES ===")
205
- for i in doc_ids:
206
- print(f"- {rag.meta[i]['file']}:{rag.meta[i]['chunk_id']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/main.py CHANGED
@@ -1,79 +1,32 @@
1
- import pymupdf
2
- import math
3
- import faiss
4
  import string
 
5
  import yaml
6
  import re
7
- import json
8
- import asyncio
9
  import torch
10
- import streamlit as st
11
  import click
12
 
13
- from collections import defaultdict
14
- from ast import literal_eval
15
  from time import time
 
 
16
  from sentence_transformers import SentenceTransformer
17
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
18
- from concurrent.futures import ThreadPoolExecutor
19
- from langchain.text_splitter import SentenceTransformersTokenTextSplitter
20
- from src.config import PROMPTS_FILEPATH, log
21
-
22
-
23
- async def load_pdfs(files, max_concurrence=5):
24
- """
25
- Load multiple PDF files async.
26
-
27
- Args:
28
- folder (str): Path to folder containing PDF files
29
- max_concurrence (int): Maximum number of concurrent PDF processing tasks
30
-
31
- Returns:
32
- list: List of tuples containing (filename, extracted_text)
33
- """
34
-
35
- def _load_pdf_sync(file):
36
- """Synchronous PDF loading function for thread pool execution"""
37
- text = ""
38
- try:
39
- with pymupdf.open(stream=file.getvalue(), filetype="pdf") as doc:
40
- text = "\n".join(page.get_text() for page in doc)
41
- except Exception:
42
- log.exception(f"Error reading {file.name}")
43
- pass
44
-
45
- return (file.name, text)
46
-
47
- loop = asyncio.get_event_loop()
48
- with ThreadPoolExecutor(max_workers=max_concurrence) as executor:
49
- futures = [
50
- loop.run_in_executor(executor, _load_pdf_sync, file)
51
- for file in files
52
- if file is not None
53
- ]
54
-
55
- results = await asyncio.gather(*futures, return_exceptions=True)
56
-
57
- valid_results = [result for result in results if not isinstance(result, Exception)]
58
-
59
- log.info(f"successfully processed {len(valid_results)} out of {len(files)} PDFs ")
60
- return valid_results
61
-
62
-
63
- async def build_corpus(pdfs, text_splitter, **load_kwargs):
64
- texts = await load_pdfs(pdfs, **load_kwargs)
65
- corpus, meta = [], []
66
- for file_name, raw_text in texts:
67
- chunks = text_splitter.split_text(raw_text)
68
- for i, chunk in enumerate(chunks):
69
- corpus.append(chunk)
70
- meta.append({"file": file_name, "chunk_id": i})
71
- return corpus, meta
72
 
73
 
74
  def generate_text(
75
- tokenizer, model, user_prompts, system_prompt=None, **llm_kwargs
76
- ): # max_new_tokens=512, temperature=.4):
 
77
  if system_prompt is None or "":
78
  system_prompt = "You are a helpful assistant."
79
 
@@ -88,96 +41,122 @@ def generate_text(
88
  for user_prompt in user_prompts
89
  ]
90
 
91
- texts = tokenizer.apply_chat_template(
92
- messages, tokenize=False, add_generation_prompt=True
93
- )
94
 
95
- model_inputs = tokenizer(
96
- texts, return_tensors="pt", truncation=True, padding=True
97
- ).to(model.device)
98
- generated_ids = model.generate(
99
- **model_inputs,
100
- max_new_tokens=llm_kwargs.pop("max_new_tokens", 512),
101
- temperature=llm_kwargs.pop("temperature", 0.4),
102
- )
103
- generated_ids = [
104
- output_ids[len(input_ids) :]
105
- for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
106
- ]
107
- response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
108
- return response if len(user_prompts) > 1 else response[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
 
110
 
111
- def load_models(
112
- embed_model_name: str,
113
- gen_model_name: str,
114
- causal_lm: bool = False,
115
- device=None,
116
- bitsandbytesconfig=None,
117
- ):
118
  # This will take some time to run for the first time if the model(s) don't exist locally.
119
  if not device:
120
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  embedder = SentenceTransformer(
122
  embed_model_name,
123
  device=device,
124
- model_kwargs={"dtype": "float16"} if device == "cuda" else {},
125
  )
126
-
127
- if not causal_lm:
128
- tok = AutoTokenizer.from_pretrained(gen_model_name)
129
- gen = AutoModelForSeq2SeqLM.from_pretrained(
130
- gen_model_name, # device_map='auto',
131
- quantization_config=bitsandbytesconfig if bitsandbytesconfig else None,
132
- )
133
- else:
134
- tok = AutoTokenizer.from_pretrained(gen_model_name, padding_side="left")
135
- gen = AutoModelForCausalLM.from_pretrained(
136
- gen_model_name,
137
- dtype="float16", # device_map='auto',
138
- quantization_config=bitsandbytesconfig if bitsandbytesconfig else None,
139
- )
140
- gen.to(device)
141
  return embedder, tok, gen
142
 
143
 
144
  def make_query_variants(
145
- tokenizer, model, query: str, prompt: str, n: int = 3, **llm_kwargs
146
  ):
147
- instructions = f"Now give me at least {n} variations."
148
- resp = generate_text(tokenizer, model, query + instructions, prompt, **llm_kwargs)
149
-
 
 
150
  clean_resp = re.sub(r"^\d+\.\s*", "", resp, flags=re.MULTILINE).split("\n")
151
- return [query] + [q for q in clean_resp if q.strip()]
152
-
153
-
154
- def clean_rewrite_resp(resp):
155
- try:
156
- resp = json.loads(resp) # Parse JSON
157
- except json.JSONDecodeError:
158
- try:
159
- resp = literal_eval(resp) # Fallback parse
160
- except Exception:
161
- pass # Keep resp as-is if both fail
162
-
163
- # Ensure resp is a string before strip and slicing
164
- if isinstance(resp, str):
165
- resp = resp.strip()
166
- if resp:
167
- start = resp.find("{")
168
- if start != -1:
169
- end = resp[::-1].find("}")
170
- if end != -1:
171
- resp = resp[start : len(resp) - end]
172
- return clean_rewrite_resp(resp)
173
- return resp
174
 
175
 
176
  def transform_query(
177
- tokenizer, model, query: str, rewrite_prompt: str, **llm_kwargs
178
  ) -> dict:
179
  """split the query into things to search and actions to take"""
180
- resp = generate_text(tokenizer, model, query, rewrite_prompt, **llm_kwargs)
 
 
181
  try:
182
  resp = clean_rewrite_resp(resp)
183
  except:
@@ -192,6 +171,7 @@ def aggregate_queries_and_tasks(
192
  rewrite_prompt,
193
  variants_prompt,
194
  n_variations=3,
 
195
  **llm_kwargs,
196
  ):
197
  # make variations for the original query as is
@@ -201,14 +181,13 @@ def aggregate_queries_and_tasks(
201
  orig_query.strip(),
202
  variants_prompt,
203
  n_variations,
 
204
  **llm_kwargs,
 
 
 
205
  )
206
 
207
- start = time()
208
- tr_q = transform_query(tokenizer, model, orig_query.strip(), rewrite_prompt)
209
- end = time()
210
- log.debug(f"\t\t transforming query task took {(end - start):.1f} seconds...")
211
-
212
  # transformed query might have multiple things to search and tasks to perform depending on user query
213
  # recursively get variations for each of the search queries but keep the tasks as is.
214
  tasks = []
@@ -222,91 +201,77 @@ def aggregate_queries_and_tasks(
222
  search_result,
223
  variants_prompt,
224
  n_variations,
 
225
  **llm_kwargs,
226
  )
227
  )
228
 
229
- queries = [q.strip(string.punctuation) for q in queries]
230
- tasks = [t.strip(string.punctuation) for t in tasks]
231
-
232
  # keep the original user query as is (if in case LLM messes up the original query) and pick some after shuffling the rest
233
- # This is disabled as we don't do loops and instead take advantage of batches.
234
- # Since it's efficient, we can take many query variations at once without worrying about performance.
235
- # q, queries = queries[:1], queries[1:]
236
- # shuffle(queries)
237
- # q += queries[:n_variations-1]
 
238
 
239
  return queries, tasks
240
 
241
 
242
- def build_index(corpus_emb, n_cells=5, n_probe=2):
243
- log.debug(f"building index with {n_cells=}, {n_probe=}")
244
- d = corpus_emb.shape[1]
245
- quantizer = faiss.IndexFlatIP(d)
246
- index = faiss.IndexIVFFlat(quantizer, d, n_cells)
247
- index.n_probe = n_probe
248
- index.train(corpus_emb)
249
- index.add(corpus_emb)
250
- # index.make_direct_map()
251
- return index
252
-
253
-
254
- def reciprocal_rank_fusion(indices, top_k=3, denom=50):
255
- ii = indices.tolist()
256
- scores = defaultdict(int)
257
- for row in ii:
258
- for rank, chunk_id in enumerate(row):
259
- scores[chunk_id] += 1 / (rank + denom)
260
- results = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
261
- return [chunk_id for chunk_id, _ in results]
262
-
263
-
264
  class HyDeRAGFusion:
265
  def __init__(
266
  self,
267
  embed_model: str,
268
  generator_llm_model: str,
269
- causal_lm: bool = True,
270
- chunk_overlap: int = 50,
271
- tokens_per_chunk: int = 256,
272
- embed_batch_size: int = 64,
273
- bitsandbytesconfig=None,
274
  ):
275
  self.embed_batch_size = embed_batch_size
276
- self.text_splitter = SentenceTransformersTokenTextSplitter(
277
- chunk_overlap, embed_model, tokens_per_chunk
278
- )
279
  self.embedder, self.tok, self.gen = load_models(
280
- embed_model, generator_llm_model, causal_lm, bitsandbytesconfig
281
  )
 
282
  with open(PROMPTS_FILEPATH) as fl:
283
  self.prompts = yaml.safe_load(fl)
284
 
285
- def preprocess_pdfs(self, pdfs, data_load_kwargs={}, faiss_index_kwargs={}):
286
- self.corpus, self.meta = asyncio.run(
287
- build_corpus(pdfs, self.text_splitter, **data_load_kwargs)
 
288
  )
289
- self.corpus_emb = self.embedder.encode(
290
- self.corpus,
 
 
 
 
 
 
 
 
 
 
 
 
291
  batch_size=self.embed_batch_size,
292
- show_progress_bar=True,
293
- normalize_embeddings=True,
 
 
294
  )
295
 
296
- # https://github.com/facebookresearch/faiss/issues/112
297
- # n_cells = int(round(4 * (self.corpus_emb.shape[0])**.5))
298
-
299
- # one centroid for every 100 or so vectors and 20% of them as n_probe
300
- n_cells = faiss_index_kwargs.pop("n_cells", self.corpus_emb.shape[0] // 100 + 1)
301
- n_probe = faiss_index_kwargs.pop("n_probe", math.ceil(0.2 * n_cells))
302
-
303
- self.index = build_index(self.corpus_emb, n_cells, n_probe)
 
304
 
305
  def retrieve(
306
- self, query, n_variants=3, top_k_per_variant=10, top_k_retrieve=3, **llm_kwargs
307
  ):
308
- start = time()
309
-
310
  queries, tasks = aggregate_queries_and_tasks(
311
  self.tok,
312
  self.gen,
@@ -314,48 +279,38 @@ class HyDeRAGFusion:
314
  self.prompts["rewrite"],
315
  self.prompts["variants"],
316
  n_variants,
 
317
  **llm_kwargs,
318
  )
319
-
320
- end = time()
321
- log.debug(f"aggregate task took {(end - start):.1f} seconds...")
322
-
323
- start = time()
324
  hyde_docs = generate_text(
325
- self.tok, self.gen, queries, self.prompts["hyde"], **llm_kwargs
 
 
 
 
 
326
  )
327
- end = time()
328
- log.debug(f"generating hyde docs took {(end - start):.1f} seconds...")
329
-
330
- start = time()
331
  chunks = []
332
  for hyde_doc in hyde_docs:
333
  chunks.extend(self.text_splitter.split_text(hyde_doc))
334
- q_emb = self.embedder.encode(
335
- chunks, batch_size=self.embed_batch_size, normalize_embeddings=True
 
336
  )
337
- end = time()
338
- log.debug(f"embedding hyde docs took {(end - start):.1f} seconds...")
 
339
 
340
- _, I = self.index.search(q_emb, top_k_per_variant)
341
- chunk_ids = reciprocal_rank_fusion(I, top_k_retrieve)
342
- return chunk_ids, tasks
343
-
344
- def answer(self, query, doc_ids, tasks, max_ctx_chars=128000):
345
  total, text, prompt_length = 0, "", 10000
346
  sep = "\n\n-----\n\n"
347
- tasks = ", ".join(tasks)
348
-
349
- for doc_id in doc_ids:
350
- # adding tags in the context caused more hallucinations.
351
- # Instead, we list them as sources beneath the model response.
352
- # _meta = self.meta[doc_id]
353
- # tag = f"(source: {_meta['file_name']}:{_meta['chunk_id']})"
354
- chunk = self.corpus[doc_id].strip()
355
- tag = ""
356
-
357
- ctx = f"{sep}{tag}\n\n{chunk}"
358
  if total + len(ctx) + len(tasks) + len(sep) + prompt_length > max_ctx_chars:
 
359
  break
360
 
361
  text += ctx
@@ -363,41 +318,52 @@ class HyDeRAGFusion:
363
 
364
  text += f"{sep}{tasks}"
365
 
366
- # instruction = "Answer concisely and also cite file names & chunk ids inline like (pdf_file_name:chunk_id)."
367
  instruction = "go ahead and answer!"
368
  user_query = f"\nq: {query}\n\nctx:{text}" + f"\n\n{instruction}\n\n"
369
-
370
- start = time()
371
  resp = generate_text(
372
  self.tok,
373
  self.gen,
374
  user_query,
375
  self.prompts["final_answer"],
376
- temperature=0.3,
377
  )
378
- end = time()
379
- log.debug(f"final resp took {(end - start):.1f} seconds...")
380
 
381
- return resp
 
 
 
382
 
 
 
 
 
 
 
 
 
383
 
384
- @st.cache_resource
385
- def initial_setup(embed_model, generator_model, bitsandbytesconfig=None):
386
- return HyDeRAGFusion(
387
- embed_model, generator_model, bitsandbytesconfig=bitsandbytesconfig
388
- )
389
 
390
 
391
  @click.command(context_settings=dict(show_default=True))
392
  @click.option(
393
- "--embed-model",
394
- default="sentence-transformers/LaBSE",
395
- help="sentence transformers embedding model",
 
 
 
 
 
 
 
 
396
  )
397
  @click.option(
398
- "--generator-llm-model",
399
- default="Qwen/Qwen2.5-0.5B-Instruct",
400
- help="Seq2Seq or CausalLM model (preferably multi-lingual)",
401
  )
402
  @click.option("--n-variants", default=3, help="no. of query variants")
403
  @click.option(
@@ -408,97 +374,57 @@ def initial_setup(embed_model, generator_model, bitsandbytesconfig=None):
408
  @click.option(
409
  "--top-k-retrieve", default=3, help="top `k` chunks to retrieve after RRF"
410
  )
411
- @click.option("--temperature", default=0.4, help="LLM Model Temperature")
412
- @click.option("--max-new-tokens", default=512, help="LLM max tokens")
413
- @click.option(
414
- "--faiss-index-kwargs",
415
- default=dict(),
416
- help="kwargs to pass to FAISS Index such as `n_cells, n_probe`",
417
- )
418
- def main(
419
- embed_model,
420
- generator_llm_model,
421
  n_variants,
422
  top_k_per_variant,
423
  top_k_retrieve,
424
  temperature,
425
- max_new_tokens,
426
- faiss_index_kwargs,
427
  ):
428
- # bits_and_bytes_cfg = BitsAndBytesConfig(
429
- # load_in_8bit=True
430
- # )
431
- start = time()
432
- hrf = initial_setup(embed_model, generator_llm_model)
433
- end = time()
434
- msg = f"init took {(end - start):.1f} seconds"
435
- log.debug(msg)
436
- st.write(msg)
437
-
438
- st.set_page_config(page_title="RAG HYDE")
439
- st.header("Ask Questions")
440
-
441
- state = st.session_state
442
- if "uploaded_names" not in state:
443
- state.uploaded_names = []
444
-
445
- pdfs = st.file_uploader(
446
- "Upload your PDF(s)", type="pdf", accept_multiple_files=True, key="upload"
447
  )
448
- if pdfs:
449
- current_names = sorted([pdf.name for pdf in pdfs])
450
- # reinitialize if uploaded files are changed
451
- if current_names != state.uploaded_names:
452
- start = time()
453
-
454
- hrf = initial_setup(embed_model, generator_llm_model)
455
- hrf.preprocess_pdfs(
456
- pdfs, faiss_index_kwargs=literal_eval(faiss_index_kwargs)
457
- )
458
-
459
- end = time()
460
- st.write(
461
- f"corpus embeddings shape: {hrf.corpus_emb.shape}, computed in {end - start:.1f} seconds"
462
- )
463
 
464
- state.uploaded_names = current_names
465
- else:
466
- state.uploaded_names = []
467
- st.write("upload data to query")
468
 
469
- query = st.text_input("ask question").strip()
470
- if query and state.uploaded_names:
471
- start = time()
472
- llm_kwargs = {
473
- "temperature": temperature,
474
- "max_new_tokens": max_new_tokens,
475
- }
476
- doc_ids, tasks = hrf.retrieve(
477
- query,
478
  int(n_variants),
479
  int(top_k_per_variant),
480
  int(top_k_retrieve),
481
- **llm_kwargs,
482
  )
483
- docs = [hrf.corpus[doc_id] for doc_id in doc_ids]
 
 
484
  end = time()
485
- reply = hrf.answer(query, doc_ids, tasks)
486
- st.write(f"search took {(end - start):.1f} seconds")
487
- st.write(f"\n\nFinal Answer: \n{reply}\n\n")
488
- st.write("Top 3 sources:")
489
- sources = [
490
- {
491
- "source": f"{hrf.meta[doc_id]['file']}:{hrf.meta[doc_id]['chunk_id']}",
492
- "content": doc,
493
- }
494
- for doc_id, doc in zip(doc_ids, docs)
495
- ]
496
- st.json(sources[:3])
497
 
498
 
499
  if __name__ == "__main__":
500
- # faiss_index_kwargs = {
501
- # 'n_cells': 20,
502
- # 'n_probe': 8
503
- # }
504
- main()
 
 
 
 
1
  import string
2
+ import faiss
3
  import yaml
4
  import re
5
+ import sys
 
6
  import torch
 
7
  import click
8
 
 
 
9
  from time import time
10
+ from random import shuffle
11
+ from functools import lru_cache
12
  from sentence_transformers import SentenceTransformer
13
+ from transformers import AutoTokenizer, AutoModelForCausalLM
14
+ from langchain_text_splitters import MarkdownTextSplitter
15
+
16
+ from src.config import PROMPTS_FILEPATH, MODEL_COMBOS, log
17
+ from src.utils import (
18
+ empty_cache,
19
+ find_think_tag_in_each_row,
20
+ build_corpus,
21
+ reciprocal_rank_fusion,
22
+ clean_rewrite_resp,
23
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
  def generate_text(
27
+ tokenizer, model, user_prompts, system_prompt=None, model_name="", **llm_kwargs
28
+ ):
29
+ assert model_name, "pass on generative model name"
30
  if system_prompt is None or "":
31
  system_prompt = "You are a helpful assistant."
32
 
 
41
  for user_prompt in user_prompts
42
  ]
43
 
44
+ if "mlx" in model_name.lower() and sys.platform == "darwin":
45
+ from mlx_lm import generate
 
46
 
47
+ texts = [
48
+ tokenizer.apply_chat_template(
49
+ message,
50
+ tokenize=False,
51
+ add_generation_prompt=True,
52
+ enable_thinking=False,
53
+ )
54
+ for message in messages
55
+ ]
56
+ responses = [
57
+ generate(
58
+ model,
59
+ tokenizer,
60
+ prompt=text,
61
+ verbose=False,
62
+ max_tokens=llm_kwargs.pop("max_new_tokens", 32768),
63
+ )
64
+ for text in texts
65
+ ]
66
+ else:
67
+ texts = tokenizer.apply_chat_template(
68
+ messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
69
+ )
70
+ model_inputs = tokenizer(
71
+ texts, return_tensors="pt", truncation=True, padding=True
72
+ ).to(model.device)
73
+
74
+ with torch.no_grad():
75
+ generated_ids = model.generate(
76
+ **model_inputs,
77
+ max_new_tokens=llm_kwargs.pop("max_new_tokens", 32768),
78
+ temperature=llm_kwargs.pop("temperature", 0.7),
79
+ top_p=llm_kwargs.pop("top_p", 0.8),
80
+ top_k=llm_kwargs.pop("top_k", 20),
81
+ min_p=llm_kwargs.pop("min_p", 0),
82
+ **llm_kwargs,
83
+ )
84
+
85
+ output_ids = generated_ids[:, model_inputs.input_ids.shape[1] :]
86
+ idxs = find_think_tag_in_each_row(output_ids)
87
+ thinking_contents = [
88
+ tokenizer.decode(output_ids[i][:idx], skip_special_tokens=True).strip("\n")
89
+ for i, idx in enumerate(idxs)
90
+ ]
91
+ contents = [
92
+ tokenizer.decode(output_ids[i][idx:], skip_special_tokens=True).strip("\n")
93
+ for i, idx in enumerate(idxs)
94
+ ]
95
+ responses = [
96
+ f"{think_resp}{cont}"
97
+ for think_resp, cont in zip(thinking_contents, contents)
98
+ ]
99
 
100
+ return responses[0] if len(user_prompts) == 1 else responses
101
 
102
+
103
+ def load_models(embed_model_name: str, gen_model_name: str, device: str = None):
 
 
 
 
 
104
  # This will take some time to run for the first time if the model(s) don't exist locally.
105
  if not device:
106
+ if torch.cuda.is_available():
107
+ device = "cuda"
108
+ elif torch.mps.is_available():
109
+ device = "mps"
110
+ else:
111
+ device = "cpu"
112
+
113
+ dtype = torch.bfloat16 if device == "cuda" else torch.float16
114
+ if device != "mps" or (device == "mps" and "mlx" not in gen_model_name.lower()):
115
+ tok = AutoTokenizer.from_pretrained(gen_model_name, padding_side="left")
116
+ # sometimes loading an AWQ model on my local machine fails for the first time
117
+ try:
118
+ gen = AutoModelForCausalLM.from_pretrained(
119
+ gen_model_name, dtype=dtype, device_map=device
120
+ ).eval()
121
+ except ImportError:
122
+ gen = AutoModelForCausalLM.from_pretrained(
123
+ gen_model_name, dtype=dtype, device_map=device
124
+ ).eval()
125
+ else:
126
+ from mlx_lm import load
127
+
128
+ gen, tok = load(gen_model_name)
129
+
130
  embedder = SentenceTransformer(
131
  embed_model_name,
132
  device=device,
133
+ model_kwargs={"dtype": dtype},
134
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  return embedder, tok, gen
136
 
137
 
138
  def make_query_variants(
139
+ tokenizer, model, query: str, prompt: str, n: int = 3, model_name="", **llm_kwargs
140
  ):
141
+ # instructions = f"\n\n(Now give me at least {n} diverse variations of user query in the same language as the user provided query)"
142
+ # query += instructions
143
+ resp = generate_text(
144
+ tokenizer, model, query.format(n=n), prompt, model_name=model_name, **llm_kwargs
145
+ )
146
  clean_resp = re.sub(r"^\d+\.\s*", "", resp, flags=re.MULTILINE).split("\n")
147
+ queries = [q.strip() for q in clean_resp if q.strip()]
148
+ return [query.lower().strip()] + sorted(
149
+ set(map(lambda x: str.lower(x).strip(), queries))
150
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
 
153
  def transform_query(
154
+ tokenizer, model, query: str, rewrite_prompt: str, model_name="", **llm_kwargs
155
  ) -> dict:
156
  """split the query into things to search and actions to take"""
157
+ resp = generate_text(
158
+ tokenizer, model, query, rewrite_prompt, model_name=model_name, **llm_kwargs
159
+ )
160
  try:
161
  resp = clean_rewrite_resp(resp)
162
  except:
 
171
  rewrite_prompt,
172
  variants_prompt,
173
  n_variations=3,
174
+ gen_model_name="",
175
  **llm_kwargs,
176
  ):
177
  # make variations for the original query as is
 
181
  orig_query.strip(),
182
  variants_prompt,
183
  n_variations,
184
+ gen_model_name,
185
  **llm_kwargs,
186
+ )[: n_variations + 1]
187
+ tr_q = transform_query(
188
+ tokenizer, model, orig_query.strip(), rewrite_prompt, gen_model_name
189
  )
190
 
 
 
 
 
 
191
  # transformed query might have multiple things to search and tasks to perform depending on user query
192
  # recursively get variations for each of the search queries but keep the tasks as is.
193
  tasks = []
 
201
  search_result,
202
  variants_prompt,
203
  n_variations,
204
+ gen_model_name,
205
  **llm_kwargs,
206
  )
207
  )
208
 
 
 
 
209
  # keep the original user query as is (if in case LLM messes up the original query) and pick some after shuffling the rest
210
+ q, qq = queries[0], queries[1:]
211
+ shuffle(qq)
212
+ queries = [q] + sorted(
213
+ set(map(lambda x: str.lower(x).strip(string.punctuation), qq[:n_variations]))
214
+ )
215
+ tasks = sorted(set(map(lambda x: str.lower(x).strip(string.punctuation), tasks)))
216
 
217
  return queries, tasks
218
 
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  class HyDeRAGFusion:
221
  def __init__(
222
  self,
223
  embed_model: str,
224
  generator_llm_model: str,
225
+ embed_batch_size: int = 8,
 
 
 
 
226
  ):
227
  self.embed_batch_size = embed_batch_size
228
+ self.gen_model_name = generator_llm_model
 
 
229
  self.embedder, self.tok, self.gen = load_models(
230
+ embed_model, generator_llm_model
231
  )
232
+ self.text_splitter = MarkdownTextSplitter(chunk_overlap=450, chunk_size=3000)
233
  with open(PROMPTS_FILEPATH) as fl:
234
  self.prompts = yaml.safe_load(fl)
235
 
236
+ def get_embeddings(self, texts, **kwargs):
237
+ log.debug(f"batching size: {len(texts)} aka {self.embed_batch_size}")
238
+ return self.embedder.encode(
239
+ texts, batch_size=self.embed_batch_size, normalize_embeddings=True, **kwargs
240
  )
241
+
242
+ @lru_cache(maxsize=2)
243
+ def preprocess_pdfs(self, pdfs, **data_load_kwargs):
244
+ log.debug(f"\n\n{'@@@@' * 20}\n\n preprocessing {pdfs=}")
245
+ empty_cache()
246
+ self.dataset = build_corpus(pdfs, self.text_splitter, **data_load_kwargs)
247
+ empty_cache()
248
+ self.dataset = self.dataset.map(
249
+ lambda x: {
250
+ "embeddings": self.get_embeddings(
251
+ x["chunk"], prompt_name="query", show_progress_bar=False
252
+ )
253
+ },
254
+ batched=True,
255
  batch_size=self.embed_batch_size,
256
+ )
257
+ empty_cache()
258
+ self.dataset.add_faiss_index(
259
+ "embeddings", metric_type=faiss.METRIC_INNER_PRODUCT
260
  )
261
 
262
+ def get_filtered_entries(self, idxs):
263
+ # We need to drop the index before filtering/selecting the desired indices and re-add it later
264
+ # Since it's FAISS and we index very little data, it's not noticeable
265
+ self.dataset.drop_index("embeddings")
266
+ entries = self.dataset.select(idxs)
267
+ self.dataset.add_faiss_index(
268
+ "embeddings", metric_type=faiss.METRIC_INNER_PRODUCT
269
+ )
270
+ return entries
271
 
272
  def retrieve(
273
+ self, query, n_variants=3, top_k_per_variant=5, top_k_retrieve=3, **llm_kwargs
274
  ):
 
 
275
  queries, tasks = aggregate_queries_and_tasks(
276
  self.tok,
277
  self.gen,
 
279
  self.prompts["rewrite"],
280
  self.prompts["variants"],
281
  n_variants,
282
+ self.gen_model_name,
283
  **llm_kwargs,
284
  )
 
 
 
 
 
285
  hyde_docs = generate_text(
286
+ self.tok,
287
+ self.gen,
288
+ queries,
289
+ self.prompts["hyde"],
290
+ self.gen_model_name,
291
+ **llm_kwargs,
292
  )
 
 
 
 
293
  chunks = []
294
  for hyde_doc in hyde_docs:
295
  chunks.extend(self.text_splitter.split_text(hyde_doc))
296
+ q_emb = self.get_embeddings(chunks)
297
+ matches = self.dataset.get_nearest_examples_batch(
298
+ "embeddings", q_emb, top_k_per_variant
299
  )
300
+ indices = [match["id"] for match in matches.total_examples]
301
+ top_idxs = reciprocal_rank_fusion(indices, top_k_retrieve)
302
+ return top_idxs, tasks
303
 
304
+ def answer(self, query, idxs, tasks, max_ctx_chars=32768):
 
 
 
 
305
  total, text, prompt_length = 0, "", 10000
306
  sep = "\n\n-----\n\n"
307
+ tasks = ", ".join(tasks) if tasks else ""
308
+ log.debug("filtering entries")
309
+ entries = self.get_filtered_entries(idxs)
310
+ for chunk in entries["chunk"]:
311
+ ctx = f"{sep}\n\n{chunk}"
 
 
 
 
 
 
312
  if total + len(ctx) + len(tasks) + len(sep) + prompt_length > max_ctx_chars:
313
+ log.warning("context overflow")
314
  break
315
 
316
  text += ctx
 
318
 
319
  text += f"{sep}{tasks}"
320
 
 
321
  instruction = "go ahead and answer!"
322
  user_query = f"\nq: {query}\n\nctx:{text}" + f"\n\n{instruction}\n\n"
 
 
323
  resp = generate_text(
324
  self.tok,
325
  self.gen,
326
  user_query,
327
  self.prompts["final_answer"],
328
+ self.gen_model_name,
329
  )
 
 
330
 
331
+ sources = ""
332
+ for idx, entry in enumerate(entries):
333
+ source = f'<h2 style="color: cyan;">Source {idx + 1} :: {entry["file"]}:{entry["chunk_id"]}</h2>'
334
+ sources += f"{sep}{source}\n\n{entry['chunk']}"
335
 
336
+ return resp, sources.replace("```", "`")
337
+
338
+
339
+ def initial_setup(model_combo_key):
340
+ models = MODEL_COMBOS[model_combo_key]
341
+ hrf = HyDeRAGFusion(models["embed_model"], models["gen_model"])
342
+ hrf._model_combo_key = model_combo_key
343
+ return hrf
344
 
345
+
346
+ HRF = None
 
 
 
347
 
348
 
349
  @click.command(context_settings=dict(show_default=True))
350
  @click.option(
351
+ "--pdfs",
352
+ multiple=True,
353
+ type=click.Path(exists=True),
354
+ help="list of PDF filepaths to extract text from",
355
+ )
356
+ @click.option("--query", help="user query")
357
+ @click.option(
358
+ "--model-combo-key",
359
+ type=click.Choice(["linux", "HF-mid"]),
360
+ default="linux",
361
+ help="embedder and generator llm models combination to load (see config.py)",
362
  )
363
  @click.option(
364
+ "--fast-extract/--no-fast-extract",
365
+ default=False,
366
+ help="Extract markdown text quickly (uses pymupdf if set, else docling if available)",
367
  )
368
  @click.option("--n-variants", default=3, help="no. of query variants")
369
  @click.option(
 
374
  @click.option(
375
  "--top-k-retrieve", default=3, help="top `k` chunks to retrieve after RRF"
376
  )
377
+ @click.option("--temperature", default=0.7, help="LLM Model Temperature")
378
+ def ask(
379
+ pdfs,
380
+ query,
381
+ model_combo_key,
382
+ fast_extract,
 
 
 
 
383
  n_variants,
384
  top_k_per_variant,
385
  top_k_retrieve,
386
  temperature,
 
 
387
  ):
388
+ pdfs = tuple(sorted(pdfs))
389
+ log.debug(
390
+ f"{pdfs=}, {query=}, {model_combo_key=}, {fast_extract=}, {n_variants=}, {top_k_per_variant=}, {top_k_retrieve=}, {temperature=}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  )
392
+ global HRF
393
+ if HRF is None or HRF._model_combo_key != model_combo_key:
394
+ if HRF is not None:
395
+ log.debug("deleting HRF object")
396
+ del HRF
397
+ log.debug("emptying cache")
398
+ empty_cache()
399
+ log.debug(f"\n\n{'=:-:' * 20}\n\n initializing")
400
+ start = time()
401
+ HRF = initial_setup(model_combo_key)
402
+ end = time()
403
+ msg = f"init took {(end - start):.1f} seconds"
404
+ log.debug(msg)
 
 
405
 
406
+ start = time()
407
+ if pdfs:
408
+ HRF.preprocess_pdfs(pdfs, fast_extract=fast_extract)
 
409
 
410
+ if query and query.strip():
411
+ top_idxs, tasks = HRF.retrieve(
412
+ query.strip(),
 
 
 
 
 
 
413
  int(n_variants),
414
  int(top_k_per_variant),
415
  int(top_k_retrieve),
416
+ temperature=temperature,
417
  )
418
+
419
+ log.debug("retrieving")
420
+ resp, sources = HRF.answer(query, top_idxs, tasks)
421
  end = time()
422
+ final_response = f"\nSearch took {(end - start):.1f} seconds\n\n{resp}{sources}"
423
+ log.debug(final_response)
424
+ return final_response
425
+
426
+ return ""
 
 
 
 
 
 
 
427
 
428
 
429
  if __name__ == "__main__":
430
+ ask()
 
 
 
 
src/prompts.yaml CHANGED
@@ -1,7 +1,7 @@
1
  rewrite: >
2
  You are a professional content writer and editor who deeply pays attention to user's query & intention. You strictly reply ONLY in JSON.
3
 
4
- Your mission is to analyze user input and intention and transform it in the same language as the input query to make it search engine optimised by determining the appropriate context.
5
 
6
  The user input can be a query or a statement. There can be multiple of them. And sometimes the input also contains actions to be taken depending on the query/statement.
7
 
@@ -109,7 +109,7 @@ rewrite: >
109
  -----------
110
 
111
 
112
- Remember to not answer the user's question but only transform it and in the same language as given.
113
 
114
 
115
  user:
@@ -117,10 +117,10 @@ rewrite: >
117
 
118
  variants: >
119
  You are a multilingual professional content writer and editor who deeply pays attention to user's query & intention.
120
-
121
- Your goal is to transform the given query into diverse search queries keeping the user's context & intention in mind.
122
 
123
- You MUST respond in the same language as the user query which need not always be English.
 
 
124
 
125
  You MUST respond with only what's asked. Avoid explanations or verbose information of your actions.
126
 
@@ -130,59 +130,59 @@ variants: >
130
  --------------
131
 
132
 
133
- user: "EBITDA last quarter"
134
 
135
  assistant:
136
- "What was the EBITDA for the quarter ending March?",
137
- "How has the company's EBITDA performance changed from the previous quarter?",
138
- "What is the current trend of EBITDA growth over the past few quarters?",
139
- "Which companies have had similar EBITDA performance recently?",
140
- "What factors might be influencing the changes in EBITDA?",
141
 
142
 
143
- user: "what are the growing concerns of the middle class?"
144
 
145
  assistant:
146
- "How are the economic challenges impacting the middle class?",
147
- "What are the social and political pressures on the middle class?",
148
- "What are the long-term implications for the middle class's well-being?",
149
- "What are the current trends and future prospects for the middle class"
150
 
151
 
152
- user: "Capital of France"
153
 
154
  assistant:
155
- "What is the capital city of France?",
156
- "How is Paris known internationally?",
157
- "Where is Paris located on the map?",
158
 
159
 
160
- user: "EBITDA letztes Quartal"
161
 
162
  assistant:
163
- "Wie hoch war das EBITDA für das im März endende Quartal?",
164
- "Wie hat sich die EBITDA-Performance des Unternehmens gegenüber dem Vorquartal verändert?",
165
- "Wie ist der aktuelle Trend des EBITDA-Wachstums in den letzten Quartalen?",
166
- "Welche Unternehmen hatten in letzter Zeit eine ähnliche EBITDA-Entwicklung?",
167
- "Welche Faktoren könnten die Veränderungen des EBITDA beeinflussen?",
168
 
169
 
170
- user: "Was sind die wachsenden Sorgen der Mittelschicht?"
171
 
172
  assistant:
173
- "Wie wirken sich die wirtschaftlichen Herausforderungen auf die Mittelschicht aus?",
174
- "Welchen sozialen und politischen Druck erlebt die Mittelschicht?",
175
- "Was sind die langfristigen Auswirkungen auf das Wohlergehen der Mittelschicht?",
176
- "Was sind die aktuellen Trends und Zukunftsaussichten für die Mittelschicht?"
177
 
178
 
179
- user: "Hauptstadt von Frankreich"
180
 
181
  assistant:
182
- "Was ist die Hauptstadt von Frankreich?",
183
- "Wie ist Paris international bekannt?",
184
- "Wo liegt Paris auf der Karte?",
185
-
 
186
 
187
  --------------
188
 
@@ -193,14 +193,14 @@ variants: >
193
  hyde: >
194
  You are a professional editor at a prestigious international media organization.
195
  Given user's query, write a neutral, self-contained paragraph ABSOLUTELY GROUNDED IN FACTS and established sources. Avoid fluff. Include likely key terms and entities. 120-180 words.
196
- You write content in the same language as the user query which need not always be in English.
197
 
198
  Examples:
199
  --------
200
 
201
- user: Quelle est le niveau actuel de l'engagement de Deutsche Telekom avec le développement durable
202
 
203
- assistant: Deutsche Telekom, la plus grande entreprise d'électricité et d'information au monde, a lancé un programme d'engagement durable en 2015 pour atteindre ses objectifs de développement durable. Ce programme visait à réduire son émission de gaz à effet de serre (CO2) par rapport aux niveaux de 1990, ainsi que les émissions de déchets et de produits chimiques. Le groupe a mis en place des politiques d'énergie basées sur la transition verte, comme la production de biogaz, l'utilisation de technologies solaires et l'innovation dans la gestion des ressources naturelles. L'objectif principal était de se démarquer du marché mondial en termes de performance énergétique et environnementale.
204
 
205
 
206
  user: BMW Group expansion into southern Asia
@@ -223,8 +223,8 @@ final_answer: >
223
  You are a journalist at a media organization. Your main specializations include fact checking, accurate information retrieval from sources among others.
224
 
225
  YOU ALWAYS ADHERE TO THE FOLLOWING INSTRUCTIONS:
226
- - When given a user query `q` and a context `ctx`, your goal is to answer `q` FROM ONLY WITHIN the given context `ctx` and add citations where applicable.
227
- - You reply in the same language as user input which need not be always English.
228
  - You do not state anything that is not present within `ctx`. NEVER GUESS.
229
  - ALWAYS GROUND YOUR TRUTH based only on what was provided within the context `ctx`.
230
  - If you believe `q` has nothing to do with `ctx`, simply state "I don't know" (or its equivalent in the user query language) instead of guessing.
@@ -238,7 +238,7 @@ final_answer: >
238
  ctx:
239
  -----
240
 
241
- this purpose. This will enable us to guarantee transparency and comparability in the validation and measurement of our targets and, at the same time, ensure they are in line with the latest scientific findings. ↗ Carbon emissions ↗ Control parameters such as ↗ carbon emissions over the entire prod - uct life cycle are important ↗ Performance indicators during the de - velopment phase of our vehicle projects. The Board of Manage - ment receives and discusses a status report on sustainability every quarter and derives appropriate measures as required. The BMW Group is actively working on numerous projects and initiatives to improve the framework conditions for electromobil- ity, including the expansion of charging infrastructure on a broad basis. The ambitious goals of the Paris Climate Agreement are designed to tackle climate change in the transport sector, requir - ing a combination of modern drive technologies that are closely aligned with customer needs and different mobility requirements around the world. In addition to all - electric models, plug - in hybrids and modern combustion engine technologies also make an im - portant contribution to the reduction of global CO2 emissions. The BMW Group is also continuously forging ahead with its work with hydrogen. ↗ Products ESG criteria are built into individual market strategies across our global organisation. Best practices in the fields of environmental protection, social sustainability, corporate citizenship and gov
242
 
243
  go ahead and answer!
244
 
@@ -267,8 +267,7 @@ final_answer: >
267
  Zur Erhöhung der Transparenz und Effektivität kooperiert sie verstärkt mit öffentlichen und privaten Organisationen.
268
  Sicherheit ist durch das „Security by Design“-Prinzip fester Bestandteil im Entwicklungsprozess neuer Produkte und Informationssysteme.
269
  Es werden intensive und obligatorische digitale Sicherheitstests durchgeführt, um Schwachstellen systematisch aufzudecken.
270
- Alle sicherheitsrelevanten Abteilungen wurden unter dem Dach der Deutschen Telekom Security zusammengeführt. Mit diesem End-to-End-Sicherheitsportfolio zielt das Unternehmen darauf ab, Marktanteile zu gewinnen und im Rahmen der Megatrends Internet der Dinge und Industrie 4.0 neue Sicherheitskonzepte zu etablieren.
271
- Zudem wird das Partner-Ökosystem im Bereich Cybersicherheit kontinuierlich ausgebaut, und auf der Unternehmenswebsite wird fortlaufend über aktuelle Entwicklungen in Datenschutz und Datensicherheit berichtet.
272
 
273
 
274
  =====
 
1
  rewrite: >
2
  You are a professional content writer and editor who deeply pays attention to user's query & intention. You strictly reply ONLY in JSON.
3
 
4
+ Your mission is to analyze user input and intention and transform it **IN THE SAME LANGUAGE** as the input query to make it search engine optimised by determining the appropriate context.
5
 
6
  The user input can be a query or a statement. There can be multiple of them. And sometimes the input also contains actions to be taken depending on the query/statement.
7
 
 
109
  -----------
110
 
111
 
112
+ Remember to **NOT** answer the user's question but only transform it and **IN THE SAME LANGUAGE** as given.
113
 
114
 
115
  user:
 
117
 
118
  variants: >
119
  You are a multilingual professional content writer and editor who deeply pays attention to user's query & intention.
 
 
120
 
121
+ Your goal is to transform the given query into at least {n} diverse queries as long as they're related to the topic of the original query.
122
+
123
+ You MUST ALWAYS respond **IN THE SAME LANGUAGE** as the user.
124
 
125
  You MUST respond with only what's asked. Avoid explanations or verbose information of your actions.
126
 
 
130
  --------------
131
 
132
 
133
+ user: EBITDA last quarter
134
 
135
  assistant:
136
+ What was the EBITDA for the quarter ending March?
137
+ How has the company's performance changed from the previous quarter?
138
+ What is the current trend of EBITDA growth over the past few quarters?
139
+ Which companies have had similar EBITDA performance recently?
140
+ What factors might be influencing the changes in EBITDA?
141
 
142
 
143
+ user: what are the growing concerns of the middle class?
144
 
145
  assistant:
146
+ How are the economic challenges impacting the middle class?
147
+ How's middle class people coping up with social and political pressures?
148
+ What are the long-term implications for the middle class's well-being?
149
+ What are the current trends and future prospects for the middle class?
150
 
151
 
152
+ user: Capital of France
153
 
154
  assistant:
155
+ What is the capital city of France?
156
+ How is Paris known internationally?
157
+ Where is the capital of France located?
158
 
159
 
160
+ user: Wer ist der aktuelle CEO dieses Unternehmens?
161
 
162
  assistant:
163
+ Wer ist der Geschäftsführer dieses Unternehmens?
164
+ Wie sieht das Organigramm des Unternehmens aus?
165
+ Wann wurde der aktuelle CEO ernannt?
166
+ Wer war in der Vergangenheit CEO dieses Unternehmens?
 
167
 
168
 
169
+ user: Wie reitet man ein Pferd?
170
 
171
  assistant:
172
+ Wo fängt man an, ein Pferd zu lernen?
173
+ Wo kann man Reiten lernen?
174
+ Grundlagen des Reitens
175
+ Wo kann ich Reiten lernen?
176
 
177
 
178
+ user: Effets de la gravité dans l'espace
179
 
180
  assistant:
181
+ Comment fonctionne la gravité ?
182
+ Qu'est-ce que la gravité ?
183
+ Quelle sensation procure l'apesanteur dans l'espace ?
184
+ Comment la gravité affecte-t-elle l'espace-temps ?
185
+ L'attraction gravitationnelle des objets dans l'espace
186
 
187
  --------------
188
 
 
193
  hyde: >
194
  You are a professional editor at a prestigious international media organization.
195
  Given user's query, write a neutral, self-contained paragraph ABSOLUTELY GROUNDED IN FACTS and established sources. Avoid fluff. Include likely key terms and entities. 120-180 words.
196
+ You MUST ALWAYS write content in the same language as the user query.
197
 
198
  Examples:
199
  --------
200
 
201
+ user: Quelle est le niveau actuel de l'engagement de Lego avec le développement durable
202
 
203
+ assistant: Lego, la plus grande entreprise d'électricité et d'information au monde, a lancé un programme d'engagement durable en 2015 pour atteindre ses objectifs de développement durable. Ce programme visait à réduire son émission de gaz à effet de serre (CO2) par rapport aux niveaux de 1990, ainsi que les émissions de déchets et de produits chimiques. Le groupe a mis en place des politiques d'énergie basées sur la transition verte, comme la production de biogaz, l'utilisation de technologies solaires et l'innovation dans la gestion des ressources naturelles. L'objectif principal était de se démarquer du marché mondial en termes de performance énergétique et environnementale.
204
 
205
 
206
  user: BMW Group expansion into southern Asia
 
223
  You are a journalist at a media organization. Your main specializations include fact checking, accurate information retrieval from sources among others.
224
 
225
  YOU ALWAYS ADHERE TO THE FOLLOWING INSTRUCTIONS:
226
+ - When given a user query `q` and a context `ctx`, your goal is to answer `q` FROM ONLY WITHIN the given context `ctx`.
227
+ - You MUST ALWAYS reply in the same language as user's.
228
  - You do not state anything that is not present within `ctx`. NEVER GUESS.
229
  - ALWAYS GROUND YOUR TRUTH based only on what was provided within the context `ctx`.
230
  - If you believe `q` has nothing to do with `ctx`, simply state "I don't know" (or its equivalent in the user query language) instead of guessing.
 
238
  ctx:
239
  -----
240
 
241
+ This will enable us to guarantee transparency and comparability in the validation and measurement of our targets and, at the same time, ensure they are in line with the latest scientific findings. ↗ Carbon emissions ↗ Control parameters such as ↗ carbon emissions over the entire prod - uct life cycle are important ↗ Performance indicators during the de - velopment phase of our vehicle projects. The Board of Manage - ment receives and discusses a status report on sustainability every quarter and derives appropriate measures as required. The BMW Group is actively working on numerous projects and initiatives to improve the framework conditions for electromobil- ity, including the expansion of charging infrastructure on a broad basis. The ambitious goals of the Paris Climate Agreement are designed to tackle climate change in the transport sector, requir - ing a combination of modern drive technologies that are closely aligned with customer needs and different mobility requirements around the world. In addition to all - electric models, plug - in hybrids and modern combustion engine technologies also make an im - portant contribution to the reduction of global CO2 emissions. The BMW Group is also continuously forging ahead with its work with hydrogen. ↗ Products ESG criteria are built into individual market strategies across our global organisation. Best practices in the fields of environmental protection, social sustainability, corporate citizenship and gov
242
 
243
  go ahead and answer!
244
 
 
267
  Zur Erhöhung der Transparenz und Effektivität kooperiert sie verstärkt mit öffentlichen und privaten Organisationen.
268
  Sicherheit ist durch das „Security by Design“-Prinzip fester Bestandteil im Entwicklungsprozess neuer Produkte und Informationssysteme.
269
  Es werden intensive und obligatorische digitale Sicherheitstests durchgeführt, um Schwachstellen systematisch aufzudecken.
270
+ Alle sicherheitsrelevanten Abteilungen wurden unter dem Dach der Deutschen Telekom Security zusammengeführt. Mit diesem End-to-End-Sicherheitsportfolio zielt das Unternehmen darauf ab, Marktanteile zu gewinnen und im Rahmen der Megatrends Internet der Dinge und Industrie 4.0 neue Sicherheitskonzepte zu etablieren. Zudem wird das Partner-Ökosystem im Bereich Cybersicherheit kontinuierlich ausgebaut, und auf der Unternehmenswebsite wird fortlaufend über aktuelle Entwicklungen in Datenschutz und Datensicherheit berichtet.
 
271
 
272
 
273
  =====
src/utils.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import json
3
+ import torch
4
+ import pymupdf, pymupdf4llm
5
+
6
+ from ast import literal_eval
7
+ from pathlib import Path
8
+ from datasets import Dataset
9
+ from collections import defaultdict
10
+ from docling.document_converter import DocumentConverter
11
+
12
+ from src.config import log
13
+ from pathlib import Path
14
+
15
+
16
+ def empty_cache():
17
+ gc.collect()
18
+ if torch.cuda.is_available():
19
+ torch.cuda.empty_cache()
20
+ elif torch.mps.is_available():
21
+ torch.mps.empty_cache()
22
+
23
+
24
+ def extract_text(file, markdown=False, backend="pymupdf", **kwargs):
25
+ if backend == "pymupdf":
26
+ if not markdown:
27
+ with pymupdf.open(file, filetype="pdf") as doc:
28
+ return "\n".join(page.get_text(**kwargs) for page in doc)
29
+ else:
30
+ log.debug("\n\n using pymupdf4llm \n\n")
31
+ return pymupdf4llm.to_markdown(file, show_progress=True, **kwargs)
32
+
33
+ elif backend == "docling":
34
+ converter = DocumentConverter(allowed_formats=["pdf"])
35
+ doc = converter.convert(file, **kwargs).document
36
+ res = doc.export_to_markdown() if markdown else doc.export_to_text()
37
+ del converter, doc
38
+ empty_cache()
39
+ return res
40
+
41
+
42
+ def _load_pdf_sync(file, markdown=True, fast=False, **kwargs):
43
+ """Synchronous PDF loading function for thread pool execution"""
44
+ text = extract_text(
45
+ file,
46
+ markdown,
47
+ backend="docling"
48
+ if ((not fast) and (torch.cuda.is_available() or torch.mps.is_available()))
49
+ else "pymupdf",
50
+ **kwargs,
51
+ )
52
+
53
+ return (Path(file).stem, text)
54
+
55
+
56
+ def load_pdfs(files, markdown=True, fast_extract=False, **kwargs):
57
+ """
58
+ Load multiple PDF files
59
+
60
+ Args:
61
+ files: PDF filepaths
62
+ markdown: whether to extract text in markdown
63
+ fast_extract: whether to use pymupdf to extract text in markdown
64
+ Returns:
65
+ list: List of tuples containing (filename, extracted_text)
66
+ """
67
+ # # Use ThreadPoolExecutor to run synchronous operations concurrently
68
+ # loop = asyncio.get_event_loop()
69
+
70
+ # # Create executor with limited workers
71
+ # with ThreadPoolExecutor(max_workers=max_concurrence) as executor:
72
+ # # Submit all PDF processing tasks
73
+ # futures = [
74
+ # loop.run_in_executor(executor, _load_pdf_sync, file, markdown, fast_extract, **kwargs) for file in files if file is not None
75
+ # ]
76
+
77
+ # results = await asyncio.gather(*futures, return_exceptions=True)
78
+
79
+ # valid_results = [result for result in results if not isinstance(result, Exception)]
80
+
81
+ # log.debug(f"Successfully processed {len(valid_results)} out of {len(files)} PDFs")
82
+ # return valid_results
83
+
84
+ results = []
85
+ for file in files:
86
+ results.append(_load_pdf_sync(file, markdown, fast_extract, **kwargs))
87
+ return results
88
+
89
+
90
+ def find_think_tag_in_each_row(tensor):
91
+ # look for `</think>` tag
92
+ res = dict((tensor == 151668).nonzero().tolist())
93
+ if not res:
94
+ return [0] * len(tensor)
95
+ idxs = []
96
+ for idx in range(len(tensor)):
97
+ idxs.append(res.get(idx, -1))
98
+ return [x + 1 for x in idxs]
99
+
100
+
101
+ def build_corpus(pdfs, text_splitter, **load_pdf_kwargs):
102
+ texts = load_pdfs(pdfs, **load_pdf_kwargs)
103
+ corpus_with_meta = []
104
+ _id = 0
105
+ for file_name, raw_text in texts:
106
+ chunks = text_splitter.split_text(raw_text)
107
+ for idx, chunk in enumerate(chunks):
108
+ corpus_with_meta.append(
109
+ {
110
+ "id": _id,
111
+ "file": Path(file_name).stem,
112
+ "chunk_id": idx,
113
+ "chunk": chunk,
114
+ }
115
+ )
116
+ _id += 1
117
+ return Dataset.from_list(corpus_with_meta)
118
+
119
+
120
+ def reciprocal_rank_fusion(indices, top_k=3, denom=50):
121
+ scores = defaultdict(int)
122
+ for row in indices:
123
+ for rank, idx in enumerate(row):
124
+ scores[idx] += 1 / (rank + denom)
125
+ results = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
126
+ return [idx for idx, _ in results]
127
+
128
+
129
+ def clean_rewrite_resp(resp):
130
+ try:
131
+ resp = json.loads(resp) # Parse JSON
132
+ except json.JSONDecodeError:
133
+ try:
134
+ resp = literal_eval(resp) # Fallback parse
135
+ except Exception:
136
+ pass # Keep resp as-is if both fail
137
+
138
+ # Ensure resp is a string before strip and slicing
139
+ if isinstance(resp, str):
140
+ resp = resp.strip()
141
+ if resp:
142
+ start = resp.find("{")
143
+ if start != -1:
144
+ end = resp[::-1].find("}")
145
+ if end != -1:
146
+ resp = resp[start : len(resp) - end]
147
+ return clean_rewrite_resp(resp)
148
+ return resp