File size: 3,873 Bytes
d4faa2c
 
 
 
 
 
 
 
 
 
666d4f6
d4faa2c
666d4f6
 
d4faa2c
 
 
 
 
 
 
 
 
 
 
666d4f6
d4faa2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
666d4f6
 
d4faa2c
 
 
666d4f6
d4faa2c
666d4f6
 
d4faa2c
 
666d4f6
 
d4faa2c
666d4f6
 
d4faa2c
666d4f6
 
d4faa2c
 
 
 
 
666d4f6
 
d4faa2c
666d4f6
 
d4faa2c
 
 
 
 
 
 
 
 
 
 
666d4f6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""Tests for the routing module."""

import pytest
from cascade.router.heuristics import classify_by_heuristics, estimate_token_count
from cascade.router.routing_engine import RoutingEngine, RoutingDecision


class TestHeuristics:
    """Tests for heuristic-based classification."""

    def test_simple_greetings(self):
        """Simple greetings should be classified as simple."""
        simple_queries = ["Hello", "Hi there", "Thanks!", "yes", "no"]
        for query in simple_queries:
            score, label = classify_by_heuristics(query)
            assert label == "simple" or score < 0.5, f"Failed for: {query}"

    def test_complex_coding_queries(self, sample_queries):
        """Coding queries should be classified as complex."""
        for query in sample_queries["complex"]:
            score, label = classify_by_heuristics(query)
            assert label == "complex" or score > 0.7, f"Failed for: {query}"

    def test_code_block_detection(self):
        """Queries with code blocks should be complex."""
        query = "```python\ndef foo():\n    pass\n```"
        score, label = classify_by_heuristics(query)
        assert label == "complex"
        assert score >= 0.85

    def test_keyword_detection(self):
        """Keywords should trigger appropriate classification."""
        # Complex keywords
        assert classify_by_heuristics("write a function")[1] == "complex"
        assert classify_by_heuristics("implement an algorithm")[1] == "complex"

        # Simple keywords
        assert classify_by_heuristics("hello there")[1] == "simple"
        assert classify_by_heuristics("what is Python")[1] == "simple"

    def test_length_based_classification(self):
        """Very short queries should be simple."""
        score, label = classify_by_heuristics("Hi")
        assert label == "simple"
        assert score < 0.3

    def test_estimate_token_count(self):
        """Token estimation should be reasonable."""
        text = "Hello world"  # ~3 tokens
        estimate = estimate_token_count(text)
        assert 2 <= estimate <= 5


class TestRoutingEngine:
    """Tests for the routing engine."""

    def test_routing_decision_creation(self):
        """RoutingDecision should be created correctly."""
        decision = RoutingDecision(
            complexity_score=0.8,
            complexity_label="complex",
            model="gpt-4o",
            reason="High complexity query",
        )
        assert decision.complexity_score == 0.8
        assert decision.complexity_label == "complex"
        assert decision.model == "gpt-4o"

    def test_complexity_label_thresholds(self):
        """Complexity labels should be determined by thresholds."""
        engine = RoutingEngine()

        # Simple -> score < 0.35
        assert engine._get_complexity_label(0.2) == "simple"

        # Medium -> 0.35 <= score <= 0.70
        assert engine._get_complexity_label(0.5) == "medium"

        # Complex -> score > 0.70
        assert engine._get_complexity_label(0.85) == "complex"

    def test_threshold_boundaries(self):
        """Test exact threshold boundaries."""
        engine = RoutingEngine()

        # At lower boundary - still medium
        assert engine._get_complexity_label(0.35) == "medium"

        # Just above upper boundary - complex
        assert engine._get_complexity_label(0.71) == "complex"

    @pytest.mark.asyncio
    async def test_route_query_returns_decision(self):
        """route_query should return a valid RoutingDecision."""
        from cascade.router import route_query

        decision = await route_query("Hello world")

        assert isinstance(decision, RoutingDecision)
        assert 0 <= decision.complexity_score <= 1
        assert decision.complexity_label in ["simple", "medium", "complex"]
        assert decision.model in ["llama3.2", "gpt-4o-mini", "gpt-4o"]