Spaces:
Sleeping
Sleeping
Vivek Vaddina
commited on
♻️ Refactor to improve model performance
Browse files- README.md +68 -0
- app.py +117 -446
- requirements.txt +4 -0
- src/config.py +33 -1
- src/hyde_rag.py +0 -206
- src/main.py +249 -323
- src/prompts.yaml +43 -44
- src/utils.py +148 -0
README.md
CHANGED
|
@@ -11,3 +11,71 @@ short_description: answer based on input documents
|
|
| 11 |
---
|
| 12 |
|
| 13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 14 |
+
|
| 15 |
+
# RAG HYDE
|
| 16 |
+
|
| 17 |
+
This challenge is to build a minimal but powerful Retrieval-Augmented Generation (RAG) workflow inspired by the techniques in the articles already shared.
|
| 18 |
+
|
| 19 |
+
Your solution should:
|
| 20 |
+
|
| 21 |
+
* Ingest and chunk local PDFs efficiently.
|
| 22 |
+
* Use a small, fast embedding model for retrieval.
|
| 23 |
+
* Apply HyDE to improve query relevance.
|
| 24 |
+
* Fuse multiple query variants with Reciprocal Rank Fusion for higher accuracy.
|
| 25 |
+
* Generate concise, context-grounded answers with clear source citations.
|
| 26 |
+
|
| 27 |
+
Goal: Deliver a working RAG example that is fast, lightweight, and high-quality in both retrieval and final answers.
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
## Description
|
| 31 |
+
|
| 32 |
+
### Approach
|
| 33 |
+
- Get the folder containing data
|
| 34 |
+
- Build Corpus & Index
|
| 35 |
+
- Get the user query
|
| 36 |
+
- transform it to separate search & intent (Query Rewrite)
|
| 37 |
+
- generate hypothetical documents from a (short) user query
|
| 38 |
+
- get their corresponding embeddings
|
| 39 |
+
- for each of those embeddings, get relevant results from the corpus
|
| 40 |
+
- fuse them all together using Reciprocal Rank Fusion
|
| 41 |
+
- extract top relevant results
|
| 42 |
+
- pre-process those to make a context and send it to the LLM one last time with the user query
|
| 43 |
+
- receive the final answer based on the user's query & provided context.
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
## Run
|
| 47 |
+
|
| 48 |
+
`MODEL_COMBOS` in [config.py](config.py) provides multiple variants of embedding & generative LLM model combinations keeping in mind the host system's limitations & capabilities.
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
## Tips/Observations:
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
- Extracting text as a markdown greatly preserved the structure and continuity of the text. This resulted in better logical chunking which in turn led to better embeddings and as a consequence, better search results.
|
| 55 |
+
|
| 56 |
+
- Reading the document via `docling` extracted more and correct text compared to `pymupdf4llm` but at a bit of an expense of speed. It is enabled by default for prioritising accuracy.
|
| 57 |
+
- This proved esp. useful in extracting data containing lots of tables spread over multiple pages.
|
| 58 |
+
- You can pass `--fast-extract` from CLI or tick a box via gradio UI to use pymupdf instead.
|
| 59 |
+
|
| 60 |
+
- Increasing the model size (coupled with correct text extraction in markdown) greatly improved performance. The Qwen3 models very much adhered to instructions but the smaller variants instead of hallucinating simply fell back to saying _'I don't know'_ (as per instructions). The `4B` variant understood the user intent which sometimes was vague and yet managed to give relevant results. The base variant is huge and it wouldn't have been fit and run fast enough on a consumer grade laptop GPU. Loading the `AWQ` variant of it helped as it occupied substantially less memory compared to the original without much loss in performance.
|
| 61 |
+
|
| 62 |
+
- This model also showed great multilingual capabilities. User can upload document in one language and ask questions in another. Or they could upload multilingual documents and ask multilingual queries. For the demo, I tested mostly in English & German.
|
| 63 |
+
|
| 64 |
+
- The data is now stored in datasets format that allows for better storage & scaling (arrow) along with indexing (FAISS) for querying.
|
| 65 |
+
|
| 66 |
+
---
|
| 67 |
+
|
| 68 |
+
## Limitations / Known Issues
|
| 69 |
+
|
| 70 |
+
- Even though `docling` with mostly default options proved to be better than `pymupdf4llm` to extract text, it's not perfect everytime. There're instances where _pymupdf_ extracted text from an embedded image inside a PDF better than docling. However, docling is highly configurable and allows for deep customization via 'pipelines'. And it also comes with a very permissive license for commercial use compared to PyMuPDF.
|
| 71 |
+
- docling comes with `easyocr` by default for text OCR. It's not powerful enough compared to _tesseract_ or similar models. But since installing the latter and linking it with docling involves touching system config, it's not pursued.
|
| 72 |
+
|
| 73 |
+
- When user uploads multiple PDFs, we can improve load times by reading them asynchronously. Attempts to do that with `docling` sometimes resulted in pages with ordering different than the original. So it's dropped for the demo. More investigation is needed later.
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
## Next Steps
|
| 77 |
+
|
| 78 |
+
- Checkout [EmbeddingGemma](https://huggingface.co/blog/embeddinggemma) for embeddings
|
| 79 |
+
- Checkout [fastembed](https://github.com/qdrant/fastembed) to generate embeddings faster
|
| 80 |
+
- Improve text extraction via docling pipeline
|
| 81 |
+
- Checkout `GGUF` models for CPU Inferencing
|
app.py
CHANGED
|
@@ -1,474 +1,145 @@
|
|
| 1 |
-
import
|
| 2 |
-
import math
|
| 3 |
-
import yaml
|
| 4 |
-
import json
|
| 5 |
-
import torch
|
| 6 |
-
import faiss
|
| 7 |
import string
|
| 8 |
-
import asyncio
|
| 9 |
-
import pymupdf
|
| 10 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
-
from time import time
|
| 13 |
-
from pathlib import Path
|
| 14 |
-
from functools import lru_cache
|
| 15 |
-
from ast import literal_eval
|
| 16 |
-
from collections import defaultdict
|
| 17 |
-
from concurrent.futures import ThreadPoolExecutor
|
| 18 |
-
from sentence_transformers import SentenceTransformer
|
| 19 |
-
from langchain_text_splitters import SentenceTransformersTokenTextSplitter
|
| 20 |
-
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
max_concurrence (int): Maximum number of concurrent PDF processing tasks
|
| 32 |
-
|
| 33 |
-
Returns:
|
| 34 |
-
list: List of tuples containing (filename, extracted_text)
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
def _load_pdf_sync(file):
|
| 38 |
-
"""Synchronous PDF loading function for thread pool execution"""
|
| 39 |
-
text = ""
|
| 40 |
-
try:
|
| 41 |
-
with pymupdf.open(file, filetype="pdf") as doc:
|
| 42 |
-
text = "\n".join(page.get_text() for page in doc)
|
| 43 |
-
except Exception:
|
| 44 |
-
log.exception(f"Error reading {file.name}")
|
| 45 |
-
pass
|
| 46 |
-
|
| 47 |
-
return (file.name, text)
|
| 48 |
-
|
| 49 |
-
loop = asyncio.get_event_loop()
|
| 50 |
-
with ThreadPoolExecutor(max_workers=max_concurrence) as executor:
|
| 51 |
-
futures = [
|
| 52 |
-
loop.run_in_executor(executor, _load_pdf_sync, file)
|
| 53 |
-
for file in files
|
| 54 |
-
if file is not None
|
| 55 |
-
]
|
| 56 |
-
|
| 57 |
-
results = await asyncio.gather(*futures, return_exceptions=True)
|
| 58 |
-
|
| 59 |
-
valid_results = [result for result in results if not isinstance(result, Exception)]
|
| 60 |
-
|
| 61 |
-
log.info(f"successfully processed {len(valid_results)} out of {len(files)} PDFs ")
|
| 62 |
-
return valid_results
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
async def build_corpus(pdfs, text_splitter, **load_kwargs):
|
| 66 |
-
texts = await load_pdfs(pdfs, **load_kwargs)
|
| 67 |
-
corpus, meta = [], []
|
| 68 |
-
for file_name, raw_text in texts:
|
| 69 |
-
chunks = text_splitter.split_text(raw_text)
|
| 70 |
-
for i, chunk in enumerate(chunks):
|
| 71 |
-
corpus.append(chunk)
|
| 72 |
-
meta.append({"file": file_name, "chunk_id": i})
|
| 73 |
-
return corpus, meta
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
def generate_text(
|
| 77 |
-
tokenizer, model, user_prompts, system_prompt=None, **llm_kwargs
|
| 78 |
-
): # max_new_tokens=512, temperature=.4):
|
| 79 |
-
if system_prompt is None or "":
|
| 80 |
-
system_prompt = "You are a helpful assistant."
|
| 81 |
-
|
| 82 |
-
if isinstance(user_prompts, str):
|
| 83 |
-
user_prompts = [user_prompts]
|
| 84 |
-
|
| 85 |
-
messages = [
|
| 86 |
-
[
|
| 87 |
-
{"role": "system", "content": system_prompt},
|
| 88 |
-
{"role": "user", "content": user_prompt},
|
| 89 |
-
]
|
| 90 |
-
for user_prompt in user_prompts
|
| 91 |
-
]
|
| 92 |
-
|
| 93 |
-
texts = tokenizer.apply_chat_template(
|
| 94 |
-
messages, tokenize=False, add_generation_prompt=True
|
| 95 |
-
)
|
| 96 |
-
|
| 97 |
-
model_inputs = tokenizer(
|
| 98 |
-
texts, return_tensors="pt", truncation=True, padding=True
|
| 99 |
-
).to(model.device)
|
| 100 |
-
generated_ids = model.generate(
|
| 101 |
-
**model_inputs,
|
| 102 |
-
max_new_tokens=llm_kwargs.pop("max_new_tokens", 512),
|
| 103 |
-
temperature=llm_kwargs.pop("temperature", 0.4),
|
| 104 |
-
)
|
| 105 |
-
generated_ids = [
|
| 106 |
-
output_ids[len(input_ids) :]
|
| 107 |
-
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
| 108 |
-
]
|
| 109 |
-
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
| 110 |
-
return response if len(user_prompts) > 1 else response[0]
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
def load_models(
|
| 114 |
-
embed_model_name: str,
|
| 115 |
-
gen_model_name: str,
|
| 116 |
-
causal_lm: bool = False,
|
| 117 |
-
device=None,
|
| 118 |
-
bitsandbytesconfig=None,
|
| 119 |
-
):
|
| 120 |
-
# This will take some time to run for the first time if the model(s) don't exist locally.
|
| 121 |
-
if not device:
|
| 122 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 123 |
-
embedder = SentenceTransformer(
|
| 124 |
-
embed_model_name,
|
| 125 |
-
device=device,
|
| 126 |
-
model_kwargs={"dtype": "float16"} if device == "cuda" else {},
|
| 127 |
-
)
|
| 128 |
-
|
| 129 |
-
if not causal_lm:
|
| 130 |
-
tok = AutoTokenizer.from_pretrained(gen_model_name)
|
| 131 |
-
gen = AutoModelForSeq2SeqLM.from_pretrained(
|
| 132 |
-
gen_model_name, # device_map='auto',
|
| 133 |
-
quantization_config=bitsandbytesconfig if bitsandbytesconfig else None,
|
| 134 |
-
)
|
| 135 |
-
else:
|
| 136 |
-
tok = AutoTokenizer.from_pretrained(gen_model_name, padding_side="left")
|
| 137 |
-
gen = AutoModelForCausalLM.from_pretrained(
|
| 138 |
-
gen_model_name,
|
| 139 |
-
dtype="float16", # device_map='auto',
|
| 140 |
-
quantization_config=bitsandbytesconfig if bitsandbytesconfig else None,
|
| 141 |
-
)
|
| 142 |
-
gen.to(device)
|
| 143 |
-
return embedder, tok, gen
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
def make_query_variants(
|
| 147 |
-
tokenizer, model, query: str, prompt: str, n: int = 3, **llm_kwargs
|
| 148 |
-
):
|
| 149 |
-
instructions = f"Now give me at least {n} variations."
|
| 150 |
-
resp = generate_text(tokenizer, model, query + instructions, prompt, **llm_kwargs)
|
| 151 |
-
|
| 152 |
-
clean_resp = re.sub(r"^\d+\.\s*", "", resp, flags=re.MULTILINE).split("\n")
|
| 153 |
-
return [query] + [q for q in clean_resp if q.strip()]
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
def clean_rewrite_resp(resp):
|
| 157 |
-
try:
|
| 158 |
-
resp = json.loads(resp) # Parse JSON
|
| 159 |
-
except json.JSONDecodeError:
|
| 160 |
-
try:
|
| 161 |
-
resp = literal_eval(resp) # Fallback parse
|
| 162 |
-
except Exception:
|
| 163 |
-
pass # Keep resp as-is if both fail
|
| 164 |
-
|
| 165 |
-
# Ensure resp is a string before strip and slicing
|
| 166 |
-
if isinstance(resp, str):
|
| 167 |
-
resp = resp.strip()
|
| 168 |
-
if resp:
|
| 169 |
-
start = resp.find("{")
|
| 170 |
-
if start != -1:
|
| 171 |
-
end = resp[::-1].find("}")
|
| 172 |
-
if end != -1:
|
| 173 |
-
resp = resp[start : len(resp) - end]
|
| 174 |
-
return clean_rewrite_resp(resp)
|
| 175 |
-
return resp
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
def transform_query(
|
| 179 |
-
tokenizer, model, query: str, rewrite_prompt: str, **llm_kwargs
|
| 180 |
-
) -> dict:
|
| 181 |
-
"""split the query into things to search and actions to take"""
|
| 182 |
-
resp = generate_text(tokenizer, model, query, rewrite_prompt, **llm_kwargs)
|
| 183 |
-
try:
|
| 184 |
-
resp = clean_rewrite_resp(resp)
|
| 185 |
-
except:
|
| 186 |
-
pass
|
| 187 |
-
return resp
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
def aggregate_queries_and_tasks(
|
| 191 |
-
tokenizer,
|
| 192 |
-
model,
|
| 193 |
-
orig_query,
|
| 194 |
-
rewrite_prompt,
|
| 195 |
-
variants_prompt,
|
| 196 |
-
n_variations=3,
|
| 197 |
-
**llm_kwargs,
|
| 198 |
):
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
|
|
|
| 207 |
)
|
| 208 |
|
| 209 |
-
|
| 210 |
-
tr_q = transform_query(tokenizer, model, orig_query.strip(), rewrite_prompt)
|
| 211 |
-
end = time()
|
| 212 |
-
log.debug(f"\t\t transforming query task took {(end - start):.1f} seconds...")
|
| 213 |
-
|
| 214 |
-
# transformed query might have multiple things to search and tasks to perform depending on user query
|
| 215 |
-
# recursively get variations for each of the search queries but keep the tasks as is.
|
| 216 |
-
tasks = []
|
| 217 |
-
if isinstance(tr_q, dict):
|
| 218 |
-
search_results, tasks = tr_q.get("search", []), tr_q.get("tasks", [])
|
| 219 |
-
for search_result in search_results:
|
| 220 |
-
queries.extend(
|
| 221 |
-
make_query_variants(
|
| 222 |
-
tokenizer,
|
| 223 |
-
model,
|
| 224 |
-
search_result,
|
| 225 |
-
variants_prompt,
|
| 226 |
-
n_variations,
|
| 227 |
-
**llm_kwargs,
|
| 228 |
-
)
|
| 229 |
-
)
|
| 230 |
-
|
| 231 |
-
queries = [q.strip(string.punctuation) for q in queries]
|
| 232 |
-
tasks = [t.strip(string.punctuation) for t in tasks]
|
| 233 |
-
|
| 234 |
-
# keep the original user query as is (if in case LLM messes up the original query) and pick some after shuffling the rest
|
| 235 |
-
# This is disabled as we don't do loops and instead take advantage of batches.
|
| 236 |
-
# Since it's efficient, we can take many query variations at once without worrying about performance.
|
| 237 |
-
# q, queries = queries[:1], queries[1:]
|
| 238 |
-
# shuffle(queries)
|
| 239 |
-
# q += queries[:n_variations-1]
|
| 240 |
-
|
| 241 |
-
return queries, tasks
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
def build_index(corpus_emb, n_cells=5, n_probe=2):
|
| 245 |
-
log.debug(f"building index with {n_cells=}, {n_probe=}")
|
| 246 |
-
d = corpus_emb.shape[1]
|
| 247 |
-
quantizer = faiss.IndexFlatIP(d)
|
| 248 |
-
index = faiss.IndexIVFFlat(quantizer, d, n_cells)
|
| 249 |
-
index.n_probe = n_probe
|
| 250 |
-
index.train(corpus_emb)
|
| 251 |
-
index.add(corpus_emb)
|
| 252 |
-
# index.make_direct_map()
|
| 253 |
-
return index
|
| 254 |
|
| 255 |
|
| 256 |
-
def
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
return [chunk_id for chunk_id, _ in results]
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
class HyDeRAGFusion:
|
| 267 |
-
def __init__(
|
| 268 |
-
self,
|
| 269 |
-
embed_model: str,
|
| 270 |
-
generator_llm_model: str,
|
| 271 |
-
causal_lm: bool = True,
|
| 272 |
-
chunk_overlap: int = 50,
|
| 273 |
-
tokens_per_chunk: int = 256,
|
| 274 |
-
embed_batch_size: int = 64,
|
| 275 |
-
bitsandbytesconfig=None,
|
| 276 |
-
):
|
| 277 |
-
self.embed_batch_size = embed_batch_size
|
| 278 |
-
self.text_splitter = SentenceTransformersTokenTextSplitter(
|
| 279 |
-
chunk_overlap, embed_model, tokens_per_chunk
|
| 280 |
-
)
|
| 281 |
-
self.embedder, self.tok, self.gen = load_models(
|
| 282 |
-
embed_model, generator_llm_model, causal_lm, bitsandbytesconfig
|
| 283 |
-
)
|
| 284 |
-
with open(PROMPTS_FILEPATH) as fl:
|
| 285 |
-
self.prompts = yaml.safe_load(fl)
|
| 286 |
-
|
| 287 |
-
@lru_cache(maxsize=8)
|
| 288 |
-
def preprocess_pdfs(self, pdfs, data_load_kwargs={}, faiss_index_kwargs={}):
|
| 289 |
-
self.corpus, self.meta = asyncio.run(
|
| 290 |
-
build_corpus(pdfs, self.text_splitter, **data_load_kwargs)
|
| 291 |
-
)
|
| 292 |
-
log.debug(f"{len(self.corpus)}, {len(self.meta)}")
|
| 293 |
-
self.corpus_emb = self.embedder.encode(
|
| 294 |
-
self.corpus,
|
| 295 |
-
batch_size=self.embed_batch_size,
|
| 296 |
-
show_progress_bar=True,
|
| 297 |
-
normalize_embeddings=True,
|
| 298 |
-
)
|
| 299 |
-
log.debug(f"{self.corpus_emb.shape}")
|
| 300 |
-
|
| 301 |
-
# https://github.com/facebookresearch/faiss/issues/112
|
| 302 |
-
# n_cells = int(round(4 * (self.corpus_emb.shape[0])**.5))
|
| 303 |
-
|
| 304 |
-
# one centroid for every 100 or so vectors and 20% of them as n_probe
|
| 305 |
-
n_cells = faiss_index_kwargs.pop("n_cells", self.corpus_emb.shape[0] // 100 + 1)
|
| 306 |
-
n_probe = faiss_index_kwargs.pop("n_probe", math.ceil(0.2 * n_cells))
|
| 307 |
-
|
| 308 |
-
self.index = build_index(self.corpus_emb, n_cells, n_probe)
|
| 309 |
-
|
| 310 |
-
def retrieve(
|
| 311 |
-
self, query, n_variants=3, top_k_per_variant=10, top_k_retrieve=3, **llm_kwargs
|
| 312 |
-
):
|
| 313 |
-
start = time()
|
| 314 |
-
|
| 315 |
-
queries, tasks = aggregate_queries_and_tasks(
|
| 316 |
-
self.tok,
|
| 317 |
-
self.gen,
|
| 318 |
-
query.strip(),
|
| 319 |
-
self.prompts["rewrite"],
|
| 320 |
-
self.prompts["variants"],
|
| 321 |
-
n_variants,
|
| 322 |
-
**llm_kwargs,
|
| 323 |
-
)
|
| 324 |
-
|
| 325 |
-
end = time()
|
| 326 |
-
log.debug(f"aggregate task took {(end - start):.1f} seconds...")
|
| 327 |
-
|
| 328 |
-
start = time()
|
| 329 |
-
hyde_docs = generate_text(
|
| 330 |
-
self.tok, self.gen, queries, self.prompts["hyde"], **llm_kwargs
|
| 331 |
-
)
|
| 332 |
-
end = time()
|
| 333 |
-
log.debug(f"generating hyde docs took {(end - start):.1f} seconds...")
|
| 334 |
-
|
| 335 |
-
start = time()
|
| 336 |
-
chunks = []
|
| 337 |
-
for hyde_doc in hyde_docs:
|
| 338 |
-
chunks.extend(self.text_splitter.split_text(hyde_doc))
|
| 339 |
-
q_emb = self.embedder.encode(
|
| 340 |
-
chunks, batch_size=self.embed_batch_size, normalize_embeddings=True
|
| 341 |
-
)
|
| 342 |
-
end = time()
|
| 343 |
-
log.debug(f"embedding hyde docs took {(end - start):.1f} seconds...")
|
| 344 |
-
|
| 345 |
-
_, I = self.index.search(q_emb, top_k_per_variant)
|
| 346 |
-
chunk_ids = reciprocal_rank_fusion(I, top_k_retrieve)
|
| 347 |
-
return chunk_ids, tasks
|
| 348 |
-
|
| 349 |
-
def answer(self, query, doc_ids, tasks, max_ctx_chars=128000):
|
| 350 |
-
total, text, prompt_length = 0, "", 10000
|
| 351 |
-
sep = "\n\n-----\n\n"
|
| 352 |
-
tasks = ", ".join(tasks)
|
| 353 |
-
|
| 354 |
-
for doc_id in doc_ids:
|
| 355 |
-
# adding tags in the context caused more hallucinations.
|
| 356 |
-
# Instead, we list them as sources beneath the model response.
|
| 357 |
-
# _meta = self.meta[doc_id]
|
| 358 |
-
# tag = f"(source: {_meta['file_name']}:{_meta['chunk_id']})"
|
| 359 |
-
chunk = self.corpus[doc_id].strip()
|
| 360 |
-
tag = ""
|
| 361 |
-
|
| 362 |
-
ctx = f"{sep}{tag}\n\n{chunk}"
|
| 363 |
-
if total + len(ctx) + len(tasks) + len(sep) + prompt_length > max_ctx_chars:
|
| 364 |
-
break
|
| 365 |
-
|
| 366 |
-
text += ctx
|
| 367 |
-
total = len(text)
|
| 368 |
-
|
| 369 |
-
text += f"{sep}{tasks}"
|
| 370 |
-
|
| 371 |
-
# instruction = "Answer concisely and also cite file names & chunk ids inline like (pdf_file_name:chunk_id)."
|
| 372 |
-
instruction = "go ahead and answer!"
|
| 373 |
-
user_query = f"\nq: {query}\n\nctx:{text}" + f"\n\n{instruction}\n\n"
|
| 374 |
-
|
| 375 |
-
start = time()
|
| 376 |
-
resp = generate_text(
|
| 377 |
-
self.tok,
|
| 378 |
-
self.gen,
|
| 379 |
-
user_query,
|
| 380 |
-
self.prompts["final_answer"],
|
| 381 |
-
temperature=0.3,
|
| 382 |
-
)
|
| 383 |
-
end = time()
|
| 384 |
-
log.debug(f"final resp took {(end - start):.1f} seconds...")
|
| 385 |
-
|
| 386 |
-
return resp
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
def initial_setup(embed_model, generator_model, bitsandbytesconfig=None):
|
| 390 |
-
return HyDeRAGFusion(
|
| 391 |
-
embed_model, generator_model, bitsandbytesconfig=bitsandbytesconfig
|
| 392 |
-
)
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
start = time()
|
| 396 |
-
HRF = initial_setup("sentence-transformers/LaBSE", "Qwen/Qwen2.5-0.5B-Instruct")
|
| 397 |
-
end = time()
|
| 398 |
-
msg = f"init took {(end - start):.1f} seconds"
|
| 399 |
-
log.debug(msg)
|
| 400 |
-
|
| 401 |
|
| 402 |
-
def main(
|
| 403 |
-
pdfs,
|
| 404 |
-
query,
|
| 405 |
-
n_variants=3,
|
| 406 |
-
top_k_per_variant=5,
|
| 407 |
-
top_k_retrieve=3,
|
| 408 |
-
temperature=0.4,
|
| 409 |
-
max_new_tokens=512,
|
| 410 |
-
):
|
| 411 |
-
start = time()
|
| 412 |
-
if pdfs:
|
| 413 |
-
HRF.preprocess_pdfs(tuple(sorted(pdfs)))
|
| 414 |
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
}
|
| 420 |
-
doc_ids, tasks = HRF.retrieve(
|
| 421 |
-
query,
|
| 422 |
-
int(n_variants),
|
| 423 |
-
int(top_k_per_variant),
|
| 424 |
-
int(top_k_retrieve),
|
| 425 |
-
**llm_kwargs,
|
| 426 |
-
)
|
| 427 |
-
docs = [HRF.corpus[doc_id] for doc_id in doc_ids]
|
| 428 |
-
reply = HRF.answer(query, doc_ids, tasks)
|
| 429 |
-
sources = [
|
| 430 |
-
{
|
| 431 |
-
"source": f"{Path(HRF.meta[doc_id]['file']).stem}:{HRF.meta[doc_id]['chunk_id']}",
|
| 432 |
-
"content": doc,
|
| 433 |
-
}
|
| 434 |
-
for doc_id, doc in zip(doc_ids, docs)
|
| 435 |
-
]
|
| 436 |
|
| 437 |
-
resp = f"{reply}\n\n{'-' * 25}\n\n"
|
| 438 |
-
resp += "Top 3 sources:"
|
| 439 |
-
resp += f"\n\n{'-' * 25}\n\n"
|
| 440 |
-
for source in sources:
|
| 441 |
-
resp += f"source: {source['source']}\n\n"
|
| 442 |
-
resp += source["content"]
|
| 443 |
-
resp += f"\n\n{'-' * 25}\n\n"
|
| 444 |
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
return resp
|
| 448 |
|
| 449 |
|
| 450 |
-
def
|
| 451 |
-
""
|
| 452 |
-
Reset text input when input docs change
|
| 453 |
-
"""
|
| 454 |
-
return ""
|
| 455 |
|
| 456 |
|
| 457 |
with gr.Blocks(title="RAG with HYDE") as demo:
|
| 458 |
gr.Markdown("# RAG with HYDE")
|
| 459 |
with gr.Row():
|
| 460 |
pdf_input = gr.File(
|
| 461 |
-
label="upload PDF(s)",
|
|
|
|
|
|
|
| 462 |
)
|
| 463 |
-
query = gr.Textbox(label="
|
| 464 |
|
| 465 |
-
gr.
|
| 466 |
-
|
| 467 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
|
| 469 |
-
|
| 470 |
-
|
|
|
|
|
|
|
|
|
|
| 471 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
|
| 473 |
if __name__ == "__main__":
|
| 474 |
demo.launch()
|
|
|
|
| 1 |
+
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import string
|
|
|
|
|
|
|
| 3 |
import gradio as gr
|
| 4 |
+
from src.main import ask
|
| 5 |
+
from src.utils import empty_cache
|
| 6 |
+
from src.config import log
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
def ask_wrapper(
|
| 10 |
+
pdfs,
|
| 11 |
+
query,
|
| 12 |
+
model_combo_key,
|
| 13 |
+
fast_extract,
|
| 14 |
+
n_variants,
|
| 15 |
+
top_k_per_variant,
|
| 16 |
+
top_k_retrieve,
|
| 17 |
+
temperature,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
):
|
| 19 |
+
resp = ask.callback(
|
| 20 |
+
pdfs,
|
| 21 |
+
query,
|
| 22 |
+
model_combo_key,
|
| 23 |
+
fast_extract,
|
| 24 |
+
n_variants,
|
| 25 |
+
top_k_per_variant,
|
| 26 |
+
top_k_retrieve,
|
| 27 |
+
temperature,
|
| 28 |
)
|
| 29 |
|
| 30 |
+
return f"## Final Answer:\n\n{resp}\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
+
def reset(pdfs):
|
| 34 |
+
"""
|
| 35 |
+
Reset text input and empty cache
|
| 36 |
+
"""
|
| 37 |
+
log.warning("emptying cache")
|
| 38 |
+
empty_cache()
|
| 39 |
+
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
# Enable the button only when both fields are nonempty
|
| 43 |
+
def _enable_submit_if_filled(pdfs, query):
|
| 44 |
+
status = bool(pdfs) and bool(len(query.strip(string.punctuation + " ")) > 10)
|
| 45 |
+
return gr.update(interactive=status)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
+
def disable_button():
|
| 49 |
+
return gr.update(interactive=False, value="Processing...")
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
+
def enable_button():
|
| 53 |
+
return gr.update(interactive=True, value="Submit")
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
|
| 56 |
with gr.Blocks(title="RAG with HYDE") as demo:
|
| 57 |
gr.Markdown("# RAG with HYDE")
|
| 58 |
with gr.Row():
|
| 59 |
pdf_input = gr.File(
|
| 60 |
+
label="upload PDF(s)",
|
| 61 |
+
file_types=[".pdf"],
|
| 62 |
+
file_count="multiple",
|
| 63 |
)
|
| 64 |
+
query = gr.Textbox(label="Question (Enter at least 10 valid characters)")
|
| 65 |
|
| 66 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 67 |
+
gr.Markdown(
|
| 68 |
+
"*These parameters have sensible defaults but can be customized if needed*"
|
| 69 |
+
)
|
| 70 |
+
with gr.Row():
|
| 71 |
+
_default_combo = "linux" if sys.platform == "linux" else "mac"
|
| 72 |
+
model_combo_key = gr.Dropdown(
|
| 73 |
+
label="Model Combo Key",
|
| 74 |
+
choices=[_default_combo, "HF-mid"],
|
| 75 |
+
value=_default_combo,
|
| 76 |
+
)
|
| 77 |
+
fast_extract = gr.Checkbox(
|
| 78 |
+
value=False, label="Use PyMuPDF to extract content in markdown"
|
| 79 |
+
)
|
| 80 |
+
n_variants = gr.Number(
|
| 81 |
+
value=3, minimum=1, maximum=5, label="no. of query variants"
|
| 82 |
+
)
|
| 83 |
+
with gr.Row():
|
| 84 |
+
top_k_per_variant = gr.Number(
|
| 85 |
+
value=5,
|
| 86 |
+
minimum=2,
|
| 87 |
+
maximum=10,
|
| 88 |
+
label="top `k` hits per query variant for RRF",
|
| 89 |
+
)
|
| 90 |
+
top_k_retrieve = gr.Number(
|
| 91 |
+
value=3,
|
| 92 |
+
minimum=1,
|
| 93 |
+
maximum=5,
|
| 94 |
+
label="top `k` chunks to retrieve after RRF",
|
| 95 |
+
)
|
| 96 |
+
temperature = gr.Slider(
|
| 97 |
+
value=0.7, minimum=0.1, maximum=1.0, step=0.1, label="temperature"
|
| 98 |
+
)
|
| 99 |
|
| 100 |
+
gr.Markdown(
|
| 101 |
+
"### *Please be patient after hitting the submit button* esp. for the first question after uploading new document(s)"
|
| 102 |
+
)
|
| 103 |
+
submit_btn = gr.Button("Submit", variant="primary", interactive=False)
|
| 104 |
+
answer = gr.Markdown(label="## Answer")
|
| 105 |
|
| 106 |
+
pdf_input.change(
|
| 107 |
+
_enable_submit_if_filled, [pdf_input, query], submit_btn, queue=False
|
| 108 |
+
)
|
| 109 |
+
query.change(_enable_submit_if_filled, [pdf_input, query], submit_btn, queue=False)
|
| 110 |
+
|
| 111 |
+
submit_btn.click(fn=disable_button, outputs=submit_btn).then(
|
| 112 |
+
fn=ask_wrapper,
|
| 113 |
+
inputs=[
|
| 114 |
+
pdf_input,
|
| 115 |
+
query,
|
| 116 |
+
model_combo_key,
|
| 117 |
+
fast_extract,
|
| 118 |
+
n_variants,
|
| 119 |
+
top_k_per_variant,
|
| 120 |
+
top_k_retrieve,
|
| 121 |
+
temperature,
|
| 122 |
+
],
|
| 123 |
+
outputs=answer,
|
| 124 |
+
).then(fn=enable_button, outputs=submit_btn)
|
| 125 |
+
|
| 126 |
+
query.submit(fn=disable_button, outputs=submit_btn).then(
|
| 127 |
+
fn=ask_wrapper,
|
| 128 |
+
inputs=[
|
| 129 |
+
pdf_input,
|
| 130 |
+
query,
|
| 131 |
+
model_combo_key,
|
| 132 |
+
fast_extract,
|
| 133 |
+
n_variants,
|
| 134 |
+
top_k_per_variant,
|
| 135 |
+
top_k_retrieve,
|
| 136 |
+
temperature,
|
| 137 |
+
],
|
| 138 |
+
outputs=answer,
|
| 139 |
+
).then(fn=enable_button, outputs=submit_btn)
|
| 140 |
+
|
| 141 |
+
pdf_input.change(reset, pdf_input, query)
|
| 142 |
+
demo.load(reset, pdf_input, query)
|
| 143 |
|
| 144 |
if __name__ == "__main__":
|
| 145 |
demo.launch()
|
requirements.txt
CHANGED
|
@@ -3,6 +3,10 @@ numpy
|
|
| 3 |
transformers
|
| 4 |
sentence-transformers
|
| 5 |
pymupdf
|
|
|
|
| 6 |
langchain-text-splitters
|
| 7 |
#faiss-cpu
|
| 8 |
faiss-gpu-cu12
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
transformers
|
| 4 |
sentence-transformers
|
| 5 |
pymupdf
|
| 6 |
+
pymupdf4llm
|
| 7 |
langchain-text-splitters
|
| 8 |
#faiss-cpu
|
| 9 |
faiss-gpu-cu12
|
| 10 |
+
autoawq
|
| 11 |
+
docling
|
| 12 |
+
|
src/config.py
CHANGED
|
@@ -8,7 +8,7 @@ def get_logger(LOG_LEVEL="INFO"):
|
|
| 8 |
LOG_PATH = Path("logs.log")
|
| 9 |
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
| 10 |
|
| 11 |
-
log = logging.Logger("
|
| 12 |
log.setLevel(LOG_LEVEL)
|
| 13 |
|
| 14 |
file_handler = logging.FileHandler(LOG_PATH)
|
|
@@ -21,3 +21,35 @@ def get_logger(LOG_LEVEL="INFO"):
|
|
| 21 |
|
| 22 |
|
| 23 |
log = get_logger("DEBUG")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
LOG_PATH = Path("logs.log")
|
| 9 |
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
| 10 |
|
| 11 |
+
log = logging.Logger("agentic_search")
|
| 12 |
log.setLevel(LOG_LEVEL)
|
| 13 |
|
| 14 |
file_handler = logging.FileHandler(LOG_PATH)
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
log = get_logger("DEBUG")
|
| 24 |
+
|
| 25 |
+
MODEL_COMBOS = {
|
| 26 |
+
"linux": {
|
| 27 |
+
"embed_model": "Qwen/Qwen3-Embedding-0.6B",
|
| 28 |
+
"gen_model": "Qwen/Qwen3-4B-AWQ",
|
| 29 |
+
# 'gen_model': "Qwen/Qwen3-0.6B-GPTQ-Int8"
|
| 30 |
+
# 'gen_model': "Qwen/Qwen3-1.7B-GPTQ-Int8"
|
| 31 |
+
},
|
| 32 |
+
# feel free to replace with any ??B-MLX-?bit versions from Qwen3 Collection at:
|
| 33 |
+
# https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f
|
| 34 |
+
"mac": {
|
| 35 |
+
"embed_model": "Qwen/Qwen3-Embedding-0.6B",
|
| 36 |
+
"gen_model": "Qwen/Qwen3-4B-MLX-4bit",
|
| 37 |
+
},
|
| 38 |
+
"mac_mid": {
|
| 39 |
+
"embed_model": "Qwen/Qwen3-Embedding-0.6B",
|
| 40 |
+
"gen_model": "Qwen/Qwen3-4B-MLX-6bit",
|
| 41 |
+
},
|
| 42 |
+
"mac_high": {
|
| 43 |
+
"embed_model": "Qwen/Qwen3-Embedding-0.6B",
|
| 44 |
+
"gen_model": "Qwen/Qwen3-4B-MLX-8bit",
|
| 45 |
+
},
|
| 46 |
+
# HF-low is same as `linux-local`
|
| 47 |
+
"HF-mid": {
|
| 48 |
+
"embed_model": "Qwen/Qwen3-Embedding-0.6B",
|
| 49 |
+
"gen_model": "Qwen/Qwen3-8B-AWQ",
|
| 50 |
+
},
|
| 51 |
+
"HF-high": {
|
| 52 |
+
"embed_model": "Qwen/Qwen3-Embedding-4B",
|
| 53 |
+
"gen_model": "Qwen/Qwen3-14B-AWQ",
|
| 54 |
+
},
|
| 55 |
+
}
|
src/hyde_rag.py
DELETED
|
@@ -1,206 +0,0 @@
|
|
| 1 |
-
# hyde_ragfusion.py
|
| 2 |
-
# Minimal HyDE + RAG-Fusion over local PDFs.
|
| 3 |
-
# Dependencies: transformers, sentence-transformers, scikit-learn, pymupdf, numpy
|
| 4 |
-
|
| 5 |
-
import os
|
| 6 |
-
import re
|
| 7 |
-
import heapq
|
| 8 |
-
import fitz # PyMuPDF
|
| 9 |
-
from sklearn.neighbors import NearestNeighbors
|
| 10 |
-
from sentence_transformers import SentenceTransformer
|
| 11 |
-
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
# -----------------------------
|
| 15 |
-
# Ingestion & Chunking
|
| 16 |
-
# -----------------------------
|
| 17 |
-
def load_pdfs(folder):
|
| 18 |
-
docs = []
|
| 19 |
-
for fn in os.listdir(folder):
|
| 20 |
-
if fn.lower().endswith(".pdf"):
|
| 21 |
-
path = os.path.join(folder, fn)
|
| 22 |
-
with fitz.open(path) as doc:
|
| 23 |
-
text = "\n".join(page.get_text("text") for page in doc)
|
| 24 |
-
text = re.sub(r"\s+\n", "\n", text).strip()
|
| 25 |
-
docs.append((fn, text))
|
| 26 |
-
return docs
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def chunk_text(text, chunk_size=300, overlap=50):
|
| 30 |
-
words = text.split()
|
| 31 |
-
chunks, i = [], 0
|
| 32 |
-
while i < len(words):
|
| 33 |
-
chunk = " ".join(words[i : i + chunk_size])
|
| 34 |
-
chunks.append(chunk)
|
| 35 |
-
i += chunk_size - overlap
|
| 36 |
-
return chunks
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def build_corpus(pdf_folder):
|
| 40 |
-
raw = load_pdfs(pdf_folder)
|
| 41 |
-
corpus, meta = [], []
|
| 42 |
-
for fn, txt in raw:
|
| 43 |
-
for i, ch in enumerate(chunk_text(txt)):
|
| 44 |
-
corpus.append(ch)
|
| 45 |
-
meta.append({"file": fn, "chunk_id": i})
|
| 46 |
-
return corpus, meta
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
# -----------------------------
|
| 50 |
-
# Models (local)
|
| 51 |
-
# -----------------------------
|
| 52 |
-
def load_models():
|
| 53 |
-
# Small, fast encoder for embeddings
|
| 54 |
-
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
| 55 |
-
# Lightweight local generator for HyDE + answers
|
| 56 |
-
gen_name = "google/flan-t5-base"
|
| 57 |
-
tok = AutoTokenizer.from_pretrained(gen_name)
|
| 58 |
-
gen = AutoModelForSeq2SeqLM.from_pretrained(gen_name)
|
| 59 |
-
return embedder, tok, gen
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
# -----------------------------
|
| 63 |
-
# Index (cosine)
|
| 64 |
-
# -----------------------------
|
| 65 |
-
def fit_index(embeddings, n_neighbors=12):
|
| 66 |
-
nn = NearestNeighbors(metric="cosine", algorithm="auto")
|
| 67 |
-
nn.fit(embeddings)
|
| 68 |
-
return nn
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
# -----------------------------
|
| 72 |
-
# RAG-Fusion (query variants) + HyDE
|
| 73 |
-
# -----------------------------
|
| 74 |
-
Q_VARIANTS_PROMPT = """You rewrite the user query into {n} diverse, specific search queries (short).
|
| 75 |
-
User query: "{q}"
|
| 76 |
-
Return each on a new line, no numbering, no extra text."""
|
| 77 |
-
|
| 78 |
-
HYDE_PROMPT = """Write a factual, neutral, self-contained paragraph that could answer:
|
| 79 |
-
"{q}"
|
| 80 |
-
Avoid fluff. Include likely key terms and entities. 120-180 words."""
|
| 81 |
-
|
| 82 |
-
ANSWER_PROMPT = """You are a helpful assistant. Use ONLY the provided context.
|
| 83 |
-
Question: {q}
|
| 84 |
-
|
| 85 |
-
Context:
|
| 86 |
-
{ctx}
|
| 87 |
-
|
| 88 |
-
Answer concisely and cite file names & chunk ids inline like (file:chunk).
|
| 89 |
-
"""
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
def generate_text(gen, tok, prompt, max_new_tokens=160, temperature=0.3):
|
| 93 |
-
inputs = tok(prompt, return_tensors="pt")
|
| 94 |
-
out = gen.generate(
|
| 95 |
-
**inputs,
|
| 96 |
-
max_new_tokens=max_new_tokens,
|
| 97 |
-
do_sample=False,
|
| 98 |
-
temperature=temperature,
|
| 99 |
-
)
|
| 100 |
-
return tok.decode(out[0], skip_special_tokens=True).strip()
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
def make_query_variants(gen, tok, q, n=4):
|
| 104 |
-
txt = generate_text(
|
| 105 |
-
gen, tok, Q_VARIANTS_PROMPT.format(q=q, n=n), max_new_tokens=120
|
| 106 |
-
)
|
| 107 |
-
# Split cleanly into lines (drop empties/dups; include original)
|
| 108 |
-
lines = [l.strip(" -•\t") for l in txt.split("\n") if l.strip()]
|
| 109 |
-
uniq = []
|
| 110 |
-
seen = set()
|
| 111 |
-
for l in lines + [q]:
|
| 112 |
-
if l not in seen:
|
| 113 |
-
seen.add(l)
|
| 114 |
-
uniq.append(l)
|
| 115 |
-
return uniq[:n]
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
def hyde_doc(gen, tok, q):
|
| 119 |
-
return generate_text(gen, tok, HYDE_PROMPT.format(q=q), max_new_tokens=220)
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
# -----------------------------
|
| 123 |
-
# Retrieval + RRF
|
| 124 |
-
# -----------------------------
|
| 125 |
-
def cosine_search(nn, corpus_embeddings, query_vec, top_k=8):
|
| 126 |
-
dists, idxs = nn.kneighbors(query_vec.reshape(1, -1), n_neighbors=top_k)
|
| 127 |
-
# Convert cosine distance to similarity
|
| 128 |
-
sims = 1 - dists[0]
|
| 129 |
-
return list(zip(idxs[0].tolist(), sims.tolist()))
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
def reciprocal_rank_fusion(rank_lists, k=60, top_k=8):
|
| 133 |
-
# rank_lists: list of [doc_id, ...] ordered best→worst
|
| 134 |
-
scores = {}
|
| 135 |
-
for ranks in rank_lists:
|
| 136 |
-
for rank, doc_id in enumerate(ranks, start=1):
|
| 137 |
-
scores[doc_id] = scores.get(doc_id, 0.0) + 1.0 / (k + rank)
|
| 138 |
-
# top by fused score
|
| 139 |
-
best = heapq.nlargest(top_k, scores.items(), key=lambda x: x[1])
|
| 140 |
-
return [doc_id for doc_id, _ in best]
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
# -----------------------------
|
| 144 |
-
# Pipeline
|
| 145 |
-
# -----------------------------
|
| 146 |
-
class HyDeRAGFusion:
|
| 147 |
-
def __init__(self, pdf_folder):
|
| 148 |
-
self.corpus, self.meta = build_corpus(pdf_folder)
|
| 149 |
-
self.embedder, self.tok, self.gen = load_models()
|
| 150 |
-
self.corpus_emb = self.embedder.encode(
|
| 151 |
-
self.corpus,
|
| 152 |
-
batch_size=64,
|
| 153 |
-
show_progress_bar=True,
|
| 154 |
-
normalize_embeddings=True,
|
| 155 |
-
)
|
| 156 |
-
self.nn = fit_index(self.corpus_emb)
|
| 157 |
-
|
| 158 |
-
def retrieve(self, query, n_variants=4, per_variant_k=8, final_top_k=6, rrf_k=60):
|
| 159 |
-
variants = make_query_variants(self.gen, self.tok, query, n=n_variants)
|
| 160 |
-
rank_lists = []
|
| 161 |
-
for v in variants:
|
| 162 |
-
hypo = hyde_doc(self.gen, self.tok, v) # HyDE
|
| 163 |
-
q_vec = self.embedder.encode([hypo], normalize_embeddings=True)[0]
|
| 164 |
-
hits = cosine_search(self.nn, self.corpus_emb, q_vec, top_k=per_variant_k)
|
| 165 |
-
rank_lists.append([doc_id for doc_id, _ in hits])
|
| 166 |
-
fused = reciprocal_rank_fusion(rank_lists, k=rrf_k, top_k=final_top_k)
|
| 167 |
-
return fused
|
| 168 |
-
|
| 169 |
-
def answer(self, query, doc_ids, max_ctx_chars=4000):
|
| 170 |
-
# Build compact context with inline provenance
|
| 171 |
-
ctx_parts = []
|
| 172 |
-
total = 0
|
| 173 |
-
for i in doc_ids:
|
| 174 |
-
piece = self.corpus[i]
|
| 175 |
-
tag = f"(source: {self.meta[i]['file']}:{self.meta[i]['chunk_id']})"
|
| 176 |
-
chunk = piece.strip()
|
| 177 |
-
if total + len(chunk) + len(tag) + 5 > max_ctx_chars:
|
| 178 |
-
break
|
| 179 |
-
ctx_parts.append(f"{chunk}\n{tag}")
|
| 180 |
-
total += len(chunk) + len(tag) + 5
|
| 181 |
-
ctx = "\n\n---\n\n".join(ctx_parts)
|
| 182 |
-
prompt = ANSWER_PROMPT.format(q=query, ctx=ctx)
|
| 183 |
-
return generate_text(self.gen, self.tok, prompt, max_new_tokens=300)
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
# -----------------------------
|
| 187 |
-
# Example usage
|
| 188 |
-
# -----------------------------
|
| 189 |
-
if __name__ == "__main__":
|
| 190 |
-
import argparse
|
| 191 |
-
|
| 192 |
-
ap = argparse.ArgumentParser()
|
| 193 |
-
ap.add_argument("--pdf_folder", required=True, help="Folder with PDFs to index")
|
| 194 |
-
ap.add_argument("--query", required=True, help="Your user question")
|
| 195 |
-
ap.add_argument("--show_sources", action="store_true")
|
| 196 |
-
args = ap.parse_args()
|
| 197 |
-
|
| 198 |
-
rag = HyDeRAGFusion(args.pdf_folder)
|
| 199 |
-
doc_ids = rag.retrieve(args.query)
|
| 200 |
-
answer = rag.answer(args.query, doc_ids)
|
| 201 |
-
print("\n=== ANSWER ===\n")
|
| 202 |
-
print(answer)
|
| 203 |
-
if args.show_sources:
|
| 204 |
-
print("\n=== TOP SOURCES ===")
|
| 205 |
-
for i in doc_ids:
|
| 206 |
-
print(f"- {rag.meta[i]['file']}:{rag.meta[i]['chunk_id']}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/main.py
CHANGED
|
@@ -1,79 +1,32 @@
|
|
| 1 |
-
import pymupdf
|
| 2 |
-
import math
|
| 3 |
-
import faiss
|
| 4 |
import string
|
|
|
|
| 5 |
import yaml
|
| 6 |
import re
|
| 7 |
-
import
|
| 8 |
-
import asyncio
|
| 9 |
import torch
|
| 10 |
-
import streamlit as st
|
| 11 |
import click
|
| 12 |
|
| 13 |
-
from collections import defaultdict
|
| 14 |
-
from ast import literal_eval
|
| 15 |
from time import time
|
|
|
|
|
|
|
| 16 |
from sentence_transformers import SentenceTransformer
|
| 17 |
-
from transformers import AutoTokenizer,
|
| 18 |
-
from
|
| 19 |
-
|
| 20 |
-
from src.config import PROMPTS_FILEPATH, log
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
folder (str): Path to folder containing PDF files
|
| 29 |
-
max_concurrence (int): Maximum number of concurrent PDF processing tasks
|
| 30 |
-
|
| 31 |
-
Returns:
|
| 32 |
-
list: List of tuples containing (filename, extracted_text)
|
| 33 |
-
"""
|
| 34 |
-
|
| 35 |
-
def _load_pdf_sync(file):
|
| 36 |
-
"""Synchronous PDF loading function for thread pool execution"""
|
| 37 |
-
text = ""
|
| 38 |
-
try:
|
| 39 |
-
with pymupdf.open(stream=file.getvalue(), filetype="pdf") as doc:
|
| 40 |
-
text = "\n".join(page.get_text() for page in doc)
|
| 41 |
-
except Exception:
|
| 42 |
-
log.exception(f"Error reading {file.name}")
|
| 43 |
-
pass
|
| 44 |
-
|
| 45 |
-
return (file.name, text)
|
| 46 |
-
|
| 47 |
-
loop = asyncio.get_event_loop()
|
| 48 |
-
with ThreadPoolExecutor(max_workers=max_concurrence) as executor:
|
| 49 |
-
futures = [
|
| 50 |
-
loop.run_in_executor(executor, _load_pdf_sync, file)
|
| 51 |
-
for file in files
|
| 52 |
-
if file is not None
|
| 53 |
-
]
|
| 54 |
-
|
| 55 |
-
results = await asyncio.gather(*futures, return_exceptions=True)
|
| 56 |
-
|
| 57 |
-
valid_results = [result for result in results if not isinstance(result, Exception)]
|
| 58 |
-
|
| 59 |
-
log.info(f"successfully processed {len(valid_results)} out of {len(files)} PDFs ")
|
| 60 |
-
return valid_results
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
async def build_corpus(pdfs, text_splitter, **load_kwargs):
|
| 64 |
-
texts = await load_pdfs(pdfs, **load_kwargs)
|
| 65 |
-
corpus, meta = [], []
|
| 66 |
-
for file_name, raw_text in texts:
|
| 67 |
-
chunks = text_splitter.split_text(raw_text)
|
| 68 |
-
for i, chunk in enumerate(chunks):
|
| 69 |
-
corpus.append(chunk)
|
| 70 |
-
meta.append({"file": file_name, "chunk_id": i})
|
| 71 |
-
return corpus, meta
|
| 72 |
|
| 73 |
|
| 74 |
def generate_text(
|
| 75 |
-
tokenizer, model, user_prompts, system_prompt=None, **llm_kwargs
|
| 76 |
-
):
|
|
|
|
| 77 |
if system_prompt is None or "":
|
| 78 |
system_prompt = "You are a helpful assistant."
|
| 79 |
|
|
@@ -88,96 +41,122 @@ def generate_text(
|
|
| 88 |
for user_prompt in user_prompts
|
| 89 |
]
|
| 90 |
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
)
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
|
|
|
| 110 |
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
gen_model_name: str,
|
| 114 |
-
causal_lm: bool = False,
|
| 115 |
-
device=None,
|
| 116 |
-
bitsandbytesconfig=None,
|
| 117 |
-
):
|
| 118 |
# This will take some time to run for the first time if the model(s) don't exist locally.
|
| 119 |
if not device:
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
embedder = SentenceTransformer(
|
| 122 |
embed_model_name,
|
| 123 |
device=device,
|
| 124 |
-
model_kwargs={"dtype":
|
| 125 |
)
|
| 126 |
-
|
| 127 |
-
if not causal_lm:
|
| 128 |
-
tok = AutoTokenizer.from_pretrained(gen_model_name)
|
| 129 |
-
gen = AutoModelForSeq2SeqLM.from_pretrained(
|
| 130 |
-
gen_model_name, # device_map='auto',
|
| 131 |
-
quantization_config=bitsandbytesconfig if bitsandbytesconfig else None,
|
| 132 |
-
)
|
| 133 |
-
else:
|
| 134 |
-
tok = AutoTokenizer.from_pretrained(gen_model_name, padding_side="left")
|
| 135 |
-
gen = AutoModelForCausalLM.from_pretrained(
|
| 136 |
-
gen_model_name,
|
| 137 |
-
dtype="float16", # device_map='auto',
|
| 138 |
-
quantization_config=bitsandbytesconfig if bitsandbytesconfig else None,
|
| 139 |
-
)
|
| 140 |
-
gen.to(device)
|
| 141 |
return embedder, tok, gen
|
| 142 |
|
| 143 |
|
| 144 |
def make_query_variants(
|
| 145 |
-
tokenizer, model, query: str, prompt: str, n: int = 3, **llm_kwargs
|
| 146 |
):
|
| 147 |
-
instructions = f"Now give me at least {n} variations
|
| 148 |
-
|
| 149 |
-
|
|
|
|
|
|
|
| 150 |
clean_resp = re.sub(r"^\d+\.\s*", "", resp, flags=re.MULTILINE).split("\n")
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
try:
|
| 156 |
-
resp = json.loads(resp) # Parse JSON
|
| 157 |
-
except json.JSONDecodeError:
|
| 158 |
-
try:
|
| 159 |
-
resp = literal_eval(resp) # Fallback parse
|
| 160 |
-
except Exception:
|
| 161 |
-
pass # Keep resp as-is if both fail
|
| 162 |
-
|
| 163 |
-
# Ensure resp is a string before strip and slicing
|
| 164 |
-
if isinstance(resp, str):
|
| 165 |
-
resp = resp.strip()
|
| 166 |
-
if resp:
|
| 167 |
-
start = resp.find("{")
|
| 168 |
-
if start != -1:
|
| 169 |
-
end = resp[::-1].find("}")
|
| 170 |
-
if end != -1:
|
| 171 |
-
resp = resp[start : len(resp) - end]
|
| 172 |
-
return clean_rewrite_resp(resp)
|
| 173 |
-
return resp
|
| 174 |
|
| 175 |
|
| 176 |
def transform_query(
|
| 177 |
-
tokenizer, model, query: str, rewrite_prompt: str, **llm_kwargs
|
| 178 |
) -> dict:
|
| 179 |
"""split the query into things to search and actions to take"""
|
| 180 |
-
resp = generate_text(
|
|
|
|
|
|
|
| 181 |
try:
|
| 182 |
resp = clean_rewrite_resp(resp)
|
| 183 |
except:
|
|
@@ -192,6 +171,7 @@ def aggregate_queries_and_tasks(
|
|
| 192 |
rewrite_prompt,
|
| 193 |
variants_prompt,
|
| 194 |
n_variations=3,
|
|
|
|
| 195 |
**llm_kwargs,
|
| 196 |
):
|
| 197 |
# make variations for the original query as is
|
|
@@ -201,14 +181,13 @@ def aggregate_queries_and_tasks(
|
|
| 201 |
orig_query.strip(),
|
| 202 |
variants_prompt,
|
| 203 |
n_variations,
|
|
|
|
| 204 |
**llm_kwargs,
|
|
|
|
|
|
|
|
|
|
| 205 |
)
|
| 206 |
|
| 207 |
-
start = time()
|
| 208 |
-
tr_q = transform_query(tokenizer, model, orig_query.strip(), rewrite_prompt)
|
| 209 |
-
end = time()
|
| 210 |
-
log.debug(f"\t\t transforming query task took {(end - start):.1f} seconds...")
|
| 211 |
-
|
| 212 |
# transformed query might have multiple things to search and tasks to perform depending on user query
|
| 213 |
# recursively get variations for each of the search queries but keep the tasks as is.
|
| 214 |
tasks = []
|
|
@@ -222,91 +201,77 @@ def aggregate_queries_and_tasks(
|
|
| 222 |
search_result,
|
| 223 |
variants_prompt,
|
| 224 |
n_variations,
|
|
|
|
| 225 |
**llm_kwargs,
|
| 226 |
)
|
| 227 |
)
|
| 228 |
|
| 229 |
-
queries = [q.strip(string.punctuation) for q in queries]
|
| 230 |
-
tasks = [t.strip(string.punctuation) for t in tasks]
|
| 231 |
-
|
| 232 |
# keep the original user query as is (if in case LLM messes up the original query) and pick some after shuffling the rest
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
|
|
|
| 238 |
|
| 239 |
return queries, tasks
|
| 240 |
|
| 241 |
|
| 242 |
-
def build_index(corpus_emb, n_cells=5, n_probe=2):
|
| 243 |
-
log.debug(f"building index with {n_cells=}, {n_probe=}")
|
| 244 |
-
d = corpus_emb.shape[1]
|
| 245 |
-
quantizer = faiss.IndexFlatIP(d)
|
| 246 |
-
index = faiss.IndexIVFFlat(quantizer, d, n_cells)
|
| 247 |
-
index.n_probe = n_probe
|
| 248 |
-
index.train(corpus_emb)
|
| 249 |
-
index.add(corpus_emb)
|
| 250 |
-
# index.make_direct_map()
|
| 251 |
-
return index
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
def reciprocal_rank_fusion(indices, top_k=3, denom=50):
|
| 255 |
-
ii = indices.tolist()
|
| 256 |
-
scores = defaultdict(int)
|
| 257 |
-
for row in ii:
|
| 258 |
-
for rank, chunk_id in enumerate(row):
|
| 259 |
-
scores[chunk_id] += 1 / (rank + denom)
|
| 260 |
-
results = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
|
| 261 |
-
return [chunk_id for chunk_id, _ in results]
|
| 262 |
-
|
| 263 |
-
|
| 264 |
class HyDeRAGFusion:
|
| 265 |
def __init__(
|
| 266 |
self,
|
| 267 |
embed_model: str,
|
| 268 |
generator_llm_model: str,
|
| 269 |
-
|
| 270 |
-
chunk_overlap: int = 50,
|
| 271 |
-
tokens_per_chunk: int = 256,
|
| 272 |
-
embed_batch_size: int = 64,
|
| 273 |
-
bitsandbytesconfig=None,
|
| 274 |
):
|
| 275 |
self.embed_batch_size = embed_batch_size
|
| 276 |
-
self.
|
| 277 |
-
chunk_overlap, embed_model, tokens_per_chunk
|
| 278 |
-
)
|
| 279 |
self.embedder, self.tok, self.gen = load_models(
|
| 280 |
-
embed_model, generator_llm_model
|
| 281 |
)
|
|
|
|
| 282 |
with open(PROMPTS_FILEPATH) as fl:
|
| 283 |
self.prompts = yaml.safe_load(fl)
|
| 284 |
|
| 285 |
-
def
|
| 286 |
-
|
| 287 |
-
|
|
|
|
| 288 |
)
|
| 289 |
-
|
| 290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
batch_size=self.embed_batch_size,
|
| 292 |
-
|
| 293 |
-
|
|
|
|
|
|
|
| 294 |
)
|
| 295 |
|
| 296 |
-
|
| 297 |
-
#
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
|
|
|
| 304 |
|
| 305 |
def retrieve(
|
| 306 |
-
self, query, n_variants=3, top_k_per_variant=
|
| 307 |
):
|
| 308 |
-
start = time()
|
| 309 |
-
|
| 310 |
queries, tasks = aggregate_queries_and_tasks(
|
| 311 |
self.tok,
|
| 312 |
self.gen,
|
|
@@ -314,48 +279,38 @@ class HyDeRAGFusion:
|
|
| 314 |
self.prompts["rewrite"],
|
| 315 |
self.prompts["variants"],
|
| 316 |
n_variants,
|
|
|
|
| 317 |
**llm_kwargs,
|
| 318 |
)
|
| 319 |
-
|
| 320 |
-
end = time()
|
| 321 |
-
log.debug(f"aggregate task took {(end - start):.1f} seconds...")
|
| 322 |
-
|
| 323 |
-
start = time()
|
| 324 |
hyde_docs = generate_text(
|
| 325 |
-
self.tok,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
)
|
| 327 |
-
end = time()
|
| 328 |
-
log.debug(f"generating hyde docs took {(end - start):.1f} seconds...")
|
| 329 |
-
|
| 330 |
-
start = time()
|
| 331 |
chunks = []
|
| 332 |
for hyde_doc in hyde_docs:
|
| 333 |
chunks.extend(self.text_splitter.split_text(hyde_doc))
|
| 334 |
-
q_emb = self.
|
| 335 |
-
|
|
|
|
| 336 |
)
|
| 337 |
-
|
| 338 |
-
|
|
|
|
| 339 |
|
| 340 |
-
|
| 341 |
-
chunk_ids = reciprocal_rank_fusion(I, top_k_retrieve)
|
| 342 |
-
return chunk_ids, tasks
|
| 343 |
-
|
| 344 |
-
def answer(self, query, doc_ids, tasks, max_ctx_chars=128000):
|
| 345 |
total, text, prompt_length = 0, "", 10000
|
| 346 |
sep = "\n\n-----\n\n"
|
| 347 |
-
tasks = ", ".join(tasks)
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
# _meta = self.meta[doc_id]
|
| 353 |
-
# tag = f"(source: {_meta['file_name']}:{_meta['chunk_id']})"
|
| 354 |
-
chunk = self.corpus[doc_id].strip()
|
| 355 |
-
tag = ""
|
| 356 |
-
|
| 357 |
-
ctx = f"{sep}{tag}\n\n{chunk}"
|
| 358 |
if total + len(ctx) + len(tasks) + len(sep) + prompt_length > max_ctx_chars:
|
|
|
|
| 359 |
break
|
| 360 |
|
| 361 |
text += ctx
|
|
@@ -363,41 +318,52 @@ class HyDeRAGFusion:
|
|
| 363 |
|
| 364 |
text += f"{sep}{tasks}"
|
| 365 |
|
| 366 |
-
# instruction = "Answer concisely and also cite file names & chunk ids inline like (pdf_file_name:chunk_id)."
|
| 367 |
instruction = "go ahead and answer!"
|
| 368 |
user_query = f"\nq: {query}\n\nctx:{text}" + f"\n\n{instruction}\n\n"
|
| 369 |
-
|
| 370 |
-
start = time()
|
| 371 |
resp = generate_text(
|
| 372 |
self.tok,
|
| 373 |
self.gen,
|
| 374 |
user_query,
|
| 375 |
self.prompts["final_answer"],
|
| 376 |
-
|
| 377 |
)
|
| 378 |
-
end = time()
|
| 379 |
-
log.debug(f"final resp took {(end - start):.1f} seconds...")
|
| 380 |
|
| 381 |
-
|
|
|
|
|
|
|
|
|
|
| 382 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
return HyDeRAGFusion(
|
| 387 |
-
embed_model, generator_model, bitsandbytesconfig=bitsandbytesconfig
|
| 388 |
-
)
|
| 389 |
|
| 390 |
|
| 391 |
@click.command(context_settings=dict(show_default=True))
|
| 392 |
@click.option(
|
| 393 |
-
"--
|
| 394 |
-
|
| 395 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 396 |
)
|
| 397 |
@click.option(
|
| 398 |
-
"--
|
| 399 |
-
default=
|
| 400 |
-
help="
|
| 401 |
)
|
| 402 |
@click.option("--n-variants", default=3, help="no. of query variants")
|
| 403 |
@click.option(
|
|
@@ -408,97 +374,57 @@ def initial_setup(embed_model, generator_model, bitsandbytesconfig=None):
|
|
| 408 |
@click.option(
|
| 409 |
"--top-k-retrieve", default=3, help="top `k` chunks to retrieve after RRF"
|
| 410 |
)
|
| 411 |
-
@click.option("--temperature", default=0.
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
)
|
| 418 |
-
def main(
|
| 419 |
-
embed_model,
|
| 420 |
-
generator_llm_model,
|
| 421 |
n_variants,
|
| 422 |
top_k_per_variant,
|
| 423 |
top_k_retrieve,
|
| 424 |
temperature,
|
| 425 |
-
max_new_tokens,
|
| 426 |
-
faiss_index_kwargs,
|
| 427 |
):
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
start = time()
|
| 432 |
-
hrf = initial_setup(embed_model, generator_llm_model)
|
| 433 |
-
end = time()
|
| 434 |
-
msg = f"init took {(end - start):.1f} seconds"
|
| 435 |
-
log.debug(msg)
|
| 436 |
-
st.write(msg)
|
| 437 |
-
|
| 438 |
-
st.set_page_config(page_title="RAG HYDE")
|
| 439 |
-
st.header("Ask Questions")
|
| 440 |
-
|
| 441 |
-
state = st.session_state
|
| 442 |
-
if "uploaded_names" not in state:
|
| 443 |
-
state.uploaded_names = []
|
| 444 |
-
|
| 445 |
-
pdfs = st.file_uploader(
|
| 446 |
-
"Upload your PDF(s)", type="pdf", accept_multiple_files=True, key="upload"
|
| 447 |
)
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
f"corpus embeddings shape: {hrf.corpus_emb.shape}, computed in {end - start:.1f} seconds"
|
| 462 |
-
)
|
| 463 |
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
st.write("upload data to query")
|
| 468 |
|
| 469 |
-
query
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
llm_kwargs = {
|
| 473 |
-
"temperature": temperature,
|
| 474 |
-
"max_new_tokens": max_new_tokens,
|
| 475 |
-
}
|
| 476 |
-
doc_ids, tasks = hrf.retrieve(
|
| 477 |
-
query,
|
| 478 |
int(n_variants),
|
| 479 |
int(top_k_per_variant),
|
| 480 |
int(top_k_retrieve),
|
| 481 |
-
|
| 482 |
)
|
| 483 |
-
|
|
|
|
|
|
|
| 484 |
end = time()
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
{
|
| 491 |
-
"source": f"{hrf.meta[doc_id]['file']}:{hrf.meta[doc_id]['chunk_id']}",
|
| 492 |
-
"content": doc,
|
| 493 |
-
}
|
| 494 |
-
for doc_id, doc in zip(doc_ids, docs)
|
| 495 |
-
]
|
| 496 |
-
st.json(sources[:3])
|
| 497 |
|
| 498 |
|
| 499 |
if __name__ == "__main__":
|
| 500 |
-
|
| 501 |
-
# 'n_cells': 20,
|
| 502 |
-
# 'n_probe': 8
|
| 503 |
-
# }
|
| 504 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import string
|
| 2 |
+
import faiss
|
| 3 |
import yaml
|
| 4 |
import re
|
| 5 |
+
import sys
|
|
|
|
| 6 |
import torch
|
|
|
|
| 7 |
import click
|
| 8 |
|
|
|
|
|
|
|
| 9 |
from time import time
|
| 10 |
+
from random import shuffle
|
| 11 |
+
from functools import lru_cache
|
| 12 |
from sentence_transformers import SentenceTransformer
|
| 13 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 14 |
+
from langchain_text_splitters import MarkdownTextSplitter
|
| 15 |
+
|
| 16 |
+
from src.config import PROMPTS_FILEPATH, MODEL_COMBOS, log
|
| 17 |
+
from src.utils import (
|
| 18 |
+
empty_cache,
|
| 19 |
+
find_think_tag_in_each_row,
|
| 20 |
+
build_corpus,
|
| 21 |
+
reciprocal_rank_fusion,
|
| 22 |
+
clean_rewrite_resp,
|
| 23 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
def generate_text(
|
| 27 |
+
tokenizer, model, user_prompts, system_prompt=None, model_name="", **llm_kwargs
|
| 28 |
+
):
|
| 29 |
+
assert model_name, "pass on generative model name"
|
| 30 |
if system_prompt is None or "":
|
| 31 |
system_prompt = "You are a helpful assistant."
|
| 32 |
|
|
|
|
| 41 |
for user_prompt in user_prompts
|
| 42 |
]
|
| 43 |
|
| 44 |
+
if "mlx" in model_name.lower() and sys.platform == "darwin":
|
| 45 |
+
from mlx_lm import generate
|
|
|
|
| 46 |
|
| 47 |
+
texts = [
|
| 48 |
+
tokenizer.apply_chat_template(
|
| 49 |
+
message,
|
| 50 |
+
tokenize=False,
|
| 51 |
+
add_generation_prompt=True,
|
| 52 |
+
enable_thinking=False,
|
| 53 |
+
)
|
| 54 |
+
for message in messages
|
| 55 |
+
]
|
| 56 |
+
responses = [
|
| 57 |
+
generate(
|
| 58 |
+
model,
|
| 59 |
+
tokenizer,
|
| 60 |
+
prompt=text,
|
| 61 |
+
verbose=False,
|
| 62 |
+
max_tokens=llm_kwargs.pop("max_new_tokens", 32768),
|
| 63 |
+
)
|
| 64 |
+
for text in texts
|
| 65 |
+
]
|
| 66 |
+
else:
|
| 67 |
+
texts = tokenizer.apply_chat_template(
|
| 68 |
+
messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
|
| 69 |
+
)
|
| 70 |
+
model_inputs = tokenizer(
|
| 71 |
+
texts, return_tensors="pt", truncation=True, padding=True
|
| 72 |
+
).to(model.device)
|
| 73 |
+
|
| 74 |
+
with torch.no_grad():
|
| 75 |
+
generated_ids = model.generate(
|
| 76 |
+
**model_inputs,
|
| 77 |
+
max_new_tokens=llm_kwargs.pop("max_new_tokens", 32768),
|
| 78 |
+
temperature=llm_kwargs.pop("temperature", 0.7),
|
| 79 |
+
top_p=llm_kwargs.pop("top_p", 0.8),
|
| 80 |
+
top_k=llm_kwargs.pop("top_k", 20),
|
| 81 |
+
min_p=llm_kwargs.pop("min_p", 0),
|
| 82 |
+
**llm_kwargs,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
output_ids = generated_ids[:, model_inputs.input_ids.shape[1] :]
|
| 86 |
+
idxs = find_think_tag_in_each_row(output_ids)
|
| 87 |
+
thinking_contents = [
|
| 88 |
+
tokenizer.decode(output_ids[i][:idx], skip_special_tokens=True).strip("\n")
|
| 89 |
+
for i, idx in enumerate(idxs)
|
| 90 |
+
]
|
| 91 |
+
contents = [
|
| 92 |
+
tokenizer.decode(output_ids[i][idx:], skip_special_tokens=True).strip("\n")
|
| 93 |
+
for i, idx in enumerate(idxs)
|
| 94 |
+
]
|
| 95 |
+
responses = [
|
| 96 |
+
f"{think_resp}{cont}"
|
| 97 |
+
for think_resp, cont in zip(thinking_contents, contents)
|
| 98 |
+
]
|
| 99 |
|
| 100 |
+
return responses[0] if len(user_prompts) == 1 else responses
|
| 101 |
|
| 102 |
+
|
| 103 |
+
def load_models(embed_model_name: str, gen_model_name: str, device: str = None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
# This will take some time to run for the first time if the model(s) don't exist locally.
|
| 105 |
if not device:
|
| 106 |
+
if torch.cuda.is_available():
|
| 107 |
+
device = "cuda"
|
| 108 |
+
elif torch.mps.is_available():
|
| 109 |
+
device = "mps"
|
| 110 |
+
else:
|
| 111 |
+
device = "cpu"
|
| 112 |
+
|
| 113 |
+
dtype = torch.bfloat16 if device == "cuda" else torch.float16
|
| 114 |
+
if device != "mps" or (device == "mps" and "mlx" not in gen_model_name.lower()):
|
| 115 |
+
tok = AutoTokenizer.from_pretrained(gen_model_name, padding_side="left")
|
| 116 |
+
# sometimes loading an AWQ model on my local machine fails for the first time
|
| 117 |
+
try:
|
| 118 |
+
gen = AutoModelForCausalLM.from_pretrained(
|
| 119 |
+
gen_model_name, dtype=dtype, device_map=device
|
| 120 |
+
).eval()
|
| 121 |
+
except ImportError:
|
| 122 |
+
gen = AutoModelForCausalLM.from_pretrained(
|
| 123 |
+
gen_model_name, dtype=dtype, device_map=device
|
| 124 |
+
).eval()
|
| 125 |
+
else:
|
| 126 |
+
from mlx_lm import load
|
| 127 |
+
|
| 128 |
+
gen, tok = load(gen_model_name)
|
| 129 |
+
|
| 130 |
embedder = SentenceTransformer(
|
| 131 |
embed_model_name,
|
| 132 |
device=device,
|
| 133 |
+
model_kwargs={"dtype": dtype},
|
| 134 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
return embedder, tok, gen
|
| 136 |
|
| 137 |
|
| 138 |
def make_query_variants(
|
| 139 |
+
tokenizer, model, query: str, prompt: str, n: int = 3, model_name="", **llm_kwargs
|
| 140 |
):
|
| 141 |
+
# instructions = f"\n\n(Now give me at least {n} diverse variations of user query in the same language as the user provided query)"
|
| 142 |
+
# query += instructions
|
| 143 |
+
resp = generate_text(
|
| 144 |
+
tokenizer, model, query.format(n=n), prompt, model_name=model_name, **llm_kwargs
|
| 145 |
+
)
|
| 146 |
clean_resp = re.sub(r"^\d+\.\s*", "", resp, flags=re.MULTILINE).split("\n")
|
| 147 |
+
queries = [q.strip() for q in clean_resp if q.strip()]
|
| 148 |
+
return [query.lower().strip()] + sorted(
|
| 149 |
+
set(map(lambda x: str.lower(x).strip(), queries))
|
| 150 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
|
| 153 |
def transform_query(
|
| 154 |
+
tokenizer, model, query: str, rewrite_prompt: str, model_name="", **llm_kwargs
|
| 155 |
) -> dict:
|
| 156 |
"""split the query into things to search and actions to take"""
|
| 157 |
+
resp = generate_text(
|
| 158 |
+
tokenizer, model, query, rewrite_prompt, model_name=model_name, **llm_kwargs
|
| 159 |
+
)
|
| 160 |
try:
|
| 161 |
resp = clean_rewrite_resp(resp)
|
| 162 |
except:
|
|
|
|
| 171 |
rewrite_prompt,
|
| 172 |
variants_prompt,
|
| 173 |
n_variations=3,
|
| 174 |
+
gen_model_name="",
|
| 175 |
**llm_kwargs,
|
| 176 |
):
|
| 177 |
# make variations for the original query as is
|
|
|
|
| 181 |
orig_query.strip(),
|
| 182 |
variants_prompt,
|
| 183 |
n_variations,
|
| 184 |
+
gen_model_name,
|
| 185 |
**llm_kwargs,
|
| 186 |
+
)[: n_variations + 1]
|
| 187 |
+
tr_q = transform_query(
|
| 188 |
+
tokenizer, model, orig_query.strip(), rewrite_prompt, gen_model_name
|
| 189 |
)
|
| 190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
# transformed query might have multiple things to search and tasks to perform depending on user query
|
| 192 |
# recursively get variations for each of the search queries but keep the tasks as is.
|
| 193 |
tasks = []
|
|
|
|
| 201 |
search_result,
|
| 202 |
variants_prompt,
|
| 203 |
n_variations,
|
| 204 |
+
gen_model_name,
|
| 205 |
**llm_kwargs,
|
| 206 |
)
|
| 207 |
)
|
| 208 |
|
|
|
|
|
|
|
|
|
|
| 209 |
# keep the original user query as is (if in case LLM messes up the original query) and pick some after shuffling the rest
|
| 210 |
+
q, qq = queries[0], queries[1:]
|
| 211 |
+
shuffle(qq)
|
| 212 |
+
queries = [q] + sorted(
|
| 213 |
+
set(map(lambda x: str.lower(x).strip(string.punctuation), qq[:n_variations]))
|
| 214 |
+
)
|
| 215 |
+
tasks = sorted(set(map(lambda x: str.lower(x).strip(string.punctuation), tasks)))
|
| 216 |
|
| 217 |
return queries, tasks
|
| 218 |
|
| 219 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
class HyDeRAGFusion:
|
| 221 |
def __init__(
|
| 222 |
self,
|
| 223 |
embed_model: str,
|
| 224 |
generator_llm_model: str,
|
| 225 |
+
embed_batch_size: int = 8,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
):
|
| 227 |
self.embed_batch_size = embed_batch_size
|
| 228 |
+
self.gen_model_name = generator_llm_model
|
|
|
|
|
|
|
| 229 |
self.embedder, self.tok, self.gen = load_models(
|
| 230 |
+
embed_model, generator_llm_model
|
| 231 |
)
|
| 232 |
+
self.text_splitter = MarkdownTextSplitter(chunk_overlap=450, chunk_size=3000)
|
| 233 |
with open(PROMPTS_FILEPATH) as fl:
|
| 234 |
self.prompts = yaml.safe_load(fl)
|
| 235 |
|
| 236 |
+
def get_embeddings(self, texts, **kwargs):
|
| 237 |
+
log.debug(f"batching size: {len(texts)} aka {self.embed_batch_size}")
|
| 238 |
+
return self.embedder.encode(
|
| 239 |
+
texts, batch_size=self.embed_batch_size, normalize_embeddings=True, **kwargs
|
| 240 |
)
|
| 241 |
+
|
| 242 |
+
@lru_cache(maxsize=2)
|
| 243 |
+
def preprocess_pdfs(self, pdfs, **data_load_kwargs):
|
| 244 |
+
log.debug(f"\n\n{'@@@@' * 20}\n\n preprocessing {pdfs=}")
|
| 245 |
+
empty_cache()
|
| 246 |
+
self.dataset = build_corpus(pdfs, self.text_splitter, **data_load_kwargs)
|
| 247 |
+
empty_cache()
|
| 248 |
+
self.dataset = self.dataset.map(
|
| 249 |
+
lambda x: {
|
| 250 |
+
"embeddings": self.get_embeddings(
|
| 251 |
+
x["chunk"], prompt_name="query", show_progress_bar=False
|
| 252 |
+
)
|
| 253 |
+
},
|
| 254 |
+
batched=True,
|
| 255 |
batch_size=self.embed_batch_size,
|
| 256 |
+
)
|
| 257 |
+
empty_cache()
|
| 258 |
+
self.dataset.add_faiss_index(
|
| 259 |
+
"embeddings", metric_type=faiss.METRIC_INNER_PRODUCT
|
| 260 |
)
|
| 261 |
|
| 262 |
+
def get_filtered_entries(self, idxs):
|
| 263 |
+
# We need to drop the index before filtering/selecting the desired indices and re-add it later
|
| 264 |
+
# Since it's FAISS and we index very little data, it's not noticeable
|
| 265 |
+
self.dataset.drop_index("embeddings")
|
| 266 |
+
entries = self.dataset.select(idxs)
|
| 267 |
+
self.dataset.add_faiss_index(
|
| 268 |
+
"embeddings", metric_type=faiss.METRIC_INNER_PRODUCT
|
| 269 |
+
)
|
| 270 |
+
return entries
|
| 271 |
|
| 272 |
def retrieve(
|
| 273 |
+
self, query, n_variants=3, top_k_per_variant=5, top_k_retrieve=3, **llm_kwargs
|
| 274 |
):
|
|
|
|
|
|
|
| 275 |
queries, tasks = aggregate_queries_and_tasks(
|
| 276 |
self.tok,
|
| 277 |
self.gen,
|
|
|
|
| 279 |
self.prompts["rewrite"],
|
| 280 |
self.prompts["variants"],
|
| 281 |
n_variants,
|
| 282 |
+
self.gen_model_name,
|
| 283 |
**llm_kwargs,
|
| 284 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
hyde_docs = generate_text(
|
| 286 |
+
self.tok,
|
| 287 |
+
self.gen,
|
| 288 |
+
queries,
|
| 289 |
+
self.prompts["hyde"],
|
| 290 |
+
self.gen_model_name,
|
| 291 |
+
**llm_kwargs,
|
| 292 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
chunks = []
|
| 294 |
for hyde_doc in hyde_docs:
|
| 295 |
chunks.extend(self.text_splitter.split_text(hyde_doc))
|
| 296 |
+
q_emb = self.get_embeddings(chunks)
|
| 297 |
+
matches = self.dataset.get_nearest_examples_batch(
|
| 298 |
+
"embeddings", q_emb, top_k_per_variant
|
| 299 |
)
|
| 300 |
+
indices = [match["id"] for match in matches.total_examples]
|
| 301 |
+
top_idxs = reciprocal_rank_fusion(indices, top_k_retrieve)
|
| 302 |
+
return top_idxs, tasks
|
| 303 |
|
| 304 |
+
def answer(self, query, idxs, tasks, max_ctx_chars=32768):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
total, text, prompt_length = 0, "", 10000
|
| 306 |
sep = "\n\n-----\n\n"
|
| 307 |
+
tasks = ", ".join(tasks) if tasks else ""
|
| 308 |
+
log.debug("filtering entries")
|
| 309 |
+
entries = self.get_filtered_entries(idxs)
|
| 310 |
+
for chunk in entries["chunk"]:
|
| 311 |
+
ctx = f"{sep}\n\n{chunk}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
if total + len(ctx) + len(tasks) + len(sep) + prompt_length > max_ctx_chars:
|
| 313 |
+
log.warning("context overflow")
|
| 314 |
break
|
| 315 |
|
| 316 |
text += ctx
|
|
|
|
| 318 |
|
| 319 |
text += f"{sep}{tasks}"
|
| 320 |
|
|
|
|
| 321 |
instruction = "go ahead and answer!"
|
| 322 |
user_query = f"\nq: {query}\n\nctx:{text}" + f"\n\n{instruction}\n\n"
|
|
|
|
|
|
|
| 323 |
resp = generate_text(
|
| 324 |
self.tok,
|
| 325 |
self.gen,
|
| 326 |
user_query,
|
| 327 |
self.prompts["final_answer"],
|
| 328 |
+
self.gen_model_name,
|
| 329 |
)
|
|
|
|
|
|
|
| 330 |
|
| 331 |
+
sources = ""
|
| 332 |
+
for idx, entry in enumerate(entries):
|
| 333 |
+
source = f'<h2 style="color: cyan;">Source {idx + 1} :: {entry["file"]}:{entry["chunk_id"]}</h2>'
|
| 334 |
+
sources += f"{sep}{source}\n\n{entry['chunk']}"
|
| 335 |
|
| 336 |
+
return resp, sources.replace("```", "`")
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def initial_setup(model_combo_key):
|
| 340 |
+
models = MODEL_COMBOS[model_combo_key]
|
| 341 |
+
hrf = HyDeRAGFusion(models["embed_model"], models["gen_model"])
|
| 342 |
+
hrf._model_combo_key = model_combo_key
|
| 343 |
+
return hrf
|
| 344 |
|
| 345 |
+
|
| 346 |
+
HRF = None
|
|
|
|
|
|
|
|
|
|
| 347 |
|
| 348 |
|
| 349 |
@click.command(context_settings=dict(show_default=True))
|
| 350 |
@click.option(
|
| 351 |
+
"--pdfs",
|
| 352 |
+
multiple=True,
|
| 353 |
+
type=click.Path(exists=True),
|
| 354 |
+
help="list of PDF filepaths to extract text from",
|
| 355 |
+
)
|
| 356 |
+
@click.option("--query", help="user query")
|
| 357 |
+
@click.option(
|
| 358 |
+
"--model-combo-key",
|
| 359 |
+
type=click.Choice(["linux", "HF-mid"]),
|
| 360 |
+
default="linux",
|
| 361 |
+
help="embedder and generator llm models combination to load (see config.py)",
|
| 362 |
)
|
| 363 |
@click.option(
|
| 364 |
+
"--fast-extract/--no-fast-extract",
|
| 365 |
+
default=False,
|
| 366 |
+
help="Extract markdown text quickly (uses pymupdf if set, else docling if available)",
|
| 367 |
)
|
| 368 |
@click.option("--n-variants", default=3, help="no. of query variants")
|
| 369 |
@click.option(
|
|
|
|
| 374 |
@click.option(
|
| 375 |
"--top-k-retrieve", default=3, help="top `k` chunks to retrieve after RRF"
|
| 376 |
)
|
| 377 |
+
@click.option("--temperature", default=0.7, help="LLM Model Temperature")
|
| 378 |
+
def ask(
|
| 379 |
+
pdfs,
|
| 380 |
+
query,
|
| 381 |
+
model_combo_key,
|
| 382 |
+
fast_extract,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
n_variants,
|
| 384 |
top_k_per_variant,
|
| 385 |
top_k_retrieve,
|
| 386 |
temperature,
|
|
|
|
|
|
|
| 387 |
):
|
| 388 |
+
pdfs = tuple(sorted(pdfs))
|
| 389 |
+
log.debug(
|
| 390 |
+
f"{pdfs=}, {query=}, {model_combo_key=}, {fast_extract=}, {n_variants=}, {top_k_per_variant=}, {top_k_retrieve=}, {temperature=}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 391 |
)
|
| 392 |
+
global HRF
|
| 393 |
+
if HRF is None or HRF._model_combo_key != model_combo_key:
|
| 394 |
+
if HRF is not None:
|
| 395 |
+
log.debug("deleting HRF object")
|
| 396 |
+
del HRF
|
| 397 |
+
log.debug("emptying cache")
|
| 398 |
+
empty_cache()
|
| 399 |
+
log.debug(f"\n\n{'=:-:' * 20}\n\n initializing")
|
| 400 |
+
start = time()
|
| 401 |
+
HRF = initial_setup(model_combo_key)
|
| 402 |
+
end = time()
|
| 403 |
+
msg = f"init took {(end - start):.1f} seconds"
|
| 404 |
+
log.debug(msg)
|
|
|
|
|
|
|
| 405 |
|
| 406 |
+
start = time()
|
| 407 |
+
if pdfs:
|
| 408 |
+
HRF.preprocess_pdfs(pdfs, fast_extract=fast_extract)
|
|
|
|
| 409 |
|
| 410 |
+
if query and query.strip():
|
| 411 |
+
top_idxs, tasks = HRF.retrieve(
|
| 412 |
+
query.strip(),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
int(n_variants),
|
| 414 |
int(top_k_per_variant),
|
| 415 |
int(top_k_retrieve),
|
| 416 |
+
temperature=temperature,
|
| 417 |
)
|
| 418 |
+
|
| 419 |
+
log.debug("retrieving")
|
| 420 |
+
resp, sources = HRF.answer(query, top_idxs, tasks)
|
| 421 |
end = time()
|
| 422 |
+
final_response = f"\nSearch took {(end - start):.1f} seconds\n\n{resp}{sources}"
|
| 423 |
+
log.debug(final_response)
|
| 424 |
+
return final_response
|
| 425 |
+
|
| 426 |
+
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
|
| 428 |
|
| 429 |
if __name__ == "__main__":
|
| 430 |
+
ask()
|
|
|
|
|
|
|
|
|
|
|
|
src/prompts.yaml
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
rewrite: >
|
| 2 |
You are a professional content writer and editor who deeply pays attention to user's query & intention. You strictly reply ONLY in JSON.
|
| 3 |
|
| 4 |
-
Your mission is to analyze user input and intention and transform it
|
| 5 |
|
| 6 |
The user input can be a query or a statement. There can be multiple of them. And sometimes the input also contains actions to be taken depending on the query/statement.
|
| 7 |
|
|
@@ -109,7 +109,7 @@ rewrite: >
|
|
| 109 |
-----------
|
| 110 |
|
| 111 |
|
| 112 |
-
Remember to
|
| 113 |
|
| 114 |
|
| 115 |
user:
|
|
@@ -117,10 +117,10 @@ rewrite: >
|
|
| 117 |
|
| 118 |
variants: >
|
| 119 |
You are a multilingual professional content writer and editor who deeply pays attention to user's query & intention.
|
| 120 |
-
|
| 121 |
-
Your goal is to transform the given query into diverse search queries keeping the user's context & intention in mind.
|
| 122 |
|
| 123 |
-
|
|
|
|
|
|
|
| 124 |
|
| 125 |
You MUST respond with only what's asked. Avoid explanations or verbose information of your actions.
|
| 126 |
|
|
@@ -130,59 +130,59 @@ variants: >
|
|
| 130 |
--------------
|
| 131 |
|
| 132 |
|
| 133 |
-
user:
|
| 134 |
|
| 135 |
assistant:
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
|
| 142 |
|
| 143 |
-
user:
|
| 144 |
|
| 145 |
assistant:
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
|
| 151 |
|
| 152 |
-
user:
|
| 153 |
|
| 154 |
assistant:
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
|
| 159 |
|
| 160 |
-
user:
|
| 161 |
|
| 162 |
assistant:
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
"Welche Faktoren könnten die Veränderungen des EBITDA beeinflussen?",
|
| 168 |
|
| 169 |
|
| 170 |
-
user:
|
| 171 |
|
| 172 |
assistant:
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
|
| 178 |
|
| 179 |
-
user:
|
| 180 |
|
| 181 |
assistant:
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
|
|
|
| 186 |
|
| 187 |
--------------
|
| 188 |
|
|
@@ -193,14 +193,14 @@ variants: >
|
|
| 193 |
hyde: >
|
| 194 |
You are a professional editor at a prestigious international media organization.
|
| 195 |
Given user's query, write a neutral, self-contained paragraph ABSOLUTELY GROUNDED IN FACTS and established sources. Avoid fluff. Include likely key terms and entities. 120-180 words.
|
| 196 |
-
You write content in the same language as the user query
|
| 197 |
|
| 198 |
Examples:
|
| 199 |
--------
|
| 200 |
|
| 201 |
-
user: Quelle est le niveau actuel de l'engagement de
|
| 202 |
|
| 203 |
-
assistant:
|
| 204 |
|
| 205 |
|
| 206 |
user: BMW Group expansion into southern Asia
|
|
@@ -223,8 +223,8 @@ final_answer: >
|
|
| 223 |
You are a journalist at a media organization. Your main specializations include fact checking, accurate information retrieval from sources among others.
|
| 224 |
|
| 225 |
YOU ALWAYS ADHERE TO THE FOLLOWING INSTRUCTIONS:
|
| 226 |
-
- When given a user query `q` and a context `ctx`, your goal is to answer `q` FROM ONLY WITHIN the given context `ctx
|
| 227 |
-
- You reply in the same language as user
|
| 228 |
- You do not state anything that is not present within `ctx`. NEVER GUESS.
|
| 229 |
- ALWAYS GROUND YOUR TRUTH based only on what was provided within the context `ctx`.
|
| 230 |
- If you believe `q` has nothing to do with `ctx`, simply state "I don't know" (or its equivalent in the user query language) instead of guessing.
|
|
@@ -238,7 +238,7 @@ final_answer: >
|
|
| 238 |
ctx:
|
| 239 |
-----
|
| 240 |
|
| 241 |
-
|
| 242 |
|
| 243 |
go ahead and answer!
|
| 244 |
|
|
@@ -267,8 +267,7 @@ final_answer: >
|
|
| 267 |
Zur Erhöhung der Transparenz und Effektivität kooperiert sie verstärkt mit öffentlichen und privaten Organisationen.
|
| 268 |
Sicherheit ist durch das „Security by Design“-Prinzip fester Bestandteil im Entwicklungsprozess neuer Produkte und Informationssysteme.
|
| 269 |
Es werden intensive und obligatorische digitale Sicherheitstests durchgeführt, um Schwachstellen systematisch aufzudecken.
|
| 270 |
-
Alle sicherheitsrelevanten Abteilungen wurden unter dem Dach der Deutschen Telekom Security zusammengeführt. Mit diesem End-to-End-Sicherheitsportfolio zielt das Unternehmen darauf ab, Marktanteile zu gewinnen und im Rahmen der Megatrends Internet der Dinge und Industrie 4.0 neue Sicherheitskonzepte zu etablieren.
|
| 271 |
-
Zudem wird das Partner-Ökosystem im Bereich Cybersicherheit kontinuierlich ausgebaut, und auf der Unternehmenswebsite wird fortlaufend über aktuelle Entwicklungen in Datenschutz und Datensicherheit berichtet.
|
| 272 |
|
| 273 |
|
| 274 |
=====
|
|
|
|
| 1 |
rewrite: >
|
| 2 |
You are a professional content writer and editor who deeply pays attention to user's query & intention. You strictly reply ONLY in JSON.
|
| 3 |
|
| 4 |
+
Your mission is to analyze user input and intention and transform it **IN THE SAME LANGUAGE** as the input query to make it search engine optimised by determining the appropriate context.
|
| 5 |
|
| 6 |
The user input can be a query or a statement. There can be multiple of them. And sometimes the input also contains actions to be taken depending on the query/statement.
|
| 7 |
|
|
|
|
| 109 |
-----------
|
| 110 |
|
| 111 |
|
| 112 |
+
Remember to **NOT** answer the user's question but only transform it and **IN THE SAME LANGUAGE** as given.
|
| 113 |
|
| 114 |
|
| 115 |
user:
|
|
|
|
| 117 |
|
| 118 |
variants: >
|
| 119 |
You are a multilingual professional content writer and editor who deeply pays attention to user's query & intention.
|
|
|
|
|
|
|
| 120 |
|
| 121 |
+
Your goal is to transform the given query into at least {n} diverse queries as long as they're related to the topic of the original query.
|
| 122 |
+
|
| 123 |
+
You MUST ALWAYS respond **IN THE SAME LANGUAGE** as the user.
|
| 124 |
|
| 125 |
You MUST respond with only what's asked. Avoid explanations or verbose information of your actions.
|
| 126 |
|
|
|
|
| 130 |
--------------
|
| 131 |
|
| 132 |
|
| 133 |
+
user: EBITDA last quarter
|
| 134 |
|
| 135 |
assistant:
|
| 136 |
+
What was the EBITDA for the quarter ending March?
|
| 137 |
+
How has the company's performance changed from the previous quarter?
|
| 138 |
+
What is the current trend of EBITDA growth over the past few quarters?
|
| 139 |
+
Which companies have had similar EBITDA performance recently?
|
| 140 |
+
What factors might be influencing the changes in EBITDA?
|
| 141 |
|
| 142 |
|
| 143 |
+
user: what are the growing concerns of the middle class?
|
| 144 |
|
| 145 |
assistant:
|
| 146 |
+
How are the economic challenges impacting the middle class?
|
| 147 |
+
How's middle class people coping up with social and political pressures?
|
| 148 |
+
What are the long-term implications for the middle class's well-being?
|
| 149 |
+
What are the current trends and future prospects for the middle class?
|
| 150 |
|
| 151 |
|
| 152 |
+
user: Capital of France
|
| 153 |
|
| 154 |
assistant:
|
| 155 |
+
What is the capital city of France?
|
| 156 |
+
How is Paris known internationally?
|
| 157 |
+
Where is the capital of France located?
|
| 158 |
|
| 159 |
|
| 160 |
+
user: Wer ist der aktuelle CEO dieses Unternehmens?
|
| 161 |
|
| 162 |
assistant:
|
| 163 |
+
Wer ist der Geschäftsführer dieses Unternehmens?
|
| 164 |
+
Wie sieht das Organigramm des Unternehmens aus?
|
| 165 |
+
Wann wurde der aktuelle CEO ernannt?
|
| 166 |
+
Wer war in der Vergangenheit CEO dieses Unternehmens?
|
|
|
|
| 167 |
|
| 168 |
|
| 169 |
+
user: Wie reitet man ein Pferd?
|
| 170 |
|
| 171 |
assistant:
|
| 172 |
+
Wo fängt man an, ein Pferd zu lernen?
|
| 173 |
+
Wo kann man Reiten lernen?
|
| 174 |
+
Grundlagen des Reitens
|
| 175 |
+
Wo kann ich Reiten lernen?
|
| 176 |
|
| 177 |
|
| 178 |
+
user: Effets de la gravité dans l'espace
|
| 179 |
|
| 180 |
assistant:
|
| 181 |
+
Comment fonctionne la gravité ?
|
| 182 |
+
Qu'est-ce que la gravité ?
|
| 183 |
+
Quelle sensation procure l'apesanteur dans l'espace ?
|
| 184 |
+
Comment la gravité affecte-t-elle l'espace-temps ?
|
| 185 |
+
L'attraction gravitationnelle des objets dans l'espace
|
| 186 |
|
| 187 |
--------------
|
| 188 |
|
|
|
|
| 193 |
hyde: >
|
| 194 |
You are a professional editor at a prestigious international media organization.
|
| 195 |
Given user's query, write a neutral, self-contained paragraph ABSOLUTELY GROUNDED IN FACTS and established sources. Avoid fluff. Include likely key terms and entities. 120-180 words.
|
| 196 |
+
You MUST ALWAYS write content in the same language as the user query.
|
| 197 |
|
| 198 |
Examples:
|
| 199 |
--------
|
| 200 |
|
| 201 |
+
user: Quelle est le niveau actuel de l'engagement de Lego avec le développement durable
|
| 202 |
|
| 203 |
+
assistant: Lego, la plus grande entreprise d'électricité et d'information au monde, a lancé un programme d'engagement durable en 2015 pour atteindre ses objectifs de développement durable. Ce programme visait à réduire son émission de gaz à effet de serre (CO2) par rapport aux niveaux de 1990, ainsi que les émissions de déchets et de produits chimiques. Le groupe a mis en place des politiques d'énergie basées sur la transition verte, comme la production de biogaz, l'utilisation de technologies solaires et l'innovation dans la gestion des ressources naturelles. L'objectif principal était de se démarquer du marché mondial en termes de performance énergétique et environnementale.
|
| 204 |
|
| 205 |
|
| 206 |
user: BMW Group expansion into southern Asia
|
|
|
|
| 223 |
You are a journalist at a media organization. Your main specializations include fact checking, accurate information retrieval from sources among others.
|
| 224 |
|
| 225 |
YOU ALWAYS ADHERE TO THE FOLLOWING INSTRUCTIONS:
|
| 226 |
+
- When given a user query `q` and a context `ctx`, your goal is to answer `q` FROM ONLY WITHIN the given context `ctx`.
|
| 227 |
+
- You MUST ALWAYS reply in the same language as user's.
|
| 228 |
- You do not state anything that is not present within `ctx`. NEVER GUESS.
|
| 229 |
- ALWAYS GROUND YOUR TRUTH based only on what was provided within the context `ctx`.
|
| 230 |
- If you believe `q` has nothing to do with `ctx`, simply state "I don't know" (or its equivalent in the user query language) instead of guessing.
|
|
|
|
| 238 |
ctx:
|
| 239 |
-----
|
| 240 |
|
| 241 |
+
This will enable us to guarantee transparency and comparability in the validation and measurement of our targets and, at the same time, ensure they are in line with the latest scientific findings. ↗ Carbon emissions ↗ Control parameters such as ↗ carbon emissions over the entire prod - uct life cycle are important ↗ Performance indicators during the de - velopment phase of our vehicle projects. The Board of Manage - ment receives and discusses a status report on sustainability every quarter and derives appropriate measures as required. The BMW Group is actively working on numerous projects and initiatives to improve the framework conditions for electromobil- ity, including the expansion of charging infrastructure on a broad basis. The ambitious goals of the Paris Climate Agreement are designed to tackle climate change in the transport sector, requir - ing a combination of modern drive technologies that are closely aligned with customer needs and different mobility requirements around the world. In addition to all - electric models, plug - in hybrids and modern combustion engine technologies also make an im - portant contribution to the reduction of global CO2 emissions. The BMW Group is also continuously forging ahead with its work with hydrogen. ↗ Products ESG criteria are built into individual market strategies across our global organisation. Best practices in the fields of environmental protection, social sustainability, corporate citizenship and gov
|
| 242 |
|
| 243 |
go ahead and answer!
|
| 244 |
|
|
|
|
| 267 |
Zur Erhöhung der Transparenz und Effektivität kooperiert sie verstärkt mit öffentlichen und privaten Organisationen.
|
| 268 |
Sicherheit ist durch das „Security by Design“-Prinzip fester Bestandteil im Entwicklungsprozess neuer Produkte und Informationssysteme.
|
| 269 |
Es werden intensive und obligatorische digitale Sicherheitstests durchgeführt, um Schwachstellen systematisch aufzudecken.
|
| 270 |
+
Alle sicherheitsrelevanten Abteilungen wurden unter dem Dach der Deutschen Telekom Security zusammengeführt. Mit diesem End-to-End-Sicherheitsportfolio zielt das Unternehmen darauf ab, Marktanteile zu gewinnen und im Rahmen der Megatrends Internet der Dinge und Industrie 4.0 neue Sicherheitskonzepte zu etablieren. Zudem wird das Partner-Ökosystem im Bereich Cybersicherheit kontinuierlich ausgebaut, und auf der Unternehmenswebsite wird fortlaufend über aktuelle Entwicklungen in Datenschutz und Datensicherheit berichtet.
|
|
|
|
| 271 |
|
| 272 |
|
| 273 |
=====
|
src/utils.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import pymupdf, pymupdf4llm
|
| 5 |
+
|
| 6 |
+
from ast import literal_eval
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from datasets import Dataset
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
from docling.document_converter import DocumentConverter
|
| 11 |
+
|
| 12 |
+
from src.config import log
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def empty_cache():
|
| 17 |
+
gc.collect()
|
| 18 |
+
if torch.cuda.is_available():
|
| 19 |
+
torch.cuda.empty_cache()
|
| 20 |
+
elif torch.mps.is_available():
|
| 21 |
+
torch.mps.empty_cache()
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def extract_text(file, markdown=False, backend="pymupdf", **kwargs):
|
| 25 |
+
if backend == "pymupdf":
|
| 26 |
+
if not markdown:
|
| 27 |
+
with pymupdf.open(file, filetype="pdf") as doc:
|
| 28 |
+
return "\n".join(page.get_text(**kwargs) for page in doc)
|
| 29 |
+
else:
|
| 30 |
+
log.debug("\n\n using pymupdf4llm \n\n")
|
| 31 |
+
return pymupdf4llm.to_markdown(file, show_progress=True, **kwargs)
|
| 32 |
+
|
| 33 |
+
elif backend == "docling":
|
| 34 |
+
converter = DocumentConverter(allowed_formats=["pdf"])
|
| 35 |
+
doc = converter.convert(file, **kwargs).document
|
| 36 |
+
res = doc.export_to_markdown() if markdown else doc.export_to_text()
|
| 37 |
+
del converter, doc
|
| 38 |
+
empty_cache()
|
| 39 |
+
return res
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _load_pdf_sync(file, markdown=True, fast=False, **kwargs):
|
| 43 |
+
"""Synchronous PDF loading function for thread pool execution"""
|
| 44 |
+
text = extract_text(
|
| 45 |
+
file,
|
| 46 |
+
markdown,
|
| 47 |
+
backend="docling"
|
| 48 |
+
if ((not fast) and (torch.cuda.is_available() or torch.mps.is_available()))
|
| 49 |
+
else "pymupdf",
|
| 50 |
+
**kwargs,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
return (Path(file).stem, text)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def load_pdfs(files, markdown=True, fast_extract=False, **kwargs):
|
| 57 |
+
"""
|
| 58 |
+
Load multiple PDF files
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
files: PDF filepaths
|
| 62 |
+
markdown: whether to extract text in markdown
|
| 63 |
+
fast_extract: whether to use pymupdf to extract text in markdown
|
| 64 |
+
Returns:
|
| 65 |
+
list: List of tuples containing (filename, extracted_text)
|
| 66 |
+
"""
|
| 67 |
+
# # Use ThreadPoolExecutor to run synchronous operations concurrently
|
| 68 |
+
# loop = asyncio.get_event_loop()
|
| 69 |
+
|
| 70 |
+
# # Create executor with limited workers
|
| 71 |
+
# with ThreadPoolExecutor(max_workers=max_concurrence) as executor:
|
| 72 |
+
# # Submit all PDF processing tasks
|
| 73 |
+
# futures = [
|
| 74 |
+
# loop.run_in_executor(executor, _load_pdf_sync, file, markdown, fast_extract, **kwargs) for file in files if file is not None
|
| 75 |
+
# ]
|
| 76 |
+
|
| 77 |
+
# results = await asyncio.gather(*futures, return_exceptions=True)
|
| 78 |
+
|
| 79 |
+
# valid_results = [result for result in results if not isinstance(result, Exception)]
|
| 80 |
+
|
| 81 |
+
# log.debug(f"Successfully processed {len(valid_results)} out of {len(files)} PDFs")
|
| 82 |
+
# return valid_results
|
| 83 |
+
|
| 84 |
+
results = []
|
| 85 |
+
for file in files:
|
| 86 |
+
results.append(_load_pdf_sync(file, markdown, fast_extract, **kwargs))
|
| 87 |
+
return results
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def find_think_tag_in_each_row(tensor):
|
| 91 |
+
# look for `</think>` tag
|
| 92 |
+
res = dict((tensor == 151668).nonzero().tolist())
|
| 93 |
+
if not res:
|
| 94 |
+
return [0] * len(tensor)
|
| 95 |
+
idxs = []
|
| 96 |
+
for idx in range(len(tensor)):
|
| 97 |
+
idxs.append(res.get(idx, -1))
|
| 98 |
+
return [x + 1 for x in idxs]
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def build_corpus(pdfs, text_splitter, **load_pdf_kwargs):
|
| 102 |
+
texts = load_pdfs(pdfs, **load_pdf_kwargs)
|
| 103 |
+
corpus_with_meta = []
|
| 104 |
+
_id = 0
|
| 105 |
+
for file_name, raw_text in texts:
|
| 106 |
+
chunks = text_splitter.split_text(raw_text)
|
| 107 |
+
for idx, chunk in enumerate(chunks):
|
| 108 |
+
corpus_with_meta.append(
|
| 109 |
+
{
|
| 110 |
+
"id": _id,
|
| 111 |
+
"file": Path(file_name).stem,
|
| 112 |
+
"chunk_id": idx,
|
| 113 |
+
"chunk": chunk,
|
| 114 |
+
}
|
| 115 |
+
)
|
| 116 |
+
_id += 1
|
| 117 |
+
return Dataset.from_list(corpus_with_meta)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def reciprocal_rank_fusion(indices, top_k=3, denom=50):
|
| 121 |
+
scores = defaultdict(int)
|
| 122 |
+
for row in indices:
|
| 123 |
+
for rank, idx in enumerate(row):
|
| 124 |
+
scores[idx] += 1 / (rank + denom)
|
| 125 |
+
results = sorted(scores.items(), key=lambda x: x[1], reverse=True)[:top_k]
|
| 126 |
+
return [idx for idx, _ in results]
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def clean_rewrite_resp(resp):
|
| 130 |
+
try:
|
| 131 |
+
resp = json.loads(resp) # Parse JSON
|
| 132 |
+
except json.JSONDecodeError:
|
| 133 |
+
try:
|
| 134 |
+
resp = literal_eval(resp) # Fallback parse
|
| 135 |
+
except Exception:
|
| 136 |
+
pass # Keep resp as-is if both fail
|
| 137 |
+
|
| 138 |
+
# Ensure resp is a string before strip and slicing
|
| 139 |
+
if isinstance(resp, str):
|
| 140 |
+
resp = resp.strip()
|
| 141 |
+
if resp:
|
| 142 |
+
start = resp.find("{")
|
| 143 |
+
if start != -1:
|
| 144 |
+
end = resp[::-1].find("}")
|
| 145 |
+
if end != -1:
|
| 146 |
+
resp = resp[start : len(resp) - end]
|
| 147 |
+
return clean_rewrite_resp(resp)
|
| 148 |
+
return resp
|