File size: 21,363 Bytes
d1e5882
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
"""
Comprehensive tests for:
1. Per-Tool Latency Prediction
2. Context-Aware MCP Routing
3. Tool Output Schemas

Tests all three new features for intelligent tool selection and output validation.
"""

import pytest
from unittest.mock import Mock, patch, AsyncMock
from backend.api.services.tool_metadata import (
    get_tool_latency_estimate,
    estimate_path_latency,
    get_fastest_path,
    validate_tool_output,
    get_tool_schema,
    TOOL_LATENCY_METADATA,
    TOOL_OUTPUT_SCHEMAS
)
from backend.api.services.tool_selector import ToolSelector
from backend.api.services.agent_orchestrator import AgentOrchestrator


class TestLatencyPrediction:
    """Test per-tool latency prediction"""
    
    def test_get_tool_latency_estimate_basic(self):
        """Test basic latency estimation without context"""
        rag_latency = get_tool_latency_estimate("rag")
        web_latency = get_tool_latency_estimate("web")
        admin_latency = get_tool_latency_estimate("admin")
        llm_latency = get_tool_latency_estimate("llm")
        
        # Check that latencies are within expected ranges
        assert 60 <= rag_latency <= 120
        assert 400 <= web_latency <= 1800
        assert 5 <= admin_latency <= 20
        assert 500 <= llm_latency <= 5000
    
    def test_get_tool_latency_estimate_with_context(self):
        """Test latency estimation with context"""
        # RAG with long query
        rag_long = get_tool_latency_estimate("rag", {"query_length": 200})
        rag_short = get_tool_latency_estimate("rag", {"query_length": 10})
        
        assert rag_long >= rag_short  # Longer queries should take more time
        
        # Web with complexity
        web_complex = get_tool_latency_estimate("web", {"query_complexity": "high"})
        web_simple = get_tool_latency_estimate("web", {"query_complexity": "low"})
        
        assert web_complex >= web_simple  # Complex queries should take more time
    
    def test_estimate_path_latency(self):
        """Test total latency estimation for tool sequences"""
        # Single tool
        single = estimate_path_latency(["admin"])
        assert single > 0
        assert single <= 20
        
        # Multiple tools
        multi = estimate_path_latency(["rag", "web", "llm"])
        assert multi > 0
        # Should be sum of individual latencies
        assert multi >= get_tool_latency_estimate("rag")
        assert multi >= get_tool_latency_estimate("web")
        assert multi >= get_tool_latency_estimate("llm")
    
    def test_get_fastest_path(self):
        """Test fastest path optimization"""
        tools = ["llm", "admin", "rag", "web"]
        fastest = get_fastest_path(tools)
        
        # Should be sorted by latency (fastest first)
        assert len(fastest) == len(tools)
        assert "admin" in fastest  # Fastest tool
        assert fastest[0] == "admin"  # Should be first
        
        # Verify order is optimized
        latencies = [get_tool_latency_estimate(t) for t in fastest]
        assert latencies == sorted(latencies)  # Should be in ascending order
    
    def test_latency_metadata_structure(self):
        """Test that latency metadata has correct structure"""
        for tool_name, metadata in TOOL_LATENCY_METADATA.items():
            assert metadata.tool_name == tool_name
            assert metadata.min_ms > 0
            assert metadata.max_ms >= metadata.min_ms
            assert metadata.avg_ms >= metadata.min_ms
            assert metadata.avg_ms <= metadata.max_ms
            assert len(metadata.description) > 0


class TestToolOutputSchemas:
    """Test tool output schema validation"""
    
    def test_get_tool_schema(self):
        """Test schema retrieval"""
        rag_schema = get_tool_schema("rag")
        web_schema = get_tool_schema("web")
        admin_schema = get_tool_schema("admin")
        llm_schema = get_tool_schema("llm")
        
        assert rag_schema is not None
        assert web_schema is not None
        assert admin_schema is not None
        assert llm_schema is not None
        
        assert rag_schema.tool_name == "rag"
        assert web_schema.tool_name == "web"
        assert admin_schema.tool_name == "admin"
        assert llm_schema.tool_name == "llm"
    
    def test_validate_rag_output_valid(self):
        """Test validation of valid RAG output"""
        valid_rag = {
            "results": [
                {
                    "text": "Document chunk",
                    "similarity": 0.85,
                    "metadata": {"title": "Test"},
                    "doc_id": "doc123"
                }
            ],
            "query": "test query",
            "tenant_id": "tenant1",
            "hits_count": 1,
            "avg_score": 0.85,
            "top_score": 0.85,
            "latency_ms": 90
        }
        
        is_valid, error = validate_tool_output("rag", valid_rag)
        assert is_valid is True
        assert error is None
    
    def test_validate_rag_output_missing_field(self):
        """Test validation catches missing required fields"""
        invalid_rag = {
            "results": [],
            # Missing "query" and "tenant_id"
            "hits_count": 0
        }
        
        is_valid, error = validate_tool_output("rag", invalid_rag)
        assert is_valid is False
        assert "Missing required field" in error
    
    def test_validate_web_output_valid(self):
        """Test validation of valid Web output"""
        valid_web = {
            "results": [
                {
                    "title": "Result Title",
                    "snippet": "Result snippet",
                    "link": "https://example.com",
                    "displayLink": "example.com"
                }
            ],
            "query": "search query",
            "total_results": 10,
            "latency_ms": 800
        }
        
        is_valid, error = validate_tool_output("web", valid_web)
        assert is_valid is True
        assert error is None
    
    def test_validate_admin_output_valid(self):
        """Test validation of valid Admin output"""
        valid_admin = {
            "violations": [
                {
                    "rule_id": "rule1",
                    "rule_pattern": ".*password.*",
                    "severity": "high",
                    "matched_text": "password",
                    "confidence": 0.95,
                    "message_preview": "User asked for password"
                }
            ],
            "checked": True,
            "rules_count": 5,
            "latency_ms": 10
        }
        
        is_valid, error = validate_tool_output("admin", valid_admin)
        assert is_valid is True
        assert error is None
    
    def test_validate_llm_output_valid(self):
        """Test validation of valid LLM output"""
        valid_llm = {
            "text": "Generated response",
            "tokens_used": 150,
            "latency_ms": 2000,
            "model": "llama3.1:latest",
            "temperature": 0.0
        }
        
        is_valid, error = validate_tool_output("llm", valid_llm)
        assert is_valid is True
        assert error is None
    
    def test_validate_type_mismatch(self):
        """Test validation catches type mismatches"""
        invalid_rag = {
            "results": "not an array",  # Should be array
            "query": "test",
            "tenant_id": "tenant1"
        }
        
        is_valid, error = validate_tool_output("rag", invalid_rag)
        assert is_valid is False
        assert "must be array" in error
    
    def test_schema_examples(self):
        """Test that all schemas have examples"""
        for tool_name, schema in TOOL_OUTPUT_SCHEMAS.items():
            assert schema.example is not None
            assert isinstance(schema.example, dict)
            # Example should be valid
            is_valid, error = validate_tool_output(tool_name, schema.example)
            assert is_valid is True, f"Schema example for {tool_name} is invalid: {error}"


class TestContextAwareRouting:
    """Test context-aware MCP routing"""
    
    @pytest.fixture
    def tool_selector(self):
        """Create a ToolSelector instance"""
        return ToolSelector(llm_client=None)
    
    def test_analyze_context_rag_high_score(self, tool_selector):
        """Test context analysis when RAG returns high score"""
        rag_results = [
            {"similarity": 0.85, "text": "High quality result"},
            {"similarity": 0.90, "text": "Another high quality result"}
        ]
        memory = []
        admin_violations = []
        tool_scores = {"rag_fitness": 0.8, "web_fitness": 0.5}
        
        hints = tool_selector._analyze_context(rag_results, memory, admin_violations, tool_scores)
        
        assert hints.get("skip_web_if_rag_high") is True
        assert hints.get("rag_high_confidence") is True
    
    def test_analyze_context_rag_low_score(self, tool_selector):
        """Test context analysis when RAG returns low score"""
        rag_results = [
            {"similarity": 0.3, "text": "Low quality result"}
        ]
        memory = []
        admin_violations = []
        tool_scores = {"rag_fitness": 0.3, "web_fitness": 0.7}
        
        hints = tool_selector._analyze_context(rag_results, memory, admin_violations, tool_scores)
        
        # Should not skip web if RAG score is low
        assert hints.get("skip_web_if_rag_high") is not True
    
    def test_analyze_context_memory_relevant(self, tool_selector):
        """Test context analysis when relevant memory exists"""
        rag_results = []
        memory = [
            {
                "tool": "rag",
                "result": {
                    "results": [
                        {"similarity": 0.80, "text": "Recent RAG result"}
                    ]
                }
            }
        ]
        admin_violations = []
        tool_scores = {}
        
        hints = tool_selector._analyze_context(rag_results, memory, admin_violations, tool_scores)
        
        assert hints.get("has_relevant_memory") is True
        # Should suggest skipping RAG if memory is recent and high quality
        if memory[0]["result"]["results"][0]["similarity"] >= 0.75:
            assert hints.get("skip_rag_if_memory") is True
    
    def test_analyze_context_admin_critical(self, tool_selector):
        """Test context analysis when admin violation is critical"""
        rag_results = []
        memory = []
        admin_violations = [
            {
                "severity": "critical",
                "rule_id": "rule1",
                "matched_text": "sensitive data"
            }
        ]
        tool_scores = {}
        
        hints = tool_selector._analyze_context(rag_results, memory, admin_violations, tool_scores)
        
        assert hints.get("skip_agent_reasoning") is True
        assert hints.get("critical_violation") is True
    
    def test_analyze_context_admin_low_severity(self, tool_selector):
        """Test context analysis when admin violation is low severity"""
        rag_results = []
        memory = []
        admin_violations = [
            {
                "severity": "low",
                "rule_id": "rule1",
                "matched_text": "minor issue"
            }
        ]
        tool_scores = {}
        
        hints = tool_selector._analyze_context(rag_results, memory, admin_violations, tool_scores)
        
        # Low severity should not skip reasoning
        assert hints.get("skip_agent_reasoning") is not True
    
    @pytest.mark.asyncio
    async def test_tool_selection_with_context_hints(self, tool_selector):
        """Test tool selection uses context hints"""
        # Mock LLM client
        tool_selector.llm_client = AsyncMock()
        
        # Context with high RAG score
        ctx = {
            "tenant_id": "test_tenant",
            "rag_results": [
                {"similarity": 0.85, "text": "High quality result"}
            ],
            "tool_scores": {
                "rag_fitness": 0.8,
                "web_fitness": 0.6,
                "llm_only": 0.3
            },
            "memory": [],
            "admin_violations": []
        }
        
        decision = await tool_selector.select("general", "What is our company policy?", ctx)
        
        # Should include latency estimates in reason
        assert "latency" in decision.reason.lower() or "est." in decision.reason.lower()
        
        # Check that steps have latency estimates (for non-LLM tools)
        if decision.tool_input and "steps" in decision.tool_input:
            steps = decision.tool_input["steps"]
            for step in steps:
                if isinstance(step, dict) and "input" in step and step.get("tool") != "llm":
                    # Non-LLM tools should have estimated latency (or be parallel)
                    assert "_estimated_latency_ms" in step["input"] or "parallel" in step or step.get("tool") == "llm"
    
    @pytest.mark.asyncio
    async def test_tool_selection_skips_web_on_high_rag(self, tool_selector):
        """Test that tool selection skips web when RAG has high score"""
        tool_selector.llm_client = AsyncMock()
        
        ctx = {
            "tenant_id": "test_tenant",
            "rag_results": [
                {"similarity": 0.90, "text": "Very high quality result"}
            ],
            "tool_scores": {
                "rag_fitness": 0.9,
                "web_fitness": 0.7,
                "llm_only": 0.2
            },
            "memory": [],
            "admin_violations": []
        }
        
        decision = await tool_selector.select("general", "What is our internal policy?", ctx)
        
        # Check reason includes context hint
        assert "skip web" in decision.reason.lower() or "rag high" in decision.reason.lower() or "context" in decision.reason.lower()
    
    @pytest.mark.asyncio
    async def test_tool_selection_admin_critical_skip_reasoning(self, tool_selector):
        """Test that tool selection skips reasoning for critical admin violations"""
        tool_selector.llm_client = None  # No LLM needed for admin-only path
        
        ctx = {
            "tenant_id": "test_tenant",
            "rag_results": [],
            "tool_scores": {},
            "memory": [],
            "admin_violations": [
                {
                    "severity": "critical",
                    "rule_id": "rule1",
                    "matched_text": "critical violation"
                }
            ]
        }
        
        decision = await tool_selector.select("admin", "User trying to access sensitive data", ctx)
        
        # Should skip LLM reasoning for critical violations
        if decision.tool_input and "steps" in decision.tool_input:
            steps = decision.tool_input["steps"]
            # Should have admin step but may skip LLM
            has_admin = any(s.get("tool") == "admin" for s in steps if isinstance(s, dict))
            assert has_admin


class TestOrchestratorIntegration:
    """Test orchestrator integration with new features"""
    
    @pytest.fixture
    def orchestrator(self):
        """Create an AgentOrchestrator instance"""
        return AgentOrchestrator(
            rag_mcp_url="http://localhost:8900/rag",
            web_mcp_url="http://localhost:8900/web",
            admin_mcp_url="http://localhost:8900/admin",
            llm_backend="ollama"
        )
    
    def test_format_rag_output(self, orchestrator):
        """Test RAG output formatting"""
        raw_output = {
            "results": [
                {"text": "Chunk 1", "similarity": 0.85},
                {"text": "Chunk 2", "similarity": 0.75}
            ],
            "query": "test query"
        }
        
        formatted = orchestrator._format_tool_output("rag", raw_output, 90)
        
        # Check schema compliance
        assert "results" in formatted
        assert "query" in formatted
        assert "tenant_id" in formatted
        assert "hits_count" in formatted
        assert "avg_score" in formatted
        assert "top_score" in formatted
        assert "latency_ms" in formatted
        
        # Validate against schema
        is_valid, error = validate_tool_output("rag", formatted)
        assert is_valid is True, f"Formatted RAG output invalid: {error}"
    
    def test_format_web_output(self, orchestrator):
        """Test Web output formatting"""
        raw_output = {
            "items": [
                {
                    "title": "Result Title",
                    "snippet": "Result snippet",
                    "link": "https://example.com"
                }
            ]
        }
        
        formatted = orchestrator._format_tool_output("web", raw_output, 800)
        
        # Check schema compliance
        assert "results" in formatted
        assert "query" in formatted
        assert "total_results" in formatted
        assert "latency_ms" in formatted
        
        # Validate against schema
        is_valid, error = validate_tool_output("web", formatted)
        assert is_valid is True, f"Formatted Web output invalid: {error}"
    
    def test_format_admin_output(self, orchestrator):
        """Test Admin output formatting"""
        raw_output = {
            "matches": [
                {
                    "rule_id": "rule1",
                    "pattern": ".*password.*",
                    "severity": "high",
                    "text": "password",
                    "confidence": 0.95
                }
            ]
        }
        
        formatted = orchestrator._format_tool_output("admin", raw_output, 10)
        
        # Check schema compliance
        assert "violations" in formatted
        assert "checked" in formatted
        assert "rules_count" in formatted
        assert "latency_ms" in formatted
        
        # Validate against schema
        is_valid, error = validate_tool_output("admin", formatted)
        assert is_valid is True, f"Formatted Admin output invalid: {error}"
    
    def test_format_llm_output(self, orchestrator):
        """Test LLM output formatting"""
        raw_output = "This is a generated response from the LLM."
        
        formatted = orchestrator._format_tool_output("llm", raw_output, 2000)
        
        # Check schema compliance
        assert "text" in formatted
        assert "tokens_used" in formatted
        assert "latency_ms" in formatted
        assert "model" in formatted
        assert "temperature" in formatted
        
        # Validate against schema
        is_valid, error = validate_tool_output("llm", formatted)
        assert is_valid is True, f"Formatted LLM output invalid: {error}"
    
    def test_format_output_handles_missing_fields(self, orchestrator):
        """Test output formatting handles missing fields gracefully"""
        # Minimal RAG output
        minimal = {"results": []}
        
        formatted = orchestrator._format_tool_output("rag", minimal, 90)
        
        # Should have all required fields with defaults
        assert "query" in formatted
        assert "tenant_id" in formatted
        assert "hits_count" in formatted
        assert formatted["hits_count"] == 0


class TestEndToEndRouting:
    """End-to-end tests for context-aware routing"""
    
    @pytest.mark.asyncio
    async def test_routing_with_high_rag_score(self):
        """Test that high RAG score prevents web search"""
        selector = ToolSelector(llm_client=None)
        
        ctx = {
            "tenant_id": "test",
            "rag_results": [{"similarity": 0.92, "text": "Perfect match"}],
            "tool_scores": {"rag_fitness": 0.9, "web_fitness": 0.7},
            "memory": [],
            "admin_violations": []
        }
        
        decision = await selector.select("general", "What is our policy?", ctx)
        
        # Check that context hints are applied
        if decision.tool_input and "steps" in decision.tool_input:
            steps = decision.tool_input["steps"]
            tool_names = [s.get("tool") for s in steps if isinstance(s, dict) and "tool" in s]
            
            # Should have RAG but may skip web due to high score
            assert "rag" in tool_names or "llm" in tool_names
    
    @pytest.mark.asyncio
    async def test_routing_with_memory(self):
        """Test that relevant memory prevents redundant RAG call"""
        selector = ToolSelector(llm_client=None)
        
        ctx = {
            "tenant_id": "test",
            "rag_results": [],
            "tool_scores": {"rag_fitness": 0.6},
            "memory": [
                {
                    "tool": "rag",
                    "result": {
                        "results": [{"similarity": 0.85, "text": "Recent result"}]
                    }
                }
            ],
            "admin_violations": []
        }
        
        decision = await selector.select("general", "Tell me about our policy", ctx)
        
        # Context should be analyzed
        # (Actual behavior depends on implementation, but should use memory)
        assert decision is not None


if __name__ == "__main__":
    pytest.main([__file__, "-v", "--tb=short"])