| import unittest |
|
|
| from rag.retrieve import ( |
| QueryRewriteResult, |
| generate_sub_queries, |
| is_valid_retrieval_query, |
| parse_structured_sub_queries, |
| parse_sub_queries_json, |
| ) |
|
|
|
|
| class FakeResponse: |
| def __init__(self, content): |
| self.content = content |
|
|
|
|
| class FakeTransformer: |
| def __init__(self, content): |
| self.content = content |
|
|
| def invoke(self, prompt): |
| return FakeResponse(self.content) |
|
|
|
|
| class FakeStructuredRunnable: |
| def __init__(self, structured_response): |
| self.structured_response = structured_response |
|
|
| def invoke(self, prompt): |
| return self.structured_response |
|
|
|
|
| class FakeStructuredTransformer: |
| def __init__(self, structured_response): |
| self.structured_response = structured_response |
| self.prompts = [] |
|
|
| def with_structured_output(self, schema, **kwargs): |
| self.schema = schema |
| return FakeStructuredRunnable(self.structured_response) |
|
|
| def invoke(self, prompt): |
| self.prompts.append(prompt) |
| return FakeResponse("") |
|
|
|
|
| class QueryTransformTests(unittest.TestCase): |
| def test_parse_sub_queries_json_returns_clean_query_list(self): |
| response_text = """ |
| Here is the retrieval output: |
| {"sub_queries": ["trump background", "trump family", "trump business"]} |
| """ |
|
|
| result = parse_sub_queries_json( |
| response_text, |
| "5 facts about trump", |
| max_queries=4, |
| ) |
|
|
| self.assertEqual( |
| result, |
| ["trump background", "trump family", "trump business"], |
| ) |
|
|
| def test_generate_sub_queries_prefers_json_over_explanatory_text(self): |
| transformer = FakeTransformer( |
| """I'll break down the original query. |
| {"sub_queries": ["climate change biodiversity", "climate change oceans", "climate change agriculture"]} |
| """ |
| ) |
|
|
| result = generate_sub_queries( |
| "What are the impacts of climate change on the environment?", |
| transformer, |
| max_queries=3, |
| ) |
|
|
| self.assertEqual( |
| result, |
| [ |
| "climate change biodiversity", |
| "climate change oceans", |
| "climate change agriculture", |
| ], |
| ) |
|
|
| def test_generate_sub_queries_uses_structured_output_when_available(self): |
| transformer = FakeStructuredTransformer( |
| QueryRewriteResult( |
| sub_queries=[ |
| "trump background", |
| "trump family", |
| "trump business", |
| ] |
| ) |
| ) |
|
|
| result = generate_sub_queries( |
| "5 facts about trump", |
| transformer, |
| max_queries=3, |
| ) |
|
|
| self.assertEqual( |
| result, |
| ["trump background", "trump family", "trump business"], |
| ) |
| self.assertIs(transformer.schema, QueryRewriteResult) |
|
|
| def test_parse_structured_sub_queries_handles_plain_dict_payload(self): |
| result = parse_structured_sub_queries( |
| {"sub_queries": ["trump background", "trump family", "trump business"]}, |
| "5 facts about trump", |
| max_queries=3, |
| ) |
|
|
| self.assertEqual( |
| result, |
| ["trump background", "trump family", "trump business"], |
| ) |
|
|
| def test_generate_sub_queries_falls_back_to_line_parsing_when_json_is_missing(self): |
| transformer = FakeTransformer( |
| """Sub-queries: |
| 1. trump background |
| 2. trump family |
| 3. trump business |
| """ |
| ) |
|
|
| result = generate_sub_queries( |
| "5 facts about trump", |
| transformer, |
| max_queries=3, |
| ) |
|
|
| self.assertEqual( |
| result, |
| ["trump background", "trump family", "trump business"], |
| ) |
|
|
| def test_rejects_assistant_style_refusal_text_as_retrieval_query(self): |
| self.assertFalse( |
| is_valid_retrieval_query( |
| 'Please provide a specific question or topic you would like me to research.', |
| original_query="hi", |
| ) |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|