ayushm98 commited on
Commit
d4faa2c
·
1 Parent(s): e65c395

Add tests for routing and heuristics module

Browse files
Files changed (1) hide show
  1. tests/test_router.py +101 -0
tests/test_router.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the routing module."""
2
+
3
+ import pytest
4
+ from cascade.router.heuristics import classify_by_heuristics, estimate_token_count
5
+ from cascade.router.routing_engine import RoutingEngine, RoutingDecision
6
+
7
+
8
+ class TestHeuristics:
9
+ """Tests for heuristic-based classification."""
10
+
11
+ def test_simple_greetings(self, sample_queries):
12
+ """Simple greetings should be classified as simple."""
13
+ for query in sample_queries["simple"]:
14
+ score, label = classify_by_heuristics(query)
15
+ assert label == "simple" or score < 0.5, f"Failed for: {query}"
16
+
17
+ def test_complex_coding_queries(self, sample_queries):
18
+ """Coding queries should be classified as complex."""
19
+ for query in sample_queries["complex"]:
20
+ score, label = classify_by_heuristics(query)
21
+ assert label == "complex" or score > 0.7, f"Failed for: {query}"
22
+
23
+ def test_code_block_detection(self):
24
+ """Queries with code blocks should be complex."""
25
+ query = "Can you fix this?\n```python\ndef foo():\n pass\n```"
26
+ score, label = classify_by_heuristics(query)
27
+ assert label == "complex"
28
+ assert score >= 0.85
29
+
30
+ def test_keyword_detection(self):
31
+ """Keywords should trigger appropriate classification."""
32
+ # Complex keywords
33
+ assert classify_by_heuristics("write a function")[1] == "complex"
34
+ assert classify_by_heuristics("implement an algorithm")[1] == "complex"
35
+
36
+ # Simple keywords
37
+ assert classify_by_heuristics("hello there")[1] == "simple"
38
+ assert classify_by_heuristics("what is Python")[1] == "simple"
39
+
40
+ def test_length_based_classification(self):
41
+ """Very short queries should be simple."""
42
+ score, label = classify_by_heuristics("Hi")
43
+ assert label == "simple"
44
+ assert score < 0.3
45
+
46
+ def test_estimate_token_count(self):
47
+ """Token estimation should be reasonable."""
48
+ text = "Hello world" # ~3 tokens
49
+ estimate = estimate_token_count(text)
50
+ assert 2 <= estimate <= 5
51
+
52
+
53
+ class TestRoutingEngine:
54
+ """Tests for the routing engine."""
55
+
56
+ def test_routing_decision_creation(self):
57
+ """RoutingDecision should be created correctly."""
58
+ decision = RoutingDecision(
59
+ complexity_score=0.8,
60
+ complexity_label="complex",
61
+ recommended_model="gpt-4o",
62
+ routing_reason="High complexity query",
63
+ )
64
+ assert decision.complexity_score == 0.8
65
+ assert decision.complexity_label == "complex"
66
+ assert decision.recommended_model == "gpt-4o"
67
+
68
+ def test_model_selection_by_threshold(self):
69
+ """Models should be selected based on complexity thresholds."""
70
+ engine = RoutingEngine()
71
+
72
+ # Simple -> local model
73
+ assert engine._select_model(0.2) == "llama3.2"
74
+
75
+ # Medium -> mini model
76
+ assert engine._select_model(0.5) == "gpt-4o-mini"
77
+
78
+ # Complex -> full model
79
+ assert engine._select_model(0.85) == "gpt-4o"
80
+
81
+ def test_threshold_boundaries(self):
82
+ """Test exact threshold boundaries."""
83
+ engine = RoutingEngine()
84
+
85
+ # At lower boundary
86
+ assert engine._select_model(0.35) == "gpt-4o-mini"
87
+
88
+ # At upper boundary
89
+ assert engine._select_model(0.70) == "gpt-4o"
90
+
91
+ @pytest.mark.asyncio
92
+ async def test_route_query_returns_decision(self):
93
+ """route_query should return a valid RoutingDecision."""
94
+ from cascade.router import route_query
95
+
96
+ decision = await route_query("Hello world")
97
+
98
+ assert isinstance(decision, RoutingDecision)
99
+ assert 0 <= decision.complexity_score <= 1
100
+ assert decision.complexity_label in ["simple", "medium", "complex"]
101
+ assert decision.recommended_model in ["llama3.2", "gpt-4o-mini", "gpt-4o"]