Spaces:
Sleeping
Sleeping
Commit ·
36623c8
1
Parent(s): 44fcc42
Add cache
Browse files- benchmark.py +5 -4
- main.py +25 -14
- requirements old.txt +91 -0
- requirements.txt +30 -2
- retrieval_pipeline/cache.py +94 -0
- retrieval_pipeline/hybrid_search.py +15 -4
benchmark.py
CHANGED
|
@@ -7,6 +7,7 @@ TOP_N = 3
|
|
| 7 |
|
| 8 |
def get_benchmark_result(path, retriever):
|
| 9 |
df = pd.read_csv(path)
|
|
|
|
| 10 |
retrieval_result = []
|
| 11 |
query_result = [[] for i in range(TOP_N)]
|
| 12 |
retrieval_latency = []
|
|
@@ -21,13 +22,13 @@ def get_benchmark_result(path, retriever):
|
|
| 21 |
t0 = time.time()
|
| 22 |
results = retriever.get_relevant_documents(query)
|
| 23 |
t = time.time() - t0
|
| 24 |
-
retrieval_latency.append(t)
|
| 25 |
|
| 26 |
result_content = [result.page_content for result in results]
|
| 27 |
# results_content = get_relevant_documents(query, retriever, top_k=5)
|
| 28 |
|
| 29 |
for i, text in enumerate(result_content):
|
| 30 |
-
query_result[i]
|
| 31 |
|
| 32 |
if target in result_content:
|
| 33 |
retrieval_result.append("Success")
|
|
@@ -37,10 +38,10 @@ def get_benchmark_result(path, retriever):
|
|
| 37 |
# break
|
| 38 |
|
| 39 |
df["retrieval_result"] = retrieval_result
|
| 40 |
-
df["retrieval_latency"] = retrieval_latency
|
| 41 |
for i in range(TOP_N):
|
| 42 |
df[f'q{i+1}'] = query_result[i]
|
| 43 |
-
|
| 44 |
print(df['retrieval_result'].value_counts())
|
| 45 |
print(df['retrieval_result'].value_counts()/ len(df))
|
| 46 |
|
|
|
|
| 7 |
|
| 8 |
def get_benchmark_result(path, retriever):
|
| 9 |
df = pd.read_csv(path)
|
| 10 |
+
|
| 11 |
retrieval_result = []
|
| 12 |
query_result = [[] for i in range(TOP_N)]
|
| 13 |
retrieval_latency = []
|
|
|
|
| 22 |
t0 = time.time()
|
| 23 |
results = retriever.get_relevant_documents(query)
|
| 24 |
t = time.time() - t0
|
| 25 |
+
retrieval_latency.append(str(t))
|
| 26 |
|
| 27 |
result_content = [result.page_content for result in results]
|
| 28 |
# results_content = get_relevant_documents(query, retriever, top_k=5)
|
| 29 |
|
| 30 |
for i, text in enumerate(result_content):
|
| 31 |
+
query_result[i].append(text)
|
| 32 |
|
| 33 |
if target in result_content:
|
| 34 |
retrieval_result.append("Success")
|
|
|
|
| 38 |
# break
|
| 39 |
|
| 40 |
df["retrieval_result"] = retrieval_result
|
| 41 |
+
df["retrieval_latency"] = retrieval_latency
|
| 42 |
for i in range(TOP_N):
|
| 43 |
df[f'q{i+1}'] = query_result[i]
|
| 44 |
+
df.to_csv('benchmark_result.csv')
|
| 45 |
print(df['retrieval_result'].value_counts())
|
| 46 |
print(df['retrieval_result'].value_counts()/ len(df))
|
| 47 |
|
main.py
CHANGED
|
@@ -1,10 +1,13 @@
|
|
| 1 |
from dotenv import load_dotenv
|
| 2 |
import json
|
| 3 |
-
import os
|
| 4 |
import uuid
|
| 5 |
|
| 6 |
from retrieval_pipeline import get_retriever, get_compression_retriever
|
| 7 |
import benchmark
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
load_dotenv()
|
| 10 |
ELASTICSEARCH_URL = os.getenv('ELASTICSEARCH_URL')
|
|
@@ -16,18 +19,26 @@ print(ELASTICSEARCH_URL)
|
|
| 16 |
if __name__ == "__main__":
|
| 17 |
retriever = get_retriever(index='masa.ai', elasticsearch_url=ELASTICSEARCH_URL)
|
| 18 |
compression_retriever = get_compression_retriever(retriever)
|
|
|
|
|
|
|
|
|
|
| 19 |
retrieved_chunks = compression_retriever.get_relevant_documents('Gunung Semeru')
|
| 20 |
print(retrieved_chunks)
|
| 21 |
-
|
| 22 |
-
#
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from dotenv import load_dotenv
|
| 2 |
import json
|
| 3 |
+
import os, time
|
| 4 |
import uuid
|
| 5 |
|
| 6 |
from retrieval_pipeline import get_retriever, get_compression_retriever
|
| 7 |
import benchmark
|
| 8 |
+
from retrieval_pipeline.hybrid_search import store
|
| 9 |
+
|
| 10 |
+
from retrieval_pipeline.cache import SemanticCache
|
| 11 |
|
| 12 |
load_dotenv()
|
| 13 |
ELASTICSEARCH_URL = os.getenv('ELASTICSEARCH_URL')
|
|
|
|
| 19 |
if __name__ == "__main__":
|
| 20 |
retriever = get_retriever(index='masa.ai', elasticsearch_url=ELASTICSEARCH_URL)
|
| 21 |
compression_retriever = get_compression_retriever(retriever)
|
| 22 |
+
|
| 23 |
+
semantic_cache_retriever = SemanticCache(compression_retriever)
|
| 24 |
+
|
| 25 |
retrieved_chunks = compression_retriever.get_relevant_documents('Gunung Semeru')
|
| 26 |
print(retrieved_chunks)
|
| 27 |
+
|
| 28 |
+
# benchmark.get_benchmark_result("benchmark-reranker.csv", retriever=compression_retriever)
|
| 29 |
+
|
| 30 |
+
for i in range(100):
|
| 31 |
+
query = input("query: ")
|
| 32 |
+
t0 = time.time()
|
| 33 |
+
# retrieved_chunks = compression_retriever.get_relevant_documents(query)
|
| 34 |
+
retrieved_chunks = semantic_cache_retriever.get_relevant_documents(query)
|
| 35 |
+
|
| 36 |
+
t = time.time() - t0
|
| 37 |
+
|
| 38 |
+
print(list(store.yield_keys()))
|
| 39 |
+
print('time:', t)
|
| 40 |
+
|
| 41 |
+
print("Result:")
|
| 42 |
+
for r in retrieved_chunks:
|
| 43 |
+
print(r.page_content[:50])
|
| 44 |
+
print()
|
requirements old.txt
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
aiohttp==3.9.5
|
| 2 |
+
aiolimiter==1.1.0
|
| 3 |
+
aiosignal==1.3.1
|
| 4 |
+
altair==5.3.0
|
| 5 |
+
annotated-types==0.6.0
|
| 6 |
+
async-timeout==4.0.3
|
| 7 |
+
attrs==23.2.0
|
| 8 |
+
blinker==1.8.2
|
| 9 |
+
cachetools==5.3.3
|
| 10 |
+
certifi==2024.2.2
|
| 11 |
+
charset-normalizer==3.3.2
|
| 12 |
+
click==8.1.7
|
| 13 |
+
colorama==0.4.6
|
| 14 |
+
Cython==3.0.10
|
| 15 |
+
dataclasses-json==0.6.6
|
| 16 |
+
elastic-transport==8.13.0
|
| 17 |
+
elasticsearch==8.13.1
|
| 18 |
+
filelock==3.14.0
|
| 19 |
+
frozenlist==1.4.1
|
| 20 |
+
fsspec==2024.3.1
|
| 21 |
+
gitdb==4.0.11
|
| 22 |
+
GitPython==3.1.43
|
| 23 |
+
greenlet==3.0.3
|
| 24 |
+
huggingface-hub==0.23.0
|
| 25 |
+
idna==3.7
|
| 26 |
+
intel-openmp==2021.4.0
|
| 27 |
+
Jinja2==3.1.4
|
| 28 |
+
joblib==1.4.2
|
| 29 |
+
jsonpatch==1.33
|
| 30 |
+
jsonpointer==2.4
|
| 31 |
+
jsonschema==4.22.0
|
| 32 |
+
jsonschema-specifications==2023.12.1
|
| 33 |
+
langchain==0.1.20
|
| 34 |
+
langchain-community==0.0.38
|
| 35 |
+
langchain-core==0.1.52
|
| 36 |
+
langchain-text-splitters==0.0.1
|
| 37 |
+
langsmith==0.1.57
|
| 38 |
+
markdown-it-py==3.0.0
|
| 39 |
+
MarkupSafe==2.1.5
|
| 40 |
+
marshmallow==3.21.2
|
| 41 |
+
mdurl==0.1.2
|
| 42 |
+
mkl==2021.4.0
|
| 43 |
+
mpmath==1.3.0
|
| 44 |
+
multidict==6.0.5
|
| 45 |
+
mypy-extensions==1.0.0
|
| 46 |
+
networkx==3.2.1
|
| 47 |
+
numpy==1.26.4
|
| 48 |
+
orjson==3.10.3
|
| 49 |
+
packaging==23.2
|
| 50 |
+
pandas==2.2.2
|
| 51 |
+
pillow==10.3.0
|
| 52 |
+
protobuf==4.25.3
|
| 53 |
+
pyarrow==16.1.0
|
| 54 |
+
pydantic==2.7.1
|
| 55 |
+
pydantic_core==2.18.2
|
| 56 |
+
pydeck==0.9.1
|
| 57 |
+
Pygments==2.18.0
|
| 58 |
+
python-dateutil==2.9.0.post0
|
| 59 |
+
python-dotenv==1.0.1
|
| 60 |
+
pytz==2024.1
|
| 61 |
+
PyYAML==6.0.1
|
| 62 |
+
referencing==0.35.1
|
| 63 |
+
regex==2024.5.10
|
| 64 |
+
requests==2.31.0
|
| 65 |
+
rich==13.7.1
|
| 66 |
+
rpds-py==0.18.1
|
| 67 |
+
safetensors==0.4.3
|
| 68 |
+
scikit-learn==1.4.2
|
| 69 |
+
scipy==1.13.0
|
| 70 |
+
sentence-transformers==2.7.0
|
| 71 |
+
six==1.16.0
|
| 72 |
+
smmap==5.0.1
|
| 73 |
+
SQLAlchemy==2.0.30
|
| 74 |
+
streamlit==1.34.0
|
| 75 |
+
sympy==1.12
|
| 76 |
+
tbb==2021.12.0
|
| 77 |
+
tenacity==8.3.0
|
| 78 |
+
threadpoolctl==3.5.0
|
| 79 |
+
tokenizers==0.19.1
|
| 80 |
+
toml==0.10.2
|
| 81 |
+
toolz==0.12.1
|
| 82 |
+
torch==2.3.0
|
| 83 |
+
tornado==6.4
|
| 84 |
+
tqdm==4.66.4
|
| 85 |
+
transformers==4.40.2
|
| 86 |
+
typing-inspect==0.9.0
|
| 87 |
+
typing_extensions==4.11.0
|
| 88 |
+
tzdata==2024.1
|
| 89 |
+
urllib3==2.2.1
|
| 90 |
+
watchdog==4.0.0
|
| 91 |
+
yarl==1.9.4
|
requirements.txt
CHANGED
|
@@ -3,6 +3,7 @@ aiolimiter==1.1.0
|
|
| 3 |
aiosignal==1.3.1
|
| 4 |
altair==5.3.0
|
| 5 |
annotated-types==0.6.0
|
|
|
|
| 6 |
async-timeout==4.0.3
|
| 7 |
attrs==23.2.0
|
| 8 |
blinker==1.8.2
|
|
@@ -11,17 +12,31 @@ certifi==2024.2.2
|
|
| 11 |
charset-normalizer==3.3.2
|
| 12 |
click==8.1.7
|
| 13 |
colorama==0.4.6
|
|
|
|
| 14 |
Cython==3.0.10
|
| 15 |
dataclasses-json==0.6.6
|
| 16 |
elastic-transport==8.13.0
|
| 17 |
elasticsearch==8.13.1
|
|
|
|
|
|
|
|
|
|
| 18 |
filelock==3.14.0
|
|
|
|
| 19 |
frozenlist==1.4.1
|
| 20 |
fsspec==2024.3.1
|
| 21 |
gitdb==4.0.11
|
| 22 |
GitPython==3.1.43
|
| 23 |
greenlet==3.0.3
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
idna==3.7
|
| 26 |
intel-openmp==2021.4.0
|
| 27 |
Jinja2==3.1.4
|
|
@@ -35,6 +50,8 @@ langchain-community==0.0.38
|
|
| 35 |
langchain-core==0.1.52
|
| 36 |
langchain-text-splitters==0.0.1
|
| 37 |
langsmith==0.1.57
|
|
|
|
|
|
|
| 38 |
markdown-it-py==3.0.0
|
| 39 |
MarkupSafe==2.1.5
|
| 40 |
marshmallow==3.21.2
|
|
@@ -44,21 +61,29 @@ mpmath==1.3.0
|
|
| 44 |
multidict==6.0.5
|
| 45 |
mypy-extensions==1.0.0
|
| 46 |
networkx==3.2.1
|
|
|
|
| 47 |
numpy==1.26.4
|
|
|
|
|
|
|
| 48 |
orjson==3.10.3
|
| 49 |
packaging==23.2
|
| 50 |
pandas==2.2.2
|
| 51 |
pillow==10.3.0
|
| 52 |
-
|
|
|
|
| 53 |
pyarrow==16.1.0
|
| 54 |
pydantic==2.7.1
|
| 55 |
pydantic_core==2.18.2
|
| 56 |
pydeck==0.9.1
|
| 57 |
Pygments==2.18.0
|
|
|
|
| 58 |
python-dateutil==2.9.0.post0
|
| 59 |
python-dotenv==1.0.1
|
| 60 |
pytz==2024.1
|
|
|
|
| 61 |
PyYAML==6.0.1
|
|
|
|
|
|
|
| 62 |
referencing==0.35.1
|
| 63 |
regex==2024.5.10
|
| 64 |
requests==2.31.0
|
|
@@ -67,9 +92,11 @@ rpds-py==0.18.1
|
|
| 67 |
safetensors==0.4.3
|
| 68 |
scikit-learn==1.4.2
|
| 69 |
scipy==1.13.0
|
|
|
|
| 70 |
sentence-transformers==2.7.0
|
| 71 |
six==1.16.0
|
| 72 |
smmap==5.0.1
|
|
|
|
| 73 |
SQLAlchemy==2.0.30
|
| 74 |
streamlit==1.34.0
|
| 75 |
sympy==1.12
|
|
@@ -88,4 +115,5 @@ typing_extensions==4.11.0
|
|
| 88 |
tzdata==2024.1
|
| 89 |
urllib3==2.2.1
|
| 90 |
watchdog==4.0.0
|
|
|
|
| 91 |
yarl==1.9.4
|
|
|
|
| 3 |
aiosignal==1.3.1
|
| 4 |
altair==5.3.0
|
| 5 |
annotated-types==0.6.0
|
| 6 |
+
anyio==4.3.0
|
| 7 |
async-timeout==4.0.3
|
| 8 |
attrs==23.2.0
|
| 9 |
blinker==1.8.2
|
|
|
|
| 12 |
charset-normalizer==3.3.2
|
| 13 |
click==8.1.7
|
| 14 |
colorama==0.4.6
|
| 15 |
+
coloredlogs==15.0.1
|
| 16 |
Cython==3.0.10
|
| 17 |
dataclasses-json==0.6.6
|
| 18 |
elastic-transport==8.13.0
|
| 19 |
elasticsearch==8.13.1
|
| 20 |
+
exceptiongroup==1.2.1
|
| 21 |
+
faiss-cpu==1.8.0
|
| 22 |
+
fastembed==0.2.6
|
| 23 |
filelock==3.14.0
|
| 24 |
+
flatbuffers==24.3.25
|
| 25 |
frozenlist==1.4.1
|
| 26 |
fsspec==2024.3.1
|
| 27 |
gitdb==4.0.11
|
| 28 |
GitPython==3.1.43
|
| 29 |
greenlet==3.0.3
|
| 30 |
+
grpcio==1.63.0
|
| 31 |
+
grpcio-tools==1.63.0
|
| 32 |
+
h11==0.14.0
|
| 33 |
+
h2==4.1.0
|
| 34 |
+
hpack==4.0.0
|
| 35 |
+
httpcore==1.0.5
|
| 36 |
+
httpx==0.27.0
|
| 37 |
+
huggingface-hub==0.20.3
|
| 38 |
+
humanfriendly==10.0
|
| 39 |
+
hyperframe==6.0.1
|
| 40 |
idna==3.7
|
| 41 |
intel-openmp==2021.4.0
|
| 42 |
Jinja2==3.1.4
|
|
|
|
| 50 |
langchain-core==0.1.52
|
| 51 |
langchain-text-splitters==0.0.1
|
| 52 |
langsmith==0.1.57
|
| 53 |
+
llvmlite==0.42.0
|
| 54 |
+
loguru==0.7.2
|
| 55 |
markdown-it-py==3.0.0
|
| 56 |
MarkupSafe==2.1.5
|
| 57 |
marshmallow==3.21.2
|
|
|
|
| 61 |
multidict==6.0.5
|
| 62 |
mypy-extensions==1.0.0
|
| 63 |
networkx==3.2.1
|
| 64 |
+
numba==0.59.1
|
| 65 |
numpy==1.26.4
|
| 66 |
+
onnx==1.16.0
|
| 67 |
+
onnxruntime==1.17.3
|
| 68 |
orjson==3.10.3
|
| 69 |
packaging==23.2
|
| 70 |
pandas==2.2.2
|
| 71 |
pillow==10.3.0
|
| 72 |
+
portalocker==2.8.2
|
| 73 |
+
protobuf==5.26.1
|
| 74 |
pyarrow==16.1.0
|
| 75 |
pydantic==2.7.1
|
| 76 |
pydantic_core==2.18.2
|
| 77 |
pydeck==0.9.1
|
| 78 |
Pygments==2.18.0
|
| 79 |
+
pyreadline3==3.4.1
|
| 80 |
python-dateutil==2.9.0.post0
|
| 81 |
python-dotenv==1.0.1
|
| 82 |
pytz==2024.1
|
| 83 |
+
pywin32==306
|
| 84 |
PyYAML==6.0.1
|
| 85 |
+
qdrant-client==1.9.1
|
| 86 |
+
rankerEval==0.2.0
|
| 87 |
referencing==0.35.1
|
| 88 |
regex==2024.5.10
|
| 89 |
requests==2.31.0
|
|
|
|
| 92 |
safetensors==0.4.3
|
| 93 |
scikit-learn==1.4.2
|
| 94 |
scipy==1.13.0
|
| 95 |
+
semantic-cache==0.1.1
|
| 96 |
sentence-transformers==2.7.0
|
| 97 |
six==1.16.0
|
| 98 |
smmap==5.0.1
|
| 99 |
+
sniffio==1.3.1
|
| 100 |
SQLAlchemy==2.0.30
|
| 101 |
streamlit==1.34.0
|
| 102 |
sympy==1.12
|
|
|
|
| 115 |
tzdata==2024.1
|
| 116 |
urllib3==2.2.1
|
| 117 |
watchdog==4.0.0
|
| 118 |
+
win32-setctime==1.1.0
|
| 119 |
yarl==1.9.4
|
retrieval_pipeline/cache.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import faiss
|
| 2 |
+
from sentence_transformers import SentenceTransformer
|
| 3 |
+
import time
|
| 4 |
+
import json
|
| 5 |
+
|
| 6 |
+
from langchain_core.documents import Document
|
| 7 |
+
|
| 8 |
+
def init_cache():
|
| 9 |
+
index = faiss.IndexFlatL2(1024)
|
| 10 |
+
if index.is_trained:
|
| 11 |
+
print("Index trained")
|
| 12 |
+
|
| 13 |
+
# Initialize Sentence Transformer model
|
| 14 |
+
encoder = SentenceTransformer("multilingual-e5-large")
|
| 15 |
+
|
| 16 |
+
return index, encoder
|
| 17 |
+
|
| 18 |
+
def retrieve_cache(json_file):
|
| 19 |
+
try:
|
| 20 |
+
with open(json_file, "r") as file:
|
| 21 |
+
cache = json.load(file)
|
| 22 |
+
except FileNotFoundError:
|
| 23 |
+
cache = {"query": [], "embeddings": [], "answers": [], "response_text": []}
|
| 24 |
+
|
| 25 |
+
return cache
|
| 26 |
+
|
| 27 |
+
def store_cache(json_file, cache):
|
| 28 |
+
with open(json_file, "w") as file:
|
| 29 |
+
json.dump(cache, file)
|
| 30 |
+
|
| 31 |
+
class SemanticCache:
|
| 32 |
+
def __init__(self, retriever, json_file="cache_file.json", thresold=0.35):
|
| 33 |
+
# Initialize Faiss index with Euclidean distance
|
| 34 |
+
self.retriever = retriever
|
| 35 |
+
self.index, self.encoder = init_cache()
|
| 36 |
+
|
| 37 |
+
# Set Euclidean distance threshold
|
| 38 |
+
# a distance of 0 means identicals sentences
|
| 39 |
+
# We only return from cache sentences under this thresold
|
| 40 |
+
self.euclidean_threshold = thresold
|
| 41 |
+
|
| 42 |
+
self.json_file = json_file
|
| 43 |
+
self.cache = retrieve_cache(self.json_file)
|
| 44 |
+
|
| 45 |
+
def query_database(self, query_text):
|
| 46 |
+
results = self.retriever.get_relevant_documents(query_text)
|
| 47 |
+
return results
|
| 48 |
+
|
| 49 |
+
def get_relevant_documents(self, query: str) -> str:
|
| 50 |
+
# Method to retrieve an answer from the cache or generate a new one
|
| 51 |
+
start_time = time.time()
|
| 52 |
+
# try:
|
| 53 |
+
# First we obtain the embeddings corresponding to the user query
|
| 54 |
+
embedding = self.encoder.encode([query])
|
| 55 |
+
|
| 56 |
+
# Search for the nearest neighbor in the index
|
| 57 |
+
self.index.nprobe = 8
|
| 58 |
+
D, I = self.index.search(embedding, 1)
|
| 59 |
+
|
| 60 |
+
if D[0] >= 0:
|
| 61 |
+
if I[0][0] >= 0 and D[0][0] <= self.euclidean_threshold:
|
| 62 |
+
row_id = int(I[0][0])
|
| 63 |
+
|
| 64 |
+
print("Answer recovered from Cache. ")
|
| 65 |
+
print(f"{D[0][0]:.3f} smaller than {self.euclidean_threshold}")
|
| 66 |
+
print(f"Found cache in row: {row_id} with score {D[0][0]:.3f}")
|
| 67 |
+
|
| 68 |
+
end_time = time.time()
|
| 69 |
+
elapsed_time = end_time - start_time
|
| 70 |
+
print(f"Time taken: {elapsed_time:.3f} seconds")
|
| 71 |
+
return [Document(**doc[k]) for doc in self.cache["answers"][row_id]]
|
| 72 |
+
|
| 73 |
+
# Handle the case when there are not enough results
|
| 74 |
+
# or Euclidean distance is not met, asking to chromaDB.
|
| 75 |
+
answer = self.query_database(query)
|
| 76 |
+
# response_text = answer["documents"][0][0]
|
| 77 |
+
|
| 78 |
+
self.cache["query"].append(query)
|
| 79 |
+
self.cache["embeddings"].append(embedding[0].tolist())
|
| 80 |
+
self.cache["answers"].append([doc.__dict__ for doc in answer])
|
| 81 |
+
# self.cache["response_text"].append(response_text)
|
| 82 |
+
|
| 83 |
+
print("Answer recovered from ChromaDB. ")
|
| 84 |
+
# print(f"response_text: {response_text}")
|
| 85 |
+
|
| 86 |
+
self.index.add(embedding)
|
| 87 |
+
store_cache(self.json_file, self.cache)
|
| 88 |
+
end_time = time.time()
|
| 89 |
+
elapsed_time = end_time - start_time
|
| 90 |
+
print(f"Time taken: {elapsed_time:.3f} seconds")
|
| 91 |
+
|
| 92 |
+
return answer
|
| 93 |
+
# except Exception as e:
|
| 94 |
+
# raise RuntimeError(f"Error during 'get_relevant_documents' method: {e}")
|
retrieval_pipeline/hybrid_search.py
CHANGED
|
@@ -9,6 +9,10 @@ import elasticsearch
|
|
| 9 |
|
| 10 |
from typing import Optional, List
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
class HybridRetriever(BaseRetriever):
|
| 14 |
dense_db: ElasticVectorSearch
|
|
@@ -68,10 +72,6 @@ class HybridRetriever(BaseRetriever):
|
|
| 68 |
|
| 69 |
# Combine results (you'll need a strategy here)
|
| 70 |
combined_results = dense_results + sparse_results
|
| 71 |
-
# result_text = [doc.page_content for doc in combined_results]
|
| 72 |
-
|
| 73 |
-
# reranked_result = rerank.rerank(query, documents=result_text, model="rerank-lite-1", top_k=self.top_k_dense+self.top_k_sparse)
|
| 74 |
-
# reranked_result = sorted(reranked_result.results, key=lambda result: result.index)
|
| 75 |
|
| 76 |
# Create LangChain Documents
|
| 77 |
documents = [Document(page_content=doc.page_content, metadata=doc.metadata) for doc in combined_results]
|
|
@@ -82,10 +82,21 @@ class HybridRetriever(BaseRetriever):
|
|
| 82 |
raise NotImplementedError
|
| 83 |
|
| 84 |
def get_dense_db(elasticsearch_url, index_dense, embeddings):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
dense_db = ElasticVectorSearch(
|
| 86 |
elasticsearch_url=elasticsearch_url,
|
| 87 |
index_name=index_dense,
|
| 88 |
embedding=embeddings,
|
|
|
|
| 89 |
)
|
| 90 |
return dense_db
|
| 91 |
|
|
|
|
| 9 |
|
| 10 |
from typing import Optional, List
|
| 11 |
|
| 12 |
+
from langchain.storage import LocalFileStore
|
| 13 |
+
from langchain.embeddings import CacheBackedEmbeddings
|
| 14 |
+
|
| 15 |
+
store = LocalFileStore("cache")
|
| 16 |
|
| 17 |
class HybridRetriever(BaseRetriever):
|
| 18 |
dense_db: ElasticVectorSearch
|
|
|
|
| 72 |
|
| 73 |
# Combine results (you'll need a strategy here)
|
| 74 |
combined_results = dense_results + sparse_results
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
# Create LangChain Documents
|
| 77 |
documents = [Document(page_content=doc.page_content, metadata=doc.metadata) for doc in combined_results]
|
|
|
|
| 82 |
raise NotImplementedError
|
| 83 |
|
| 84 |
def get_dense_db(elasticsearch_url, index_dense, embeddings):
|
| 85 |
+
# retriever cache
|
| 86 |
+
cached_embedder = CacheBackedEmbeddings.from_bytes_store(
|
| 87 |
+
embeddings, store,
|
| 88 |
+
namespace='sentence-transformer',
|
| 89 |
+
# query_embedding_store=store,
|
| 90 |
+
# query_embedding_cache=True
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
cached_embedder.query_embedding_store = store
|
| 94 |
+
|
| 95 |
dense_db = ElasticVectorSearch(
|
| 96 |
elasticsearch_url=elasticsearch_url,
|
| 97 |
index_name=index_dense,
|
| 98 |
embedding=embeddings,
|
| 99 |
+
# embedding=cached_embedder,
|
| 100 |
)
|
| 101 |
return dense_db
|
| 102 |
|