ayushm98 commited on
Commit
6362af9
Β·
1 Parent(s): fc527f2

Add routing visualization with complexity gauge

Browse files
Files changed (1) hide show
  1. src/cascade/ui/components/routing.py +150 -0
src/cascade/ui/components/routing.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Routing visualization component."""
2
+
3
+ import streamlit as st
4
+ import plotly.graph_objects as go
5
+ import httpx
6
+ from typing import Optional
7
+
8
+ # Import local router for demo
9
+ try:
10
+ from cascade.router import route_query, classify_by_heuristics
11
+ HAS_ROUTER = True
12
+ except ImportError:
13
+ HAS_ROUTER = False
14
+
15
+
16
+ def classify_query_demo(query: str) -> dict:
17
+ """Classify query using local router or heuristics."""
18
+ if HAS_ROUTER:
19
+ try:
20
+ import asyncio
21
+ result = asyncio.run(route_query(query))
22
+ return {
23
+ "score": result.complexity_score,
24
+ "label": result.complexity_label,
25
+ "model": result.recommended_model,
26
+ "reason": result.routing_reason,
27
+ }
28
+ except Exception:
29
+ pass
30
+
31
+ # Fallback to simple heuristics
32
+ score, label = classify_by_heuristics(query) if HAS_ROUTER else (0.5, "medium")
33
+ models = {"simple": "llama3.2", "medium": "gpt-4o-mini", "complex": "gpt-4o"}
34
+ return {
35
+ "score": score,
36
+ "label": label,
37
+ "model": models.get(label, "gpt-4o-mini"),
38
+ "reason": "Classified using heuristics",
39
+ }
40
+
41
+
42
+ def render_complexity_gauge(score: float):
43
+ """Render a gauge chart for complexity score."""
44
+ fig = go.Figure(go.Indicator(
45
+ mode="gauge+number",
46
+ value=score * 100,
47
+ domain={"x": [0, 1], "y": [0, 1]},
48
+ title={"text": "Complexity Score"},
49
+ gauge={
50
+ "axis": {"range": [0, 100], "tickwidth": 1},
51
+ "bar": {"color": "#667eea"},
52
+ "steps": [
53
+ {"range": [0, 35], "color": "#27ae60"},
54
+ {"range": [35, 70], "color": "#f39c12"},
55
+ {"range": [70, 100], "color": "#e74c3c"},
56
+ ],
57
+ "threshold": {
58
+ "line": {"color": "black", "width": 4},
59
+ "thickness": 0.75,
60
+ "value": score * 100,
61
+ },
62
+ },
63
+ ))
64
+ fig.update_layout(height=250, margin=dict(l=20, r=20, t=40, b=20))
65
+ return fig
66
+
67
+
68
+ def render_routing_demo():
69
+ """Render the routing demonstration page."""
70
+ st.markdown('<h1 class="main-header">Routing Demo</h1>', unsafe_allow_html=True)
71
+ st.markdown(
72
+ "See how Cascade classifies query complexity and routes to the optimal model."
73
+ )
74
+
75
+ # Example queries
76
+ st.markdown("### Try Example Queries")
77
+ examples = {
78
+ "Simple": "What is the capital of France?",
79
+ "Medium": "Explain the difference between TCP and UDP protocols.",
80
+ "Complex": "Write a Python function that implements a binary search tree with insert, delete, and search operations, including balancing.",
81
+ }
82
+
83
+ col1, col2, col3 = st.columns(3)
84
+ with col1:
85
+ if st.button("🟒 Simple Query"):
86
+ st.session_state["demo_query"] = examples["Simple"]
87
+ with col2:
88
+ if st.button("🟑 Medium Query"):
89
+ st.session_state["demo_query"] = examples["Medium"]
90
+ with col3:
91
+ if st.button("πŸ”΄ Complex Query"):
92
+ st.session_state["demo_query"] = examples["Complex"]
93
+
94
+ st.divider()
95
+
96
+ # Query input
97
+ query = st.text_area(
98
+ "Enter a query to classify",
99
+ value=st.session_state.get("demo_query", ""),
100
+ height=100,
101
+ placeholder="Type or select an example query above...",
102
+ )
103
+
104
+ if st.button("Analyze Query", type="primary") or query:
105
+ if query:
106
+ with st.spinner("Analyzing..."):
107
+ result = classify_query_demo(query)
108
+
109
+ # Display results
110
+ col1, col2 = st.columns([1, 1])
111
+
112
+ with col1:
113
+ st.markdown("### Classification Result")
114
+ st.plotly_chart(
115
+ render_complexity_gauge(result["score"]),
116
+ use_container_width=True,
117
+ )
118
+
119
+ with col2:
120
+ st.markdown("### Routing Decision")
121
+ st.markdown(f"**Complexity Label:** `{result['label'].upper()}`")
122
+ st.markdown(f"**Recommended Model:** `{result['model']}`")
123
+ st.markdown(f"**Reasoning:** {result['reason']}")
124
+
125
+ # Model info
126
+ model_info = {
127
+ "llama3.2": ("🟒", "Free (Local)", "~50ms"),
128
+ "gpt-4o-mini": ("🟑", "$0.15/1M tokens", "~200ms"),
129
+ "gpt-4o": ("πŸ”΄", "$2.50/1M tokens", "~500ms"),
130
+ }
131
+ info = model_info.get(result["model"], ("βšͺ", "Unknown", "Unknown"))
132
+ st.markdown(f"""
133
+ **Model Details:**
134
+ - Status: {info[0]}
135
+ - Cost: {info[1]}
136
+ - Typical Latency: {info[2]}
137
+ """)
138
+
139
+ # Explanation
140
+ st.divider()
141
+ st.markdown("### How It Works")
142
+ st.markdown("""
143
+ 1. **Query Analysis**: The ML classifier (DistilBERT) analyzes the query text
144
+ 2. **Complexity Score**: Outputs a score from 0.0 (simple) to 1.0 (complex)
145
+ 3. **Threshold Routing**:
146
+ - Score < 0.35 β†’ Route to local Llama (free)
147
+ - Score 0.35-0.70 β†’ Route to GPT-4o-mini (cheap)
148
+ - Score > 0.70 β†’ Route to GPT-4o (powerful)
149
+ 4. **Cost Savings**: Simple queries use free/cheap models, saving 60%+ on API costs
150
+ """)