File size: 8,299 Bytes
8b30412
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
"""Integration tests for Langgraph state manager."""

import pytest
from langchain_core.messages import HumanMessage

from chatassistant_retail.state import ConversationState, LanggraphManager


class MockLLMClient:
    """Mock LLM client for testing."""

    async def call_llm(self, messages, tools=None):
        """Mock LLM call - returns dictionary format."""
        return {
            "choices": [
                {
                    "message": {
                        "content": "This is a test response.",
                        "role": "assistant",
                        "tool_calls": None,
                    }
                }
            ]
        }

    async def extract_response_content(self, response):
        """Extract response content from dictionary."""
        if isinstance(response, dict) and "choices" in response:
            choices = response.get("choices", [])
            if choices:
                message = choices[0].get("message", {})
                return message.get("content", "")
        return ""

    async def extract_tool_calls(self, response):
        """Extract tool calls from dictionary."""
        if isinstance(response, dict) and "choices" in response:
            choices = response.get("choices", [])
            if choices:
                message = choices[0].get("message", {})
                tool_calls = message.get("tool_calls", [])
                if tool_calls:
                    return [
                        {
                            "name": tc.get("function", {}).get("name", ""),
                            "arguments": tc.get("function", {}).get("arguments", {}),
                        }
                        for tc in tool_calls
                    ]
        return []


class MockRAGRetriever:
    """Mock RAG retriever for testing."""

    async def retrieve(self, query, top_k=5):
        """Mock retrieval."""
        return [
            {
                "sku": "SKU-10000",
                "name": "Test Product",
                "category": "Electronics",
                "price": 99.99,
                "current_stock": 5,
                "reorder_level": 10,
            }
        ]


class MockToolExecutor:
    """Mock tool executor for testing."""

    async def execute_tool(self, tool_name, args):
        """Mock tool execution."""
        return {
            "success": True,
            "message": f"Executed {tool_name} with args {args}",
        }


class TestLanggraphManager:
    """Test Langgraph state management."""

    @pytest.mark.asyncio
    async def test_greeting_classification(self):
        """Test that greetings are classified correctly."""
        llm_client = MockLLMClient()
        rag_retriever = MockRAGRetriever()
        tool_executor = MockToolExecutor()

        manager = LanggraphManager(llm_client, rag_retriever, tool_executor)

        state = ConversationState(
            session_id="test-session",
            messages=[HumanMessage(content="Hello")],
        )

        # Classify intent
        state = await manager._classify_intent_node(state)
        assert state.current_intent == "greeting"

    @pytest.mark.asyncio
    async def test_rag_classification(self):
        """Test that product queries are classified as RAG."""
        llm_client = MockLLMClient()
        rag_retriever = MockRAGRetriever()
        tool_executor = MockToolExecutor()

        manager = LanggraphManager(llm_client, rag_retriever, tool_executor)

        state = ConversationState(
            session_id="test-session",
            messages=[HumanMessage(content="Find me a wireless mouse")],
        )

        # Classify intent
        state = await manager._classify_intent_node(state)
        assert state.current_intent == "rag"
        assert state.needs_rag is True

    @pytest.mark.asyncio
    async def test_tool_classification(self):
        """Test that tool-related queries are classified correctly."""
        llm_client = MockLLMClient()
        rag_retriever = MockRAGRetriever()
        tool_executor = MockToolExecutor()

        manager = LanggraphManager(llm_client, rag_retriever, tool_executor)

        state = ConversationState(
            session_id="test-session",
            messages=[HumanMessage(content="Check low stock items")],
        )

        # Classify intent
        state = await manager._classify_intent_node(state)
        assert state.current_intent == "tool"
        assert state.needs_tool is True

    @pytest.mark.asyncio
    async def test_rag_retrieval_node(self):
        """Test RAG retrieval node."""
        llm_client = MockLLMClient()
        rag_retriever = MockRAGRetriever()
        tool_executor = MockToolExecutor()

        manager = LanggraphManager(llm_client, rag_retriever, tool_executor)

        state = ConversationState(
            session_id="test-session",
            messages=[HumanMessage(content="Find wireless mouse")],
        )

        # Execute retrieval
        state = await manager._rag_retrieval_node(state)

        assert "products" in state.context
        assert len(state.context["products"]) > 0
        assert state.context["products"][0]["name"] == "Test Product"

    @pytest.mark.asyncio
    async def test_generate_response_node(self):
        """Test response generation node."""
        llm_client = MockLLMClient()
        rag_retriever = MockRAGRetriever()
        tool_executor = MockToolExecutor()

        manager = LanggraphManager(llm_client, rag_retriever, tool_executor)

        state = ConversationState(
            session_id="test-session",
            messages=[HumanMessage(content="Hello")],
            context={"products": []},
        )

        # Generate response
        state = await manager._generate_response_node(state)

        # Should have added an AI message
        assert len(state.messages) == 2
        assert state.messages[1].content == "This is a test response."

    @pytest.mark.asyncio
    async def test_full_workflow_greeting(self):
        """Test full workflow for greeting."""
        llm_client = MockLLMClient()
        rag_retriever = MockRAGRetriever()
        tool_executor = MockToolExecutor()

        manager = LanggraphManager(llm_client, rag_retriever, tool_executor)

        state = ConversationState(
            session_id="test-session",
            messages=[HumanMessage(content="Hi there")],
        )

        # Process through workflow
        final_state = await manager.process(state)

        # Should have response
        assert len(final_state.messages) == 2
        assert final_state.current_intent == "greeting"
        assert final_state.error is None

    @pytest.mark.asyncio
    async def test_full_workflow_rag(self):
        """Test full workflow for RAG query."""
        llm_client = MockLLMClient()
        rag_retriever = MockRAGRetriever()
        tool_executor = MockToolExecutor()

        manager = LanggraphManager(llm_client, rag_retriever, tool_executor)

        state = ConversationState(
            session_id="test-session",
            messages=[HumanMessage(content="Find electronics products")],
        )

        # Process through workflow
        final_state = await manager.process(state)

        # Should have retrieved products and generated response
        assert "products" in final_state.context
        assert len(final_state.messages) == 2
        assert final_state.current_intent == "rag"

    @pytest.mark.asyncio
    async def test_error_handling(self):
        """Test error handling in workflow."""

        class FailingLLMClient:
            async def call_llm(self, messages, tools=None):
                raise Exception("LLM error")

        llm_client = FailingLLMClient()
        rag_retriever = MockRAGRetriever()
        tool_executor = MockToolExecutor()

        manager = LanggraphManager(llm_client, rag_retriever, tool_executor)

        state = ConversationState(
            session_id="test-session",
            messages=[HumanMessage(content="Hello")],
        )

        # Process through workflow (should handle error gracefully)
        final_state = await manager.process(state)

        # Should have error set
        assert final_state.error is not None


if __name__ == "__main__":
    pytest.main([__file__, "-v"])