|
|
"""Tests for the routing module.""" |
|
|
|
|
|
import pytest |
|
|
from cascade.router.heuristics import classify_by_heuristics, estimate_token_count |
|
|
from cascade.router.routing_engine import RoutingEngine, RoutingDecision |
|
|
|
|
|
|
|
|
class TestHeuristics: |
|
|
"""Tests for heuristic-based classification.""" |
|
|
|
|
|
def test_simple_greetings(self): |
|
|
"""Simple greetings should be classified as simple.""" |
|
|
simple_queries = ["Hello", "Hi there", "Thanks!", "yes", "no"] |
|
|
for query in simple_queries: |
|
|
score, label = classify_by_heuristics(query) |
|
|
assert label == "simple" or score < 0.5, f"Failed for: {query}" |
|
|
|
|
|
def test_complex_coding_queries(self, sample_queries): |
|
|
"""Coding queries should be classified as complex.""" |
|
|
for query in sample_queries["complex"]: |
|
|
score, label = classify_by_heuristics(query) |
|
|
assert label == "complex" or score > 0.7, f"Failed for: {query}" |
|
|
|
|
|
def test_code_block_detection(self): |
|
|
"""Queries with code blocks should be complex.""" |
|
|
query = "```python\ndef foo():\n pass\n```" |
|
|
score, label = classify_by_heuristics(query) |
|
|
assert label == "complex" |
|
|
assert score >= 0.85 |
|
|
|
|
|
def test_keyword_detection(self): |
|
|
"""Keywords should trigger appropriate classification.""" |
|
|
|
|
|
assert classify_by_heuristics("write a function")[1] == "complex" |
|
|
assert classify_by_heuristics("implement an algorithm")[1] == "complex" |
|
|
|
|
|
|
|
|
assert classify_by_heuristics("hello there")[1] == "simple" |
|
|
assert classify_by_heuristics("what is Python")[1] == "simple" |
|
|
|
|
|
def test_length_based_classification(self): |
|
|
"""Very short queries should be simple.""" |
|
|
score, label = classify_by_heuristics("Hi") |
|
|
assert label == "simple" |
|
|
assert score < 0.3 |
|
|
|
|
|
def test_estimate_token_count(self): |
|
|
"""Token estimation should be reasonable.""" |
|
|
text = "Hello world" |
|
|
estimate = estimate_token_count(text) |
|
|
assert 2 <= estimate <= 5 |
|
|
|
|
|
|
|
|
class TestRoutingEngine: |
|
|
"""Tests for the routing engine.""" |
|
|
|
|
|
def test_routing_decision_creation(self): |
|
|
"""RoutingDecision should be created correctly.""" |
|
|
decision = RoutingDecision( |
|
|
complexity_score=0.8, |
|
|
complexity_label="complex", |
|
|
model="gpt-4o", |
|
|
reason="High complexity query", |
|
|
) |
|
|
assert decision.complexity_score == 0.8 |
|
|
assert decision.complexity_label == "complex" |
|
|
assert decision.model == "gpt-4o" |
|
|
|
|
|
def test_complexity_label_thresholds(self): |
|
|
"""Complexity labels should be determined by thresholds.""" |
|
|
engine = RoutingEngine() |
|
|
|
|
|
|
|
|
assert engine._get_complexity_label(0.2) == "simple" |
|
|
|
|
|
|
|
|
assert engine._get_complexity_label(0.5) == "medium" |
|
|
|
|
|
|
|
|
assert engine._get_complexity_label(0.85) == "complex" |
|
|
|
|
|
def test_threshold_boundaries(self): |
|
|
"""Test exact threshold boundaries.""" |
|
|
engine = RoutingEngine() |
|
|
|
|
|
|
|
|
assert engine._get_complexity_label(0.35) == "medium" |
|
|
|
|
|
|
|
|
assert engine._get_complexity_label(0.71) == "complex" |
|
|
|
|
|
@pytest.mark.asyncio |
|
|
async def test_route_query_returns_decision(self): |
|
|
"""route_query should return a valid RoutingDecision.""" |
|
|
from cascade.router import route_query |
|
|
|
|
|
decision = await route_query("Hello world") |
|
|
|
|
|
assert isinstance(decision, RoutingDecision) |
|
|
assert 0 <= decision.complexity_score <= 1 |
|
|
assert decision.complexity_label in ["simple", "medium", "complex"] |
|
|
assert decision.model in ["llama3.2", "gpt-4o-mini", "gpt-4o"] |
|
|
|