Spaces:
Sleeping
Sleeping
Update policy_retriever_2.py
Browse files- policy_retriever_2.py +9 -20
policy_retriever_2.py
CHANGED
|
@@ -52,7 +52,7 @@ class PolicyRetriever:
|
|
| 52 |
outputs = self.model(**inputs)
|
| 53 |
return outputs.last_hidden_state[:, 0, :].cpu().numpy()[0].tolist()
|
| 54 |
|
| 55 |
-
|
| 56 |
"""Search for relevant policy sections based on query."""
|
| 57 |
print("\nSearching for relevant policy sections...")
|
| 58 |
query_embedding = self.get_query_embedding(query)
|
|
@@ -113,14 +113,9 @@ Policy sections:
|
|
| 113 |
|
| 114 |
# Add chunks to prompt
|
| 115 |
for i, chunk in enumerate(chunks, 1):
|
| 116 |
-
prompt += f"
|
| 117 |
-
Section {i}:
|
| 118 |
-
Heading: {chunk['heading']}
|
| 119 |
-
Content: {chunk['content']}
|
| 120 |
-
"
|
| 121 |
|
| 122 |
-
prompt += "
|
| 123 |
-
Provide a clear, direct answer based only on the information shown above:"
|
| 124 |
|
| 125 |
# Generate answer
|
| 126 |
inputs = self.llm_tokenizer(
|
|
@@ -145,11 +140,9 @@ Provide a clear, direct answer based only on the information shown above:"
|
|
| 145 |
answer = self.llm_tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
| 146 |
|
| 147 |
# Format and print response
|
| 148 |
-
print(f"
|
| 149 |
-
Q: {query}")
|
| 150 |
print(f"A: {answer}")
|
| 151 |
-
print("
|
| 152 |
-
Based on sections:")
|
| 153 |
for chunk in chunks:
|
| 154 |
print(f"- {chunk['heading']}")
|
| 155 |
|
|
@@ -157,15 +150,13 @@ Based on sections:")
|
|
| 157 |
|
| 158 |
def search_and_generate(self, query: str, top_k: int = 5) -> str:
|
| 159 |
"""Combined search and answer generation."""
|
| 160 |
-
print(f"
|
| 161 |
-
Processing query: {query}")
|
| 162 |
retrieved_chunks = self.search(query, top_k=top_k)
|
| 163 |
return self.generate_answer(query, retrieved_chunks)
|
| 164 |
|
| 165 |
def test_retriever():
|
| 166 |
"""Test function for the PolicyRetriever."""
|
| 167 |
-
print("
|
| 168 |
-
=== Testing Policy Retriever ===")
|
| 169 |
retriever = PolicyRetriever()
|
| 170 |
|
| 171 |
# Standard test queries
|
|
@@ -179,9 +170,7 @@ def test_retriever():
|
|
| 179 |
# Run tests
|
| 180 |
for query in test_queries:
|
| 181 |
answer = retriever.search_and_generate(query)
|
| 182 |
-
print("
|
| 183 |
-
" + "-"*70 + "
|
| 184 |
-
")
|
| 185 |
|
| 186 |
if __name__ == "__main__":
|
| 187 |
-
test_retriever()
|
|
|
|
| 52 |
outputs = self.model(**inputs)
|
| 53 |
return outputs.last_hidden_state[:, 0, :].cpu().numpy()[0].tolist()
|
| 54 |
|
| 55 |
+
def search(self, query: str, top_k: int = 5) -> List[Dict]:
|
| 56 |
"""Search for relevant policy sections based on query."""
|
| 57 |
print("\nSearching for relevant policy sections...")
|
| 58 |
query_embedding = self.get_query_embedding(query)
|
|
|
|
| 113 |
|
| 114 |
# Add chunks to prompt
|
| 115 |
for i, chunk in enumerate(chunks, 1):
|
| 116 |
+
prompt += f"\nSection {i}:\nHeading: {chunk['heading']}\nContent: {chunk['content']}\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
+
prompt += "\nProvide a clear, direct answer based only on the information shown above:"
|
|
|
|
| 119 |
|
| 120 |
# Generate answer
|
| 121 |
inputs = self.llm_tokenizer(
|
|
|
|
| 140 |
answer = self.llm_tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
| 141 |
|
| 142 |
# Format and print response
|
| 143 |
+
print(f"\nQ: {query}")
|
|
|
|
| 144 |
print(f"A: {answer}")
|
| 145 |
+
print("\nBased on sections:")
|
|
|
|
| 146 |
for chunk in chunks:
|
| 147 |
print(f"- {chunk['heading']}")
|
| 148 |
|
|
|
|
| 150 |
|
| 151 |
def search_and_generate(self, query: str, top_k: int = 5) -> str:
|
| 152 |
"""Combined search and answer generation."""
|
| 153 |
+
print(f"\nProcessing query: {query}")
|
|
|
|
| 154 |
retrieved_chunks = self.search(query, top_k=top_k)
|
| 155 |
return self.generate_answer(query, retrieved_chunks)
|
| 156 |
|
| 157 |
def test_retriever():
|
| 158 |
"""Test function for the PolicyRetriever."""
|
| 159 |
+
print("\n=== Testing Policy Retriever ===")
|
|
|
|
| 160 |
retriever = PolicyRetriever()
|
| 161 |
|
| 162 |
# Standard test queries
|
|
|
|
| 170 |
# Run tests
|
| 171 |
for query in test_queries:
|
| 172 |
answer = retriever.search_and_generate(query)
|
| 173 |
+
print("\n" + "-"*70 + "\n")
|
|
|
|
|
|
|
| 174 |
|
| 175 |
if __name__ == "__main__":
|
| 176 |
+
test_retriever()
|