Spaces:
Sleeping
Sleeping
Commit ·
46bf894
1
Parent(s): e2f7b9c
added hybrid retriever
Browse files- app/embedding/vectore_store.py +20 -9
- app/retrieval/retriever.py +27 -13
- app/schemas/metadata_schema.py +3 -3
- app/services/RAG_service.py +26 -19
- app/utils/metadata_utils.py +37 -25
app/embedding/vectore_store.py
CHANGED
|
@@ -5,12 +5,12 @@ from pinecone import ServerlessSpec
|
|
| 5 |
from langchain_pinecone import PineconeVectorStore
|
| 6 |
from datetime import datetime
|
| 7 |
from uuid import uuid4
|
| 8 |
-
|
| 9 |
class VectorStore:
|
| 10 |
def __init__(self, text_chunks, embedding_model):
|
| 11 |
self.text_chunks = text_chunks
|
| 12 |
self.current_time = datetime.now()
|
| 13 |
self.embedding_model = embedding_model
|
|
|
|
| 14 |
|
| 15 |
def create_vectorestore(self):
|
| 16 |
load_dotenv()
|
|
@@ -18,23 +18,34 @@ class VectorStore:
|
|
| 18 |
pc = Pinecone(api_key=pinecone_key)
|
| 19 |
# pc._vector_api.api_client.pool_threads = 1
|
| 20 |
time_string = self.current_time.strftime("%Y-%m-%d-%H-%M")
|
| 21 |
-
index_name =
|
|
|
|
| 22 |
if not pc.has_index(index_name):
|
| 23 |
pc.create_index(
|
| 24 |
-
name
|
| 25 |
-
dimension=
|
| 26 |
metric="cosine",
|
| 27 |
-
spec
|
| 28 |
)
|
| 29 |
|
| 30 |
index = pc.Index(index_name)
|
| 31 |
# model_loader = ModelLoader(model_provider="openai")
|
| 32 |
# embedding_model = model_loader.load_llm()
|
| 33 |
uuids = [str(uuid4()) for _ in range(len(self.text_chunks)) ]
|
| 34 |
-
vector_store = PineconeVectorStore(index = index, embedding=self.embedding_model)
|
| 35 |
-
name_space = f"hackrx-index{time_string}"
|
| 36 |
-
vector_store.add_documents(documents=self.text_chunks, ids = uuids,namespace = name_space )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
return index,
|
| 39 |
|
| 40 |
|
|
|
|
| 5 |
from langchain_pinecone import PineconeVectorStore
|
| 6 |
from datetime import datetime
|
| 7 |
from uuid import uuid4
|
|
|
|
| 8 |
class VectorStore:
|
| 9 |
def __init__(self, text_chunks, embedding_model):
|
| 10 |
self.text_chunks = text_chunks
|
| 11 |
self.current_time = datetime.now()
|
| 12 |
self.embedding_model = embedding_model
|
| 13 |
+
# self.index, self.namespace, self.retriever = self.create_vectorestore()
|
| 14 |
|
| 15 |
def create_vectorestore(self):
|
| 16 |
load_dotenv()
|
|
|
|
| 18 |
pc = Pinecone(api_key=pinecone_key)
|
| 19 |
# pc._vector_api.api_client.pool_threads = 1
|
| 20 |
time_string = self.current_time.strftime("%Y-%m-%d-%H-%M")
|
| 21 |
+
index_name = "rag-project"
|
| 22 |
+
namespace = f"rag-project{time_string}"
|
| 23 |
if not pc.has_index(index_name):
|
| 24 |
pc.create_index(
|
| 25 |
+
name=index_name,
|
| 26 |
+
dimension=1024,
|
| 27 |
metric="cosine",
|
| 28 |
+
spec=ServerlessSpec(cloud="aws", region="us-east-1")
|
| 29 |
)
|
| 30 |
|
| 31 |
index = pc.Index(index_name)
|
| 32 |
# model_loader = ModelLoader(model_provider="openai")
|
| 33 |
# embedding_model = model_loader.load_llm()
|
| 34 |
uuids = [str(uuid4()) for _ in range(len(self.text_chunks)) ]
|
| 35 |
+
# vector_store = PineconeVectorStore.from_documents(index = index, embedding=self.embedding_model)
|
| 36 |
+
# name_space = f"hackrx-index{time_string}"
|
| 37 |
+
# vector_store.add_documents(documents=self.text_chunks, ids = uuids,namespace = name_space )
|
| 38 |
+
# retriever = vector_store.as_retriever(
|
| 39 |
+
# search_type="similarity",
|
| 40 |
+
# search_kwargs={"k": 5},
|
| 41 |
+
# )
|
| 42 |
+
vector_store = PineconeVectorStore.from_documents(documents=self.text_chunks,index_name=index_name, embedding=self.embedding_model, namespace = namespace)
|
| 43 |
+
# vector_store.add_documents(documents=docs, ids=uuids)
|
| 44 |
+
# retriever = vector_store.as_retriever(
|
| 45 |
+
# search_type="similarity",
|
| 46 |
+
# search_kwargs={"k": 5,"namespace": namespace}
|
| 47 |
+
# )
|
| 48 |
|
| 49 |
+
return index, namespace, vector_store
|
| 50 |
|
| 51 |
|
app/retrieval/retriever.py
CHANGED
|
@@ -1,14 +1,26 @@
|
|
| 1 |
-
|
|
|
|
| 2 |
|
| 3 |
class Retriever:
|
| 4 |
-
def __init__(self, pinecone_index, query = None, metadata = None, namespace=None):
|
| 5 |
self.pinecone_index = pinecone_index
|
| 6 |
self.query = query
|
| 7 |
self.metadata = metadata
|
| 8 |
self.namespace = namespace
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
-
def retrieval_from_pinecone_vectoreStore(self
|
| 12 |
"""
|
| 13 |
Retrieve the top matching chunks from Pinecone.
|
| 14 |
|
|
@@ -21,14 +33,14 @@ class Retriever:
|
|
| 21 |
Returns:
|
| 22 |
List of ClauseHit objects (lightweight container for chunk info).
|
| 23 |
"""
|
| 24 |
-
res = self.pinecone_index.query(
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
|
| 33 |
# Process the results into the expected format
|
| 34 |
# class ClauseHit:
|
|
@@ -51,5 +63,7 @@ class Retriever:
|
|
| 51 |
# score=match['score']
|
| 52 |
# ))
|
| 53 |
# return hits
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from langchain.retrievers import EnsembleRetriever
|
| 3 |
|
| 4 |
class Retriever:
|
| 5 |
+
def __init__(self, pinecone_index, query = None, metadata = None, namespace=None, vectore_store = None,sparse_retriever = None, llm = None):
|
| 6 |
self.pinecone_index = pinecone_index
|
| 7 |
self.query = query
|
| 8 |
self.metadata = metadata
|
| 9 |
self.namespace = namespace
|
| 10 |
+
self.vector_store = vectore_store
|
| 11 |
+
self.sparse_retriever = sparse_retriever
|
| 12 |
+
self.llm = llm
|
| 13 |
+
self.dense_retriever = self.vector_store.as_retriever(
|
| 14 |
+
search_type="similarity",
|
| 15 |
+
search_kwargs={"k": 5,"namespace": self.namespace, "filter": self.metadata}
|
| 16 |
+
)
|
| 17 |
+
self.hybrid_retriever = EnsembleRetriever(
|
| 18 |
+
retrievers=[self.dense_retriever, sparse_retriever], # Use .retriever attribute
|
| 19 |
+
weights=[0.7, 0.3] # Fix: 'weights' not 'weight'
|
| 20 |
+
)
|
| 21 |
|
| 22 |
|
| 23 |
+
def retrieval_from_pinecone_vectoreStore(self):
|
| 24 |
"""
|
| 25 |
Retrieve the top matching chunks from Pinecone.
|
| 26 |
|
|
|
|
| 33 |
Returns:
|
| 34 |
List of ClauseHit objects (lightweight container for chunk info).
|
| 35 |
"""
|
| 36 |
+
# res = self.pinecone_index.query(
|
| 37 |
+
# vector= self.query,
|
| 38 |
+
# top_k =top_k ,
|
| 39 |
+
# include_metadata = True,
|
| 40 |
+
# include_values = False,
|
| 41 |
+
# filter = self.metadata,
|
| 42 |
+
# namespace = self.namespace
|
| 43 |
+
# )
|
| 44 |
|
| 45 |
# Process the results into the expected format
|
| 46 |
# class ClauseHit:
|
|
|
|
| 63 |
# score=match['score']
|
| 64 |
# ))
|
| 65 |
# return hits
|
| 66 |
+
results = self.hybrid_retriever.invoke(self.query)
|
| 67 |
+
for doc in results:
|
| 68 |
+
print(f"printing Doc content : {doc.page_content}")
|
| 69 |
+
return results
|
app/schemas/metadata_schema.py
CHANGED
|
@@ -3,7 +3,7 @@ from typing import List, Dict, Any, Optional, Union, Literal
|
|
| 3 |
|
| 4 |
class CommonMetaData(BaseModel):
|
| 5 |
# --- Common metadata (across all domains) ---
|
| 6 |
-
doc_id: Optional[List[str]] = Field(None, description="Unique document identifier")
|
| 7 |
doc_category: Optional[List[str]] = Field(None, description="General pool/category e.g. Insurance, HR, Legal")
|
| 8 |
doc_type: Optional[List[str]] = Field(None, description="Specific type e.g. Policy doc, Contract, Handbook")
|
| 9 |
jurisdiction: Optional[List[str]] = Field(
|
|
@@ -17,7 +17,7 @@ class CommonMetaData(BaseModel):
|
|
| 17 |
# description="List of short, normalized obligation keywords (2–5 words each, no full sentences)"
|
| 18 |
# )
|
| 19 |
penalties: Optional[List[str]] = Field(None, description="Penalties/non-compliance consequences")
|
| 20 |
-
notes: Optional[List[str]] = Field(None, description="Freeform additional metadata")
|
| 21 |
# added_new_keyword: bool = False
|
| 22 |
added_new_keyword: bool = True
|
| 23 |
class InsuranceMetadata(CommonMetaData):
|
|
@@ -28,7 +28,7 @@ class InsuranceMetadata(CommonMetaData):
|
|
| 28 |
default=None,
|
| 29 |
description="Type(s) of coverage. Short keywords (1–3 words each)."
|
| 30 |
)
|
| 31 |
-
premium_amount: Optional[List[str]] = None
|
| 32 |
exclusions: Optional[List[str]] = Field(
|
| 33 |
description="List of normalized keywords representing exclusions (short, 2-5 words each, not full paragraphs).", default=None
|
| 34 |
)
|
|
|
|
| 3 |
|
| 4 |
class CommonMetaData(BaseModel):
|
| 5 |
# --- Common metadata (across all domains) ---
|
| 6 |
+
# doc_id: Optional[List[str]] = Field(None, description="Unique document identifier")
|
| 7 |
doc_category: Optional[List[str]] = Field(None, description="General pool/category e.g. Insurance, HR, Legal")
|
| 8 |
doc_type: Optional[List[str]] = Field(None, description="Specific type e.g. Policy doc, Contract, Handbook")
|
| 9 |
jurisdiction: Optional[List[str]] = Field(
|
|
|
|
| 17 |
# description="List of short, normalized obligation keywords (2–5 words each, no full sentences)"
|
| 18 |
# )
|
| 19 |
penalties: Optional[List[str]] = Field(None, description="Penalties/non-compliance consequences")
|
| 20 |
+
# notes: Optional[List[str]] = Field(None, description="Freeform additional metadata")
|
| 21 |
# added_new_keyword: bool = False
|
| 22 |
added_new_keyword: bool = True
|
| 23 |
class InsuranceMetadata(CommonMetaData):
|
|
|
|
| 28 |
default=None,
|
| 29 |
description="Type(s) of coverage. Short keywords (1–3 words each)."
|
| 30 |
)
|
| 31 |
+
# premium_amount: Optional[List[str]] = None
|
| 32 |
exclusions: Optional[List[str]] = Field(
|
| 33 |
description="List of normalized keywords representing exclusions (short, 2-5 words each, not full paragraphs).", default=None
|
| 34 |
)
|
app/services/RAG_service.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
from typing import List
|
| 2 |
from app.utils.model_loader import ModelLoader
|
| 3 |
from app.ingestion.file_loader import FileLoader
|
| 4 |
from app.ingestion.text_splitter import splitting_text
|
|
@@ -7,11 +6,10 @@ from app.embedding.embeder import QueryEmbedding
|
|
| 7 |
from app.embedding.vectore_store import VectorStore
|
| 8 |
from app.metadata_extraction.metadata_ext import MetadataExtractor
|
| 9 |
from app.utils.metadata_utils import MetadataService
|
| 10 |
-
# from app.utils.document_op import DocumentOperation
|
| 11 |
from langchain_core.documents import Document
|
| 12 |
import json
|
| 13 |
-
from
|
| 14 |
-
|
| 15 |
|
| 16 |
# Global model instances (loaded once)
|
| 17 |
_embedding_model = None
|
|
@@ -118,29 +116,38 @@ class RAGService:
|
|
| 118 |
|
| 119 |
def create_vector_store(self):
|
| 120 |
print("[RAGService] Creating vector store...")
|
| 121 |
-
self.
|
| 122 |
-
self.index, self.namespace = self.
|
| 123 |
print(f"[RAGService] Vector store created. Index: {self.index}, Namespace: {self.namespace}")
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
-
|
|
|
|
|
|
|
| 126 |
print("[RAGService] Retrieving documents from vector store...")
|
| 127 |
-
self.
|
|
|
|
|
|
|
| 128 |
self.result = self.retriever.retrieval_from_pinecone_vectoreStore()
|
| 129 |
-
|
| 130 |
-
|
|
|
|
| 131 |
def answer_query(self, raw_query:str) -> str:
|
| 132 |
"""Answer user query using retrieved documents and LLM"""
|
| 133 |
print(f"[RAGService] Answering query: {raw_query}")
|
| 134 |
-
top_clause = self.result['matches']
|
| 135 |
-
top_clause_dicts = [r.to_dict() for r in top_clause]
|
| 136 |
-
self.top_clauses = top_clause_dicts
|
| 137 |
-
keys_to_remove = {"file_path", "source", "producer", "keywords", "subject", "added_new_keyword", "author", "chunk_id"}
|
| 138 |
-
for r in top_clause_dicts:
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
|
| 143 |
-
context_clauses = json.dumps(top_clause_dicts, separators=(",", ":"))
|
|
|
|
| 144 |
|
| 145 |
print(f"context_clauses: {context_clauses}")
|
| 146 |
|
|
|
|
|
|
|
| 1 |
from app.utils.model_loader import ModelLoader
|
| 2 |
from app.ingestion.file_loader import FileLoader
|
| 3 |
from app.ingestion.text_splitter import splitting_text
|
|
|
|
| 6 |
from app.embedding.vectore_store import VectorStore
|
| 7 |
from app.metadata_extraction.metadata_ext import MetadataExtractor
|
| 8 |
from app.utils.metadata_utils import MetadataService
|
|
|
|
| 9 |
from langchain_core.documents import Document
|
| 10 |
import json
|
| 11 |
+
from langchain_community.retrievers import BM25Retriever
|
| 12 |
+
from langchain.schema import Document
|
| 13 |
|
| 14 |
# Global model instances (loaded once)
|
| 15 |
_embedding_model = None
|
|
|
|
| 116 |
|
| 117 |
def create_vector_store(self):
|
| 118 |
print("[RAGService] Creating vector store...")
|
| 119 |
+
self.vector_store_class_instance = VectorStore(self.chunks, self.embedding_model)
|
| 120 |
+
self.index, self.namespace, self.vector_store = self.vector_store_class_instance.create_vectorestore()
|
| 121 |
print(f"[RAGService] Vector store created. Index: {self.index}, Namespace: {self.namespace}")
|
| 122 |
+
### Sparse Retriever(BM25)
|
| 123 |
+
self.sparse_retriever=BM25Retriever.from_documents(self.chunks)
|
| 124 |
+
self.sparse_retriever.k=3 ##top- k documents to retriever
|
| 125 |
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def retrive_documents(self, raw_query: str):
|
| 129 |
print("[RAGService] Retrieving documents from vector store...")
|
| 130 |
+
self.create_query_embedding(raw_query)
|
| 131 |
+
|
| 132 |
+
self.retriever = Retriever(self.index,raw_query,self.query_metadata, self.namespace, self.vector_store,sparse_retriever = self.sparse_retriever,llm = self.llm)
|
| 133 |
self.result = self.retriever.retrieval_from_pinecone_vectoreStore()
|
| 134 |
+
# self.result = self.retriever.invoke(raw_query)
|
| 135 |
+
# print(f"[RAGService] Retrieval result: {self.result}")
|
| 136 |
+
|
| 137 |
def answer_query(self, raw_query:str) -> str:
|
| 138 |
"""Answer user query using retrieved documents and LLM"""
|
| 139 |
print(f"[RAGService] Answering query: {raw_query}")
|
| 140 |
+
# top_clause = self.result['matches']
|
| 141 |
+
# top_clause_dicts = [r.to_dict() for r in top_clause]
|
| 142 |
+
# self.top_clauses = top_clause_dicts
|
| 143 |
+
# keys_to_remove = {"file_path", "source", "producer", "keywords", "subject", "added_new_keyword", "author", "chunk_id"}
|
| 144 |
+
# for r in top_clause_dicts:
|
| 145 |
+
# meta = r.get("metadata", {})
|
| 146 |
+
# for k in keys_to_remove:
|
| 147 |
+
# meta.pop(k, None)
|
| 148 |
|
| 149 |
+
# context_clauses = json.dumps(top_clause_dicts, separators=(",", ":"))
|
| 150 |
+
context_clauses = [doc.page_content for doc in self.result]
|
| 151 |
|
| 152 |
print(f"context_clauses: {context_clauses}")
|
| 153 |
|
app/utils/metadata_utils.py
CHANGED
|
@@ -57,39 +57,51 @@ class MetadataService:
|
|
| 57 |
return normalized
|
| 58 |
|
| 59 |
@staticmethod
|
| 60 |
-
def cosine_similarity(
|
| 61 |
-
|
| 62 |
-
vector2 = embedding_model.embed_query(text2)
|
| 63 |
-
cosine_similarity = np.dot(vector1, vector2) / (np.linalg.norm(vector1) * np.linalg.norm(vector2))
|
| 64 |
-
return cosine_similarity
|
| 65 |
|
| 66 |
@staticmethod
|
| 67 |
def keyword_sementic_check(result, data, embedding_model):
|
| 68 |
-
|
| 69 |
-
# result = result.model_dump()
|
| 70 |
-
# data = json.load(open(data, 'r'))
|
| 71 |
-
# Compare all keys present in both result and data, and check if any value in result[key] is present in data[key]
|
| 72 |
for key in result.keys():
|
| 73 |
-
print(f"Comparing key: {key}",flush=True)
|
| 74 |
-
|
| 75 |
if result[key] is not None and data.get(key) is not None:
|
| 76 |
-
print(f"result[{key}]: {result[key]}",flush=True)
|
| 77 |
-
print(f"data[{key}]: {data[key]}",flush=True)
|
| 78 |
-
|
| 79 |
if isinstance(result[key], list) and isinstance(data[key], list):
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
else:
|
| 85 |
-
print(f"'{val}' NOT found in data['{key}']")
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
| 89 |
if similarity > 0.90:
|
| 90 |
-
print(f"'{val}' is similar to '{data_val}' with similarity {similarity}",flush=True)
|
| 91 |
-
## if similarity is greater than 0.90, then consider it as matched and replace the value in result with data value
|
| 92 |
result[key][idx] = data_val
|
|
|
|
| 93 |
else:
|
| 94 |
-
print(f"'{val}' is NOT similar to '{data_val}' with similarity {similarity}",flush=True)
|
|
|
|
| 95 |
return result
|
|
|
|
| 57 |
return normalized
|
| 58 |
|
| 59 |
@staticmethod
|
| 60 |
+
def cosine_similarity(vec1, vec2) -> float:
|
| 61 |
+
return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
@staticmethod
|
| 64 |
def keyword_sementic_check(result, data, embedding_model):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
for key in result.keys():
|
| 66 |
+
print(f"Comparing key: {key}", flush=True)
|
| 67 |
+
|
| 68 |
if result[key] is not None and data.get(key) is not None:
|
| 69 |
+
print(f"result[{key}]: {result[key]}", flush=True)
|
| 70 |
+
print(f"data[{key}]: {data[key]}", flush=True)
|
| 71 |
+
|
| 72 |
if isinstance(result[key], list) and isinstance(data[key], list):
|
| 73 |
+
# Filter to only strings
|
| 74 |
+
data_list = [v for v in data[key] if isinstance(v, str)]
|
| 75 |
+
val_list = [v for v in result[key] if isinstance(v, str)]
|
| 76 |
+
data_set = set(data_list)
|
| 77 |
+
|
| 78 |
+
if not data_list or not val_list:
|
| 79 |
+
print(f"Skipping key '{key}' due to empty valid strings.")
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
# Precompute embeddings for data_list
|
| 83 |
+
data_embeddings = {val: embedding_model.embed_query(val) for val in data_list}
|
| 84 |
+
|
| 85 |
+
# Precompute embeddings for val_list
|
| 86 |
+
val_embeddings_list = embedding_model.embed_documents(val_list)
|
| 87 |
+
|
| 88 |
+
for idx, val in enumerate(val_list):
|
| 89 |
+
print(f"Comparing value: {val}", flush=True)
|
| 90 |
+
|
| 91 |
+
if val in data_set:
|
| 92 |
+
print(f"'{val}' found in data['{key}']", flush=True)
|
| 93 |
else:
|
| 94 |
+
print(f"'{val}' NOT found in data['{key}']", flush=True)
|
| 95 |
+
val_vector = val_embeddings_list[idx]
|
| 96 |
+
|
| 97 |
+
for data_val, data_vector in data_embeddings.items():
|
| 98 |
+
similarity = MetadataService.cosine_similarity(val_vector, data_vector)
|
| 99 |
+
print(f"Cosine similarity between '{val}' and '{data_val}': {similarity}", flush=True)
|
| 100 |
if similarity > 0.90:
|
| 101 |
+
print(f"'{val}' is similar to '{data_val}' with similarity {similarity}", flush=True)
|
|
|
|
| 102 |
result[key][idx] = data_val
|
| 103 |
+
break
|
| 104 |
else:
|
| 105 |
+
print(f"'{val}' is NOT similar to '{data_val}' with similarity {similarity}", flush=True)
|
| 106 |
+
|
| 107 |
return result
|