File size: 4,155 Bytes
c76423f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import unittest

from rag.retrieve import (
    QueryRewriteResult,
    generate_sub_queries,
    is_valid_retrieval_query,
    parse_structured_sub_queries,
    parse_sub_queries_json,
)


class FakeResponse:
    def __init__(self, content):
        self.content = content


class FakeTransformer:
    def __init__(self, content):
        self.content = content

    def invoke(self, prompt):
        return FakeResponse(self.content)


class FakeStructuredRunnable:
    def __init__(self, structured_response):
        self.structured_response = structured_response

    def invoke(self, prompt):
        return self.structured_response


class FakeStructuredTransformer:
    def __init__(self, structured_response):
        self.structured_response = structured_response
        self.prompts = []

    def with_structured_output(self, schema, **kwargs):
        self.schema = schema
        return FakeStructuredRunnable(self.structured_response)

    def invoke(self, prompt):
        self.prompts.append(prompt)
        return FakeResponse("")


class QueryTransformTests(unittest.TestCase):
    def test_parse_sub_queries_json_returns_clean_query_list(self):
        response_text = """
Here is the retrieval output:
{"sub_queries": ["trump background", "trump family", "trump business"]}
"""

        result = parse_sub_queries_json(
            response_text,
            "5 facts about trump",
            max_queries=4,
        )

        self.assertEqual(
            result,
            ["trump background", "trump family", "trump business"],
        )

    def test_generate_sub_queries_prefers_json_over_explanatory_text(self):
        transformer = FakeTransformer(
            """I'll break down the original query.
{"sub_queries": ["climate change biodiversity", "climate change oceans", "climate change agriculture"]}
"""
        )

        result = generate_sub_queries(
            "What are the impacts of climate change on the environment?",
            transformer,
            max_queries=3,
        )

        self.assertEqual(
            result,
            [
                "climate change biodiversity",
                "climate change oceans",
                "climate change agriculture",
            ],
        )

    def test_generate_sub_queries_uses_structured_output_when_available(self):
        transformer = FakeStructuredTransformer(
            QueryRewriteResult(
                sub_queries=[
                    "trump background",
                    "trump family",
                    "trump business",
                ]
            )
        )

        result = generate_sub_queries(
            "5 facts about trump",
            transformer,
            max_queries=3,
        )

        self.assertEqual(
            result,
            ["trump background", "trump family", "trump business"],
        )
        self.assertIs(transformer.schema, QueryRewriteResult)

    def test_parse_structured_sub_queries_handles_plain_dict_payload(self):
        result = parse_structured_sub_queries(
            {"sub_queries": ["trump background", "trump family", "trump business"]},
            "5 facts about trump",
            max_queries=3,
        )

        self.assertEqual(
            result,
            ["trump background", "trump family", "trump business"],
        )

    def test_generate_sub_queries_falls_back_to_line_parsing_when_json_is_missing(self):
        transformer = FakeTransformer(
            """Sub-queries:
1. trump background
2. trump family
3. trump business
"""
        )

        result = generate_sub_queries(
            "5 facts about trump",
            transformer,
            max_queries=3,
        )

        self.assertEqual(
            result,
            ["trump background", "trump family", "trump business"],
        )

    def test_rejects_assistant_style_refusal_text_as_retrieval_query(self):
        self.assertFalse(
            is_valid_retrieval_query(
                'Please provide a specific question or topic you would like me to research.',
                original_query="hi",
            )
        )


if __name__ == "__main__":
    unittest.main()