| import unittest | |
| from rag.retrieve import AnswerDraftResult, generate_answer | |
| class FakeResponse: | |
| def __init__(self, content): | |
| self.content = content | |
| class FakeStructuredRunnable: | |
| def __init__(self, structured_response): | |
| self.structured_response = structured_response | |
| def invoke(self, prompt): | |
| return self.structured_response | |
| class FakeStructuredLLM: | |
| def __init__(self, structured_response): | |
| self.structured_response = structured_response | |
| self.prompts = [] | |
| def with_structured_output(self, schema): | |
| self.schema = schema | |
| return FakeStructuredRunnable(self.structured_response) | |
| def invoke(self, prompt): | |
| self.prompts.append(prompt) | |
| return FakeResponse("") | |
| class AnswerStructuredOutputTests(unittest.TestCase): | |
| def test_generate_answer_uses_structured_output_when_available(self): | |
| llm = FakeStructuredLLM(AnswerDraftResult(answer="Structured answer [1]")) | |
| context = "Source [1] climate note" | |
| result = generate_answer("What is climate change?", context, llm) | |
| self.assertEqual(result, "Structured answer [1]") | |
| self.assertIs(llm.schema, AnswerDraftResult) | |
| def test_generate_answer_falls_back_to_plain_text_when_structured_output_is_missing(self): | |
| class PlainLLM: | |
| def __init__(self): | |
| self.prompts = [] | |
| def invoke(self, prompt): | |
| self.prompts.append(prompt) | |
| return FakeResponse("Plain answer [1]") | |
| llm = PlainLLM() | |
| context = "Source [1] climate note" | |
| result = generate_answer("What is climate change?", context, llm) | |
| self.assertEqual(result, "Plain answer [1]") | |
| if __name__ == "__main__": | |
| unittest.main() | |