Rabbook / tests /test_query_transform.py
Matcry's picture
Deploy snapshot
c76423f
Raw
History Blame Contribute Delete
4.16 kB
import unittest
from rag.retrieve import (
QueryRewriteResult,
generate_sub_queries,
is_valid_retrieval_query,
parse_structured_sub_queries,
parse_sub_queries_json,
)
class FakeResponse:
def __init__(self, content):
self.content = content
class FakeTransformer:
def __init__(self, content):
self.content = content
def invoke(self, prompt):
return FakeResponse(self.content)
class FakeStructuredRunnable:
def __init__(self, structured_response):
self.structured_response = structured_response
def invoke(self, prompt):
return self.structured_response
class FakeStructuredTransformer:
def __init__(self, structured_response):
self.structured_response = structured_response
self.prompts = []
def with_structured_output(self, schema, **kwargs):
self.schema = schema
return FakeStructuredRunnable(self.structured_response)
def invoke(self, prompt):
self.prompts.append(prompt)
return FakeResponse("")
class QueryTransformTests(unittest.TestCase):
def test_parse_sub_queries_json_returns_clean_query_list(self):
response_text = """
Here is the retrieval output:
{"sub_queries": ["trump background", "trump family", "trump business"]}
"""
result = parse_sub_queries_json(
response_text,
"5 facts about trump",
max_queries=4,
)
self.assertEqual(
result,
["trump background", "trump family", "trump business"],
)
def test_generate_sub_queries_prefers_json_over_explanatory_text(self):
transformer = FakeTransformer(
"""I'll break down the original query.
{"sub_queries": ["climate change biodiversity", "climate change oceans", "climate change agriculture"]}
"""
)
result = generate_sub_queries(
"What are the impacts of climate change on the environment?",
transformer,
max_queries=3,
)
self.assertEqual(
result,
[
"climate change biodiversity",
"climate change oceans",
"climate change agriculture",
],
)
def test_generate_sub_queries_uses_structured_output_when_available(self):
transformer = FakeStructuredTransformer(
QueryRewriteResult(
sub_queries=[
"trump background",
"trump family",
"trump business",
]
)
)
result = generate_sub_queries(
"5 facts about trump",
transformer,
max_queries=3,
)
self.assertEqual(
result,
["trump background", "trump family", "trump business"],
)
self.assertIs(transformer.schema, QueryRewriteResult)
def test_parse_structured_sub_queries_handles_plain_dict_payload(self):
result = parse_structured_sub_queries(
{"sub_queries": ["trump background", "trump family", "trump business"]},
"5 facts about trump",
max_queries=3,
)
self.assertEqual(
result,
["trump background", "trump family", "trump business"],
)
def test_generate_sub_queries_falls_back_to_line_parsing_when_json_is_missing(self):
transformer = FakeTransformer(
"""Sub-queries:
1. trump background
2. trump family
3. trump business
"""
)
result = generate_sub_queries(
"5 facts about trump",
transformer,
max_queries=3,
)
self.assertEqual(
result,
["trump background", "trump family", "trump business"],
)
def test_rejects_assistant_style_refusal_text_as_retrieval_query(self):
self.assertFalse(
is_valid_retrieval_query(
'Please provide a specific question or topic you would like me to research.',
original_query="hi",
)
)
if __name__ == "__main__":
unittest.main()