File size: 1,775 Bytes
c76423f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import unittest

from rag.retrieve import AnswerDraftResult, generate_answer


class FakeResponse:
    def __init__(self, content):
        self.content = content


class FakeStructuredRunnable:
    def __init__(self, structured_response):
        self.structured_response = structured_response

    def invoke(self, prompt):
        return self.structured_response


class FakeStructuredLLM:
    def __init__(self, structured_response):
        self.structured_response = structured_response
        self.prompts = []

    def with_structured_output(self, schema):
        self.schema = schema
        return FakeStructuredRunnable(self.structured_response)

    def invoke(self, prompt):
        self.prompts.append(prompt)
        return FakeResponse("")


class AnswerStructuredOutputTests(unittest.TestCase):
    def test_generate_answer_uses_structured_output_when_available(self):
        llm = FakeStructuredLLM(AnswerDraftResult(answer="Structured answer [1]"))
        context = "Source [1] climate note"

        result = generate_answer("What is climate change?", context, llm)

        self.assertEqual(result, "Structured answer [1]")
        self.assertIs(llm.schema, AnswerDraftResult)

    def test_generate_answer_falls_back_to_plain_text_when_structured_output_is_missing(self):
        class PlainLLM:
            def __init__(self):
                self.prompts = []

            def invoke(self, prompt):
                self.prompts.append(prompt)
                return FakeResponse("Plain answer [1]")

        llm = PlainLLM()
        context = "Source [1] climate note"

        result = generate_answer("What is climate change?", context, llm)

        self.assertEqual(result, "Plain answer [1]")


if __name__ == "__main__":
    unittest.main()