Fix API routes and tests to use correct attribute names
Browse files- poetry.lock +0 -0
- src/cascade/api/routes.py +12 -2
- src/cascade/config.py +4 -0
- tests/test_router.py +20 -19
poetry.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/cascade/api/routes.py
CHANGED
|
@@ -105,7 +105,7 @@ async def chat_completions(request: ChatCompletionRequest):
|
|
| 105 |
|
| 106 |
# Determine final model
|
| 107 |
if request.model == "auto" or request.model is None:
|
| 108 |
-
final_model = routing.
|
| 109 |
else:
|
| 110 |
final_model = request.model
|
| 111 |
|
|
@@ -121,12 +121,22 @@ async def chat_completions(request: ChatCompletionRequest):
|
|
| 121 |
provider = await get_provider(final_model)
|
| 122 |
|
| 123 |
try:
|
| 124 |
-
|
| 125 |
model=final_model,
|
| 126 |
messages=[{"role": m.role, "content": m.content} for m in request.messages],
|
| 127 |
temperature=request.temperature,
|
| 128 |
max_tokens=request.max_tokens,
|
| 129 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
except Exception as e:
|
| 131 |
raise HTTPException(status_code=500, detail=str(e))
|
| 132 |
|
|
|
|
| 105 |
|
| 106 |
# Determine final model
|
| 107 |
if request.model == "auto" or request.model is None:
|
| 108 |
+
final_model = routing.model
|
| 109 |
else:
|
| 110 |
final_model = request.model
|
| 111 |
|
|
|
|
| 121 |
provider = await get_provider(final_model)
|
| 122 |
|
| 123 |
try:
|
| 124 |
+
llm_response = await provider.complete(
|
| 125 |
model=final_model,
|
| 126 |
messages=[{"role": m.role, "content": m.content} for m in request.messages],
|
| 127 |
temperature=request.temperature,
|
| 128 |
max_tokens=request.max_tokens,
|
| 129 |
)
|
| 130 |
+
# Convert LLMResponse to dict format for compatibility
|
| 131 |
+
response = {
|
| 132 |
+
"id": f"cascade-{int(time.time())}",
|
| 133 |
+
"choices": [{"message": {"content": llm_response.content}}],
|
| 134 |
+
"usage": {
|
| 135 |
+
"prompt_tokens": llm_response.prompt_tokens,
|
| 136 |
+
"completion_tokens": llm_response.completion_tokens,
|
| 137 |
+
"total_tokens": llm_response.prompt_tokens + llm_response.completion_tokens,
|
| 138 |
+
},
|
| 139 |
+
}
|
| 140 |
except Exception as e:
|
| 141 |
raise HTTPException(status_code=500, detail=str(e))
|
| 142 |
|
src/cascade/config.py
CHANGED
|
@@ -43,3 +43,7 @@ class Settings(BaseSettings):
|
|
| 43 |
def get_settings() -> Settings:
|
| 44 |
"""Get cached settings instance."""
|
| 45 |
return Settings()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
def get_settings() -> Settings:
|
| 44 |
"""Get cached settings instance."""
|
| 45 |
return Settings()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# Convenience instance for direct import
|
| 49 |
+
settings = get_settings()
|
tests/test_router.py
CHANGED
|
@@ -8,9 +8,10 @@ from cascade.router.routing_engine import RoutingEngine, RoutingDecision
|
|
| 8 |
class TestHeuristics:
|
| 9 |
"""Tests for heuristic-based classification."""
|
| 10 |
|
| 11 |
-
def test_simple_greetings(self
|
| 12 |
"""Simple greetings should be classified as simple."""
|
| 13 |
-
|
|
|
|
| 14 |
score, label = classify_by_heuristics(query)
|
| 15 |
assert label == "simple" or score < 0.5, f"Failed for: {query}"
|
| 16 |
|
|
@@ -22,7 +23,7 @@ class TestHeuristics:
|
|
| 22 |
|
| 23 |
def test_code_block_detection(self):
|
| 24 |
"""Queries with code blocks should be complex."""
|
| 25 |
-
query = "
|
| 26 |
score, label = classify_by_heuristics(query)
|
| 27 |
assert label == "complex"
|
| 28 |
assert score >= 0.85
|
|
@@ -58,35 +59,35 @@ class TestRoutingEngine:
|
|
| 58 |
decision = RoutingDecision(
|
| 59 |
complexity_score=0.8,
|
| 60 |
complexity_label="complex",
|
| 61 |
-
|
| 62 |
-
|
| 63 |
)
|
| 64 |
assert decision.complexity_score == 0.8
|
| 65 |
assert decision.complexity_label == "complex"
|
| 66 |
-
assert decision.
|
| 67 |
|
| 68 |
-
def
|
| 69 |
-
"""
|
| 70 |
engine = RoutingEngine()
|
| 71 |
|
| 72 |
-
# Simple ->
|
| 73 |
-
assert engine.
|
| 74 |
|
| 75 |
-
# Medium ->
|
| 76 |
-
assert engine.
|
| 77 |
|
| 78 |
-
# Complex ->
|
| 79 |
-
assert engine.
|
| 80 |
|
| 81 |
def test_threshold_boundaries(self):
|
| 82 |
"""Test exact threshold boundaries."""
|
| 83 |
engine = RoutingEngine()
|
| 84 |
|
| 85 |
-
# At lower boundary
|
| 86 |
-
assert engine.
|
| 87 |
|
| 88 |
-
#
|
| 89 |
-
assert engine.
|
| 90 |
|
| 91 |
@pytest.mark.asyncio
|
| 92 |
async def test_route_query_returns_decision(self):
|
|
@@ -98,4 +99,4 @@ class TestRoutingEngine:
|
|
| 98 |
assert isinstance(decision, RoutingDecision)
|
| 99 |
assert 0 <= decision.complexity_score <= 1
|
| 100 |
assert decision.complexity_label in ["simple", "medium", "complex"]
|
| 101 |
-
assert decision.
|
|
|
|
| 8 |
class TestHeuristics:
|
| 9 |
"""Tests for heuristic-based classification."""
|
| 10 |
|
| 11 |
+
def test_simple_greetings(self):
|
| 12 |
"""Simple greetings should be classified as simple."""
|
| 13 |
+
simple_queries = ["Hello", "Hi there", "Thanks!", "yes", "no"]
|
| 14 |
+
for query in simple_queries:
|
| 15 |
score, label = classify_by_heuristics(query)
|
| 16 |
assert label == "simple" or score < 0.5, f"Failed for: {query}"
|
| 17 |
|
|
|
|
| 23 |
|
| 24 |
def test_code_block_detection(self):
|
| 25 |
"""Queries with code blocks should be complex."""
|
| 26 |
+
query = "```python\ndef foo():\n pass\n```"
|
| 27 |
score, label = classify_by_heuristics(query)
|
| 28 |
assert label == "complex"
|
| 29 |
assert score >= 0.85
|
|
|
|
| 59 |
decision = RoutingDecision(
|
| 60 |
complexity_score=0.8,
|
| 61 |
complexity_label="complex",
|
| 62 |
+
model="gpt-4o",
|
| 63 |
+
reason="High complexity query",
|
| 64 |
)
|
| 65 |
assert decision.complexity_score == 0.8
|
| 66 |
assert decision.complexity_label == "complex"
|
| 67 |
+
assert decision.model == "gpt-4o"
|
| 68 |
|
| 69 |
+
def test_complexity_label_thresholds(self):
|
| 70 |
+
"""Complexity labels should be determined by thresholds."""
|
| 71 |
engine = RoutingEngine()
|
| 72 |
|
| 73 |
+
# Simple -> score < 0.35
|
| 74 |
+
assert engine._get_complexity_label(0.2) == "simple"
|
| 75 |
|
| 76 |
+
# Medium -> 0.35 <= score <= 0.70
|
| 77 |
+
assert engine._get_complexity_label(0.5) == "medium"
|
| 78 |
|
| 79 |
+
# Complex -> score > 0.70
|
| 80 |
+
assert engine._get_complexity_label(0.85) == "complex"
|
| 81 |
|
| 82 |
def test_threshold_boundaries(self):
|
| 83 |
"""Test exact threshold boundaries."""
|
| 84 |
engine = RoutingEngine()
|
| 85 |
|
| 86 |
+
# At lower boundary - still medium
|
| 87 |
+
assert engine._get_complexity_label(0.35) == "medium"
|
| 88 |
|
| 89 |
+
# Just above upper boundary - complex
|
| 90 |
+
assert engine._get_complexity_label(0.71) == "complex"
|
| 91 |
|
| 92 |
@pytest.mark.asyncio
|
| 93 |
async def test_route_query_returns_decision(self):
|
|
|
|
| 99 |
assert isinstance(decision, RoutingDecision)
|
| 100 |
assert 0 <= decision.complexity_score <= 1
|
| 101 |
assert decision.complexity_label in ["simple", "medium", "complex"]
|
| 102 |
+
assert decision.model in ["llama3.2", "gpt-4o-mini", "gpt-4o"]
|