File size: 4,155 Bytes
c76423f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | import unittest
from rag.retrieve import (
QueryRewriteResult,
generate_sub_queries,
is_valid_retrieval_query,
parse_structured_sub_queries,
parse_sub_queries_json,
)
class FakeResponse:
def __init__(self, content):
self.content = content
class FakeTransformer:
def __init__(self, content):
self.content = content
def invoke(self, prompt):
return FakeResponse(self.content)
class FakeStructuredRunnable:
def __init__(self, structured_response):
self.structured_response = structured_response
def invoke(self, prompt):
return self.structured_response
class FakeStructuredTransformer:
def __init__(self, structured_response):
self.structured_response = structured_response
self.prompts = []
def with_structured_output(self, schema, **kwargs):
self.schema = schema
return FakeStructuredRunnable(self.structured_response)
def invoke(self, prompt):
self.prompts.append(prompt)
return FakeResponse("")
class QueryTransformTests(unittest.TestCase):
def test_parse_sub_queries_json_returns_clean_query_list(self):
response_text = """
Here is the retrieval output:
{"sub_queries": ["trump background", "trump family", "trump business"]}
"""
result = parse_sub_queries_json(
response_text,
"5 facts about trump",
max_queries=4,
)
self.assertEqual(
result,
["trump background", "trump family", "trump business"],
)
def test_generate_sub_queries_prefers_json_over_explanatory_text(self):
transformer = FakeTransformer(
"""I'll break down the original query.
{"sub_queries": ["climate change biodiversity", "climate change oceans", "climate change agriculture"]}
"""
)
result = generate_sub_queries(
"What are the impacts of climate change on the environment?",
transformer,
max_queries=3,
)
self.assertEqual(
result,
[
"climate change biodiversity",
"climate change oceans",
"climate change agriculture",
],
)
def test_generate_sub_queries_uses_structured_output_when_available(self):
transformer = FakeStructuredTransformer(
QueryRewriteResult(
sub_queries=[
"trump background",
"trump family",
"trump business",
]
)
)
result = generate_sub_queries(
"5 facts about trump",
transformer,
max_queries=3,
)
self.assertEqual(
result,
["trump background", "trump family", "trump business"],
)
self.assertIs(transformer.schema, QueryRewriteResult)
def test_parse_structured_sub_queries_handles_plain_dict_payload(self):
result = parse_structured_sub_queries(
{"sub_queries": ["trump background", "trump family", "trump business"]},
"5 facts about trump",
max_queries=3,
)
self.assertEqual(
result,
["trump background", "trump family", "trump business"],
)
def test_generate_sub_queries_falls_back_to_line_parsing_when_json_is_missing(self):
transformer = FakeTransformer(
"""Sub-queries:
1. trump background
2. trump family
3. trump business
"""
)
result = generate_sub_queries(
"5 facts about trump",
transformer,
max_queries=3,
)
self.assertEqual(
result,
["trump background", "trump family", "trump business"],
)
def test_rejects_assistant_style_refusal_text_as_retrieval_query(self):
self.assertFalse(
is_valid_retrieval_query(
'Please provide a specific question or topic you would like me to research.',
original_query="hi",
)
)
if __name__ == "__main__":
unittest.main()
|