bappiahk commited on
Commit
e219dd3
·
verified ·
1 Parent(s): da7ab77

Update policy_retriever_2.py

Browse files
Files changed (1) hide show
  1. policy_retriever_2.py +9 -20
policy_retriever_2.py CHANGED
@@ -52,7 +52,7 @@ class PolicyRetriever:
52
  outputs = self.model(**inputs)
53
  return outputs.last_hidden_state[:, 0, :].cpu().numpy()[0].tolist()
54
 
55
- def search(self, query: str, top_k: int = 5) -> List[Dict]:
56
  """Search for relevant policy sections based on query."""
57
  print("\nSearching for relevant policy sections...")
58
  query_embedding = self.get_query_embedding(query)
@@ -113,14 +113,9 @@ Policy sections:
113
 
114
  # Add chunks to prompt
115
  for i, chunk in enumerate(chunks, 1):
116
- prompt += f"
117
- Section {i}:
118
- Heading: {chunk['heading']}
119
- Content: {chunk['content']}
120
- "
121
 
122
- prompt += "
123
- Provide a clear, direct answer based only on the information shown above:"
124
 
125
  # Generate answer
126
  inputs = self.llm_tokenizer(
@@ -145,11 +140,9 @@ Provide a clear, direct answer based only on the information shown above:"
145
  answer = self.llm_tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
146
 
147
  # Format and print response
148
- print(f"
149
- Q: {query}")
150
  print(f"A: {answer}")
151
- print("
152
- Based on sections:")
153
  for chunk in chunks:
154
  print(f"- {chunk['heading']}")
155
 
@@ -157,15 +150,13 @@ Based on sections:")
157
 
158
  def search_and_generate(self, query: str, top_k: int = 5) -> str:
159
  """Combined search and answer generation."""
160
- print(f"
161
- Processing query: {query}")
162
  retrieved_chunks = self.search(query, top_k=top_k)
163
  return self.generate_answer(query, retrieved_chunks)
164
 
165
  def test_retriever():
166
  """Test function for the PolicyRetriever."""
167
- print("
168
- === Testing Policy Retriever ===")
169
  retriever = PolicyRetriever()
170
 
171
  # Standard test queries
@@ -179,9 +170,7 @@ def test_retriever():
179
  # Run tests
180
  for query in test_queries:
181
  answer = retriever.search_and_generate(query)
182
- print("
183
- " + "-"*70 + "
184
- ")
185
 
186
  if __name__ == "__main__":
187
- test_retriever()
 
52
  outputs = self.model(**inputs)
53
  return outputs.last_hidden_state[:, 0, :].cpu().numpy()[0].tolist()
54
 
55
+ def search(self, query: str, top_k: int = 5) -> List[Dict]:
56
  """Search for relevant policy sections based on query."""
57
  print("\nSearching for relevant policy sections...")
58
  query_embedding = self.get_query_embedding(query)
 
113
 
114
  # Add chunks to prompt
115
  for i, chunk in enumerate(chunks, 1):
116
+ prompt += f"\nSection {i}:\nHeading: {chunk['heading']}\nContent: {chunk['content']}\n"
 
 
 
 
117
 
118
+ prompt += "\nProvide a clear, direct answer based only on the information shown above:"
 
119
 
120
  # Generate answer
121
  inputs = self.llm_tokenizer(
 
140
  answer = self.llm_tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
141
 
142
  # Format and print response
143
+ print(f"\nQ: {query}")
 
144
  print(f"A: {answer}")
145
+ print("\nBased on sections:")
 
146
  for chunk in chunks:
147
  print(f"- {chunk['heading']}")
148
 
 
150
 
151
  def search_and_generate(self, query: str, top_k: int = 5) -> str:
152
  """Combined search and answer generation."""
153
+ print(f"\nProcessing query: {query}")
 
154
  retrieved_chunks = self.search(query, top_k=top_k)
155
  return self.generate_answer(query, retrieved_chunks)
156
 
157
  def test_retriever():
158
  """Test function for the PolicyRetriever."""
159
+ print("\n=== Testing Policy Retriever ===")
 
160
  retriever = PolicyRetriever()
161
 
162
  # Standard test queries
 
170
  # Run tests
171
  for query in test_queries:
172
  answer = retriever.search_and_generate(query)
173
+ print("\n" + "-"*70 + "\n")
 
 
174
 
175
  if __name__ == "__main__":
176
+ test_retriever()