Kshitijk20 commited on
Commit
46bf894
·
1 Parent(s): e2f7b9c

added hybrid retriever

Browse files
app/embedding/vectore_store.py CHANGED
@@ -5,12 +5,12 @@ from pinecone import ServerlessSpec
5
  from langchain_pinecone import PineconeVectorStore
6
  from datetime import datetime
7
  from uuid import uuid4
8
-
9
  class VectorStore:
10
  def __init__(self, text_chunks, embedding_model):
11
  self.text_chunks = text_chunks
12
  self.current_time = datetime.now()
13
  self.embedding_model = embedding_model
 
14
 
15
  def create_vectorestore(self):
16
  load_dotenv()
@@ -18,23 +18,34 @@ class VectorStore:
18
  pc = Pinecone(api_key=pinecone_key)
19
  # pc._vector_api.api_client.pool_threads = 1
20
  time_string = self.current_time.strftime("%Y-%m-%d-%H-%M")
21
- index_name = f"rag-project"
 
22
  if not pc.has_index(index_name):
23
  pc.create_index(
24
- name = index_name,
25
- dimension=1536,
26
  metric="cosine",
27
- spec = ServerlessSpec(cloud="aws", region="us-east-1")
28
  )
29
 
30
  index = pc.Index(index_name)
31
  # model_loader = ModelLoader(model_provider="openai")
32
  # embedding_model = model_loader.load_llm()
33
  uuids = [str(uuid4()) for _ in range(len(self.text_chunks)) ]
34
- vector_store = PineconeVectorStore(index = index, embedding=self.embedding_model)
35
- name_space = f"hackrx-index{time_string}"
36
- vector_store.add_documents(documents=self.text_chunks, ids = uuids,namespace = name_space )
 
 
 
 
 
 
 
 
 
 
37
 
38
- return index, name_space
39
 
40
 
 
5
  from langchain_pinecone import PineconeVectorStore
6
  from datetime import datetime
7
  from uuid import uuid4
 
8
  class VectorStore:
9
  def __init__(self, text_chunks, embedding_model):
10
  self.text_chunks = text_chunks
11
  self.current_time = datetime.now()
12
  self.embedding_model = embedding_model
13
+ # self.index, self.namespace, self.retriever = self.create_vectorestore()
14
 
15
  def create_vectorestore(self):
16
  load_dotenv()
 
18
  pc = Pinecone(api_key=pinecone_key)
19
  # pc._vector_api.api_client.pool_threads = 1
20
  time_string = self.current_time.strftime("%Y-%m-%d-%H-%M")
21
+ index_name = "rag-project"
22
+ namespace = f"rag-project{time_string}"
23
  if not pc.has_index(index_name):
24
  pc.create_index(
25
+ name=index_name,
26
+ dimension=1024,
27
  metric="cosine",
28
+ spec=ServerlessSpec(cloud="aws", region="us-east-1")
29
  )
30
 
31
  index = pc.Index(index_name)
32
  # model_loader = ModelLoader(model_provider="openai")
33
  # embedding_model = model_loader.load_llm()
34
  uuids = [str(uuid4()) for _ in range(len(self.text_chunks)) ]
35
+ # vector_store = PineconeVectorStore.from_documents(index = index, embedding=self.embedding_model)
36
+ # name_space = f"hackrx-index{time_string}"
37
+ # vector_store.add_documents(documents=self.text_chunks, ids = uuids,namespace = name_space )
38
+ # retriever = vector_store.as_retriever(
39
+ # search_type="similarity",
40
+ # search_kwargs={"k": 5},
41
+ # )
42
+ vector_store = PineconeVectorStore.from_documents(documents=self.text_chunks,index_name=index_name, embedding=self.embedding_model, namespace = namespace)
43
+ # vector_store.add_documents(documents=docs, ids=uuids)
44
+ # retriever = vector_store.as_retriever(
45
+ # search_type="similarity",
46
+ # search_kwargs={"k": 5,"namespace": namespace}
47
+ # )
48
 
49
+ return index, namespace, vector_store
50
 
51
 
app/retrieval/retriever.py CHANGED
@@ -1,14 +1,26 @@
1
- # from app.schemas.request_models import ClauseHit
 
2
 
3
  class Retriever:
4
- def __init__(self, pinecone_index, query = None, metadata = None, namespace=None):
5
  self.pinecone_index = pinecone_index
6
  self.query = query
7
  self.metadata = metadata
8
  self.namespace = namespace
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
- def retrieval_from_pinecone_vectoreStore(self, top_k= 3):
12
  """
13
  Retrieve the top matching chunks from Pinecone.
14
 
@@ -21,14 +33,14 @@ class Retriever:
21
  Returns:
22
  List of ClauseHit objects (lightweight container for chunk info).
23
  """
24
- res = self.pinecone_index.query(
25
- vector= self.query,
26
- top_k =top_k ,
27
- include_metadata = True,
28
- include_values = False,
29
- filter = self.metadata,
30
- namespace = self.namespace
31
- )
32
 
33
  # Process the results into the expected format
34
  # class ClauseHit:
@@ -51,5 +63,7 @@ class Retriever:
51
  # score=match['score']
52
  # ))
53
  # return hits
54
- return res
55
-
 
 
 
1
+
2
+ from langchain.retrievers import EnsembleRetriever
3
 
4
  class Retriever:
5
+ def __init__(self, pinecone_index, query = None, metadata = None, namespace=None, vectore_store = None,sparse_retriever = None, llm = None):
6
  self.pinecone_index = pinecone_index
7
  self.query = query
8
  self.metadata = metadata
9
  self.namespace = namespace
10
+ self.vector_store = vectore_store
11
+ self.sparse_retriever = sparse_retriever
12
+ self.llm = llm
13
+ self.dense_retriever = self.vector_store.as_retriever(
14
+ search_type="similarity",
15
+ search_kwargs={"k": 5,"namespace": self.namespace, "filter": self.metadata}
16
+ )
17
+ self.hybrid_retriever = EnsembleRetriever(
18
+ retrievers=[self.dense_retriever, sparse_retriever], # Use .retriever attribute
19
+ weights=[0.7, 0.3] # Fix: 'weights' not 'weight'
20
+ )
21
 
22
 
23
+ def retrieval_from_pinecone_vectoreStore(self):
24
  """
25
  Retrieve the top matching chunks from Pinecone.
26
 
 
33
  Returns:
34
  List of ClauseHit objects (lightweight container for chunk info).
35
  """
36
+ # res = self.pinecone_index.query(
37
+ # vector= self.query,
38
+ # top_k =top_k ,
39
+ # include_metadata = True,
40
+ # include_values = False,
41
+ # filter = self.metadata,
42
+ # namespace = self.namespace
43
+ # )
44
 
45
  # Process the results into the expected format
46
  # class ClauseHit:
 
63
  # score=match['score']
64
  # ))
65
  # return hits
66
+ results = self.hybrid_retriever.invoke(self.query)
67
+ for doc in results:
68
+ print(f"printing Doc content : {doc.page_content}")
69
+ return results
app/schemas/metadata_schema.py CHANGED
@@ -3,7 +3,7 @@ from typing import List, Dict, Any, Optional, Union, Literal
3
 
4
  class CommonMetaData(BaseModel):
5
  # --- Common metadata (across all domains) ---
6
- doc_id: Optional[List[str]] = Field(None, description="Unique document identifier")
7
  doc_category: Optional[List[str]] = Field(None, description="General pool/category e.g. Insurance, HR, Legal")
8
  doc_type: Optional[List[str]] = Field(None, description="Specific type e.g. Policy doc, Contract, Handbook")
9
  jurisdiction: Optional[List[str]] = Field(
@@ -17,7 +17,7 @@ class CommonMetaData(BaseModel):
17
  # description="List of short, normalized obligation keywords (2–5 words each, no full sentences)"
18
  # )
19
  penalties: Optional[List[str]] = Field(None, description="Penalties/non-compliance consequences")
20
- notes: Optional[List[str]] = Field(None, description="Freeform additional metadata")
21
  # added_new_keyword: bool = False
22
  added_new_keyword: bool = True
23
  class InsuranceMetadata(CommonMetaData):
@@ -28,7 +28,7 @@ class InsuranceMetadata(CommonMetaData):
28
  default=None,
29
  description="Type(s) of coverage. Short keywords (1–3 words each)."
30
  )
31
- premium_amount: Optional[List[str]] = None
32
  exclusions: Optional[List[str]] = Field(
33
  description="List of normalized keywords representing exclusions (short, 2-5 words each, not full paragraphs).", default=None
34
  )
 
3
 
4
  class CommonMetaData(BaseModel):
5
  # --- Common metadata (across all domains) ---
6
+ # doc_id: Optional[List[str]] = Field(None, description="Unique document identifier")
7
  doc_category: Optional[List[str]] = Field(None, description="General pool/category e.g. Insurance, HR, Legal")
8
  doc_type: Optional[List[str]] = Field(None, description="Specific type e.g. Policy doc, Contract, Handbook")
9
  jurisdiction: Optional[List[str]] = Field(
 
17
  # description="List of short, normalized obligation keywords (2–5 words each, no full sentences)"
18
  # )
19
  penalties: Optional[List[str]] = Field(None, description="Penalties/non-compliance consequences")
20
+ # notes: Optional[List[str]] = Field(None, description="Freeform additional metadata")
21
  # added_new_keyword: bool = False
22
  added_new_keyword: bool = True
23
  class InsuranceMetadata(CommonMetaData):
 
28
  default=None,
29
  description="Type(s) of coverage. Short keywords (1–3 words each)."
30
  )
31
+ # premium_amount: Optional[List[str]] = None
32
  exclusions: Optional[List[str]] = Field(
33
  description="List of normalized keywords representing exclusions (short, 2-5 words each, not full paragraphs).", default=None
34
  )
app/services/RAG_service.py CHANGED
@@ -1,4 +1,3 @@
1
- from typing import List
2
  from app.utils.model_loader import ModelLoader
3
  from app.ingestion.file_loader import FileLoader
4
  from app.ingestion.text_splitter import splitting_text
@@ -7,11 +6,10 @@ from app.embedding.embeder import QueryEmbedding
7
  from app.embedding.vectore_store import VectorStore
8
  from app.metadata_extraction.metadata_ext import MetadataExtractor
9
  from app.utils.metadata_utils import MetadataService
10
- # from app.utils.document_op import DocumentOperation
11
  from langchain_core.documents import Document
12
  import json
13
- from typing import List, Optional
14
- # ...existing imports...
15
 
16
  # Global model instances (loaded once)
17
  _embedding_model = None
@@ -118,29 +116,38 @@ class RAGService:
118
 
119
  def create_vector_store(self):
120
  print("[RAGService] Creating vector store...")
121
- self.vector_store = VectorStore(self.chunks, self.embedding_model)
122
- self.index, self.namespace = self.vector_store.create_vectorestore()
123
  print(f"[RAGService] Vector store created. Index: {self.index}, Namespace: {self.namespace}")
 
 
 
124
 
125
- def retrive_documents(self):
 
 
126
  print("[RAGService] Retrieving documents from vector store...")
127
- self.retriever = Retriever(self.index,self.query_embedding,self.query_metadata, self.namespace)
 
 
128
  self.result = self.retriever.retrieval_from_pinecone_vectoreStore()
129
- print(f"[RAGService] Retrieval result: {self.result}")
130
-
 
131
  def answer_query(self, raw_query:str) -> str:
132
  """Answer user query using retrieved documents and LLM"""
133
  print(f"[RAGService] Answering query: {raw_query}")
134
- top_clause = self.result['matches']
135
- top_clause_dicts = [r.to_dict() for r in top_clause]
136
- self.top_clauses = top_clause_dicts
137
- keys_to_remove = {"file_path", "source", "producer", "keywords", "subject", "added_new_keyword", "author", "chunk_id"}
138
- for r in top_clause_dicts:
139
- meta = r.get("metadata", {})
140
- for k in keys_to_remove:
141
- meta.pop(k, None)
142
 
143
- context_clauses = json.dumps(top_clause_dicts, separators=(",", ":"))
 
144
 
145
  print(f"context_clauses: {context_clauses}")
146
 
 
 
1
  from app.utils.model_loader import ModelLoader
2
  from app.ingestion.file_loader import FileLoader
3
  from app.ingestion.text_splitter import splitting_text
 
6
  from app.embedding.vectore_store import VectorStore
7
  from app.metadata_extraction.metadata_ext import MetadataExtractor
8
  from app.utils.metadata_utils import MetadataService
 
9
  from langchain_core.documents import Document
10
  import json
11
+ from langchain_community.retrievers import BM25Retriever
12
+ from langchain.schema import Document
13
 
14
  # Global model instances (loaded once)
15
  _embedding_model = None
 
116
 
117
  def create_vector_store(self):
118
  print("[RAGService] Creating vector store...")
119
+ self.vector_store_class_instance = VectorStore(self.chunks, self.embedding_model)
120
+ self.index, self.namespace, self.vector_store = self.vector_store_class_instance.create_vectorestore()
121
  print(f"[RAGService] Vector store created. Index: {self.index}, Namespace: {self.namespace}")
122
+ ### Sparse Retriever(BM25)
123
+ self.sparse_retriever=BM25Retriever.from_documents(self.chunks)
124
+ self.sparse_retriever.k=3 ##top- k documents to retriever
125
 
126
+
127
+
128
+ def retrive_documents(self, raw_query: str):
129
  print("[RAGService] Retrieving documents from vector store...")
130
+ self.create_query_embedding(raw_query)
131
+
132
+ self.retriever = Retriever(self.index,raw_query,self.query_metadata, self.namespace, self.vector_store,sparse_retriever = self.sparse_retriever,llm = self.llm)
133
  self.result = self.retriever.retrieval_from_pinecone_vectoreStore()
134
+ # self.result = self.retriever.invoke(raw_query)
135
+ # print(f"[RAGService] Retrieval result: {self.result}")
136
+
137
  def answer_query(self, raw_query:str) -> str:
138
  """Answer user query using retrieved documents and LLM"""
139
  print(f"[RAGService] Answering query: {raw_query}")
140
+ # top_clause = self.result['matches']
141
+ # top_clause_dicts = [r.to_dict() for r in top_clause]
142
+ # self.top_clauses = top_clause_dicts
143
+ # keys_to_remove = {"file_path", "source", "producer", "keywords", "subject", "added_new_keyword", "author", "chunk_id"}
144
+ # for r in top_clause_dicts:
145
+ # meta = r.get("metadata", {})
146
+ # for k in keys_to_remove:
147
+ # meta.pop(k, None)
148
 
149
+ # context_clauses = json.dumps(top_clause_dicts, separators=(",", ":"))
150
+ context_clauses = [doc.page_content for doc in self.result]
151
 
152
  print(f"context_clauses: {context_clauses}")
153
 
app/utils/metadata_utils.py CHANGED
@@ -57,39 +57,51 @@ class MetadataService:
57
  return normalized
58
 
59
  @staticmethod
60
- def cosine_similarity(text1, text2, embedding_model) -> float:
61
- vector1 = embedding_model.embed_query(text1)
62
- vector2 = embedding_model.embed_query(text2)
63
- cosine_similarity = np.dot(vector1, vector2) / (np.linalg.norm(vector1) * np.linalg.norm(vector2))
64
- return cosine_similarity
65
 
66
  @staticmethod
67
  def keyword_sementic_check(result, data, embedding_model):
68
-
69
- # result = result.model_dump()
70
- # data = json.load(open(data, 'r'))
71
- # Compare all keys present in both result and data, and check if any value in result[key] is present in data[key]
72
  for key in result.keys():
73
- print(f"Comparing key: {key}",flush=True)
74
- # Only check if both result[key] and data[key] are not None and are lists
75
  if result[key] is not None and data.get(key) is not None:
76
- print(f"result[{key}]: {result[key]}",flush=True)
77
- print(f"data[{key}]: {data[key]}",flush=True)
78
- # Ensure both are lists (skip if not)
79
  if isinstance(result[key], list) and isinstance(data[key], list):
80
- for idx,val in enumerate(result[key]):
81
- print(f"Comparing value: {val}",flush=True)
82
- if val in data[key]:
83
- print(f"'{val}' found in data['{key}']")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  else:
85
- print(f"'{val}' NOT found in data['{key}']")
86
- for data_val in data[key]:
87
- similarity = MetadataService.cosine_similarity(val, data_val,embedding_model)
88
- print(f"Cosine similarity between '{val}' and '{data_val}': {similarity}")
 
 
89
  if similarity > 0.90:
90
- print(f"'{val}' is similar to '{data_val}' with similarity {similarity}",flush=True)
91
- ## if similarity is greater than 0.90, then consider it as matched and replace the value in result with data value
92
  result[key][idx] = data_val
 
93
  else:
94
- print(f"'{val}' is NOT similar to '{data_val}' with similarity {similarity}",flush=True)
 
95
  return result
 
57
  return normalized
58
 
59
  @staticmethod
60
+ def cosine_similarity(vec1, vec2) -> float:
61
+ return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
 
 
 
62
 
63
  @staticmethod
64
  def keyword_sementic_check(result, data, embedding_model):
 
 
 
 
65
  for key in result.keys():
66
+ print(f"Comparing key: {key}", flush=True)
67
+
68
  if result[key] is not None and data.get(key) is not None:
69
+ print(f"result[{key}]: {result[key]}", flush=True)
70
+ print(f"data[{key}]: {data[key]}", flush=True)
71
+
72
  if isinstance(result[key], list) and isinstance(data[key], list):
73
+ # Filter to only strings
74
+ data_list = [v for v in data[key] if isinstance(v, str)]
75
+ val_list = [v for v in result[key] if isinstance(v, str)]
76
+ data_set = set(data_list)
77
+
78
+ if not data_list or not val_list:
79
+ print(f"Skipping key '{key}' due to empty valid strings.")
80
+ continue
81
+
82
+ # Precompute embeddings for data_list
83
+ data_embeddings = {val: embedding_model.embed_query(val) for val in data_list}
84
+
85
+ # Precompute embeddings for val_list
86
+ val_embeddings_list = embedding_model.embed_documents(val_list)
87
+
88
+ for idx, val in enumerate(val_list):
89
+ print(f"Comparing value: {val}", flush=True)
90
+
91
+ if val in data_set:
92
+ print(f"'{val}' found in data['{key}']", flush=True)
93
  else:
94
+ print(f"'{val}' NOT found in data['{key}']", flush=True)
95
+ val_vector = val_embeddings_list[idx]
96
+
97
+ for data_val, data_vector in data_embeddings.items():
98
+ similarity = MetadataService.cosine_similarity(val_vector, data_vector)
99
+ print(f"Cosine similarity between '{val}' and '{data_val}': {similarity}", flush=True)
100
  if similarity > 0.90:
101
+ print(f"'{val}' is similar to '{data_val}' with similarity {similarity}", flush=True)
 
102
  result[key][idx] = data_val
103
+ break
104
  else:
105
+ print(f"'{val}' is NOT similar to '{data_val}' with similarity {similarity}", flush=True)
106
+
107
  return result