File size: 24,311 Bytes
915aba5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d13c6a
 
 
 
 
 
 
 
 
 
 
 
 
915aba5
 
 
 
 
 
7d13c6a
 
915aba5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e68bab
 
 
915aba5
 
0e68bab
 
915aba5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
# app.py
"""
ContextPilot Gradio App
=======================

This is the main application file for ContextPilot - an autonomous context
engineering system for LLM conversations.

ARCHITECTURE OVERVIEW
---------------------
ContextPilot uses a two-LLM architecture to optimize token usage:

1. CONTEXT_LLM (smaller, efficient):
   - Configured in context_pilot_workflow.py
   - Used for topic detection and context management decisions
   - Example: openai/gpt-oss-20b
   
2. RESPONSE_LLM (capable, quality):
   - Configured here in app.py
   - Used for generating actual responses to user queries
   - Example: openai/gpt-oss-120b

TWO CONTEXT MODES
-----------------
ContextPilot supports two modes for storing conversation context:

SUMMARY MODE (default):
    - Stores a summary + key facts for each topic
    - When returning to a topic, the summary is injected into system prompt
    - Most token-efficient option
    - Best for: general conversations, FAQ-style interactions
    - Trade-off: some detail may be lost in summarization
    
    Flow:
    1. User message β†’ Topic detection (no change)
    2. Curated context = System prompt + topic summary + session messages + user message
    3. Response LLM generates answer
    4. Session messages buffered for continuity
    5. On topic change β†’ LLM generates summary β†’ stored in contexts

FULL MODE:
    - Stores the complete message history for each topic
    - When returning to a topic, full history is restored as chat messages
    - Maximum context preservation
    - Best for: technical discussions, debugging, detailed Q&A
    - Trade-off: uses more tokens when returning to topics with long history
    
    Flow:
    1. User message β†’ Topic detection (no change)
    2. Curated context = System prompt + stored full history + session messages + user message
    3. Response LLM generates answer
    4. Session messages buffered (full user/assistant messages)
    5. On topic change β†’ full session history β†’ stored in contexts

CONTEXT SUMMARY EXTRACTION
--------------------------
The response LLM is instructed to generate a compact context summary at the
end of each response using special tags:

    <context_summary>Q: user question | A: brief answer</context_summary>

This summary is:
- Extracted and hidden from the user
- Used for efficient context storage
- Helps the topic detection LLM understand conversation flow

TOKEN SAVINGS TRACKING
----------------------
The app tracks and displays:
- Curated Tokens: What was sent to the response LLM
- Full Context Tokens: What would be sent without curation
- Tokens Saved: The difference (your savings!)
- Detection Tokens: Tokens used by topic detection LLM
"""

import gradio as gr
import asyncio
import os
import sys
import json
import time
from pathlib import Path
from dotenv import load_dotenv
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from openai import OpenAI

load_dotenv()

# =============================================================================
# Configuration
# =============================================================================

CONTEXT_STORE_PATH = Path(__file__).parent / ".context_store.json"

# LLM Client for generating responses (more capable model)
# RESPONSE_LLM: Used for generating actual responses (higher quality)
# CONTEXT_LLM: Used for topic detection in workflow (cheaper, configured there)
RESPONSE_LLM_MODEL = os.getenv("RESPONSE_LLM", "openai/gpt-oss-120b")
NEBIUS_BASE_URL = os.getenv("NEBIUS_BASE_URL")
NEBIUS_API_KEY = os.getenv("NEBIUS_API_KEY")

# Validate required environment variables
if not NEBIUS_BASE_URL:
    print("WARNING: NEBIUS_BASE_URL not set. Response generation will fail.")
if not NEBIUS_API_KEY:
    print("WARNING: NEBIUS_API_KEY not set. Response generation will fail.")

print(f"[ContextPilot] Response LLM: {RESPONSE_LLM_MODEL}")
print(f"[ContextPilot] API Base: {NEBIUS_BASE_URL}")
print(f"[ContextPilot] API Key set: {bool(NEBIUS_API_KEY)}")

# Threshold for summarizing responses before storing in context (in characters)
# Responses longer than this will be summarized to reduce context size
SUMMARIZE_THRESHOLD = int(os.getenv("SUMMARIZE_THRESHOLD", "500"))

llm_client = OpenAI(
    api_key=NEBIUS_API_KEY,
    base_url=NEBIUS_BASE_URL,
)
LLM_MODEL = RESPONSE_LLM_MODEL

# Context summary tag that the LLM will use
CONTEXT_SUMMARY_TAG = "<context_summary>"
CONTEXT_SUMMARY_END_TAG = "</context_summary>"


def add_context_summary_instruction(messages: list[dict]) -> list[dict]:
    """
    Add instruction to the system prompt asking the LLM to generate
    a compact context summary at the end of its response.
    """
    if not messages:
        return messages
    
    instruction = (
        "\n\n[IMPORTANT: At the very end of your response, add a brief context line in this exact format:\n"
        f"{CONTEXT_SUMMARY_TAG}Q: <user's question in 5-10 words> | A: <your answer in 10-20 words>{CONTEXT_SUMMARY_END_TAG}\n"
        "This helps track conversation context. The user won't see this tag.]"
    )
    
    # Clone messages and modify system prompt
    modified = []
    for msg in messages:
        if msg.get("role") == "system":
            modified.append({
                "role": "system",
                "content": msg.get("content", "") + instruction
            })
        else:
            modified.append(msg)
    
    return modified


def extract_context_summary(response: str) -> tuple[str, str]:
    """
    Extract the context summary from the response and return
    (clean_response, context_summary).
    """
    if CONTEXT_SUMMARY_TAG not in response:
        # No summary tag, create a simple one
        return response, ""
    
    try:
        start = response.index(CONTEXT_SUMMARY_TAG)
        end = response.index(CONTEXT_SUMMARY_END_TAG) + len(CONTEXT_SUMMARY_END_TAG)
        
        context_summary = response[start + len(CONTEXT_SUMMARY_TAG):end - len(CONTEXT_SUMMARY_END_TAG)].strip()
        clean_response = (response[:start] + response[end:]).strip()
        
        return clean_response, context_summary
    except ValueError:
        return response, ""


def summarize_for_context(text: str, max_chars: int = 300) -> str:
    """
    Summarize long text before storing in context.
    This keeps the context small for the smaller context detection model.
    Fallback when LLM doesn't generate a context summary tag.
    """
    if len(text) <= SUMMARIZE_THRESHOLD:
        return text
    
    try:
        result = llm_client.chat.completions.create(
            model=LLM_MODEL,
            messages=[
                {
                    "role": "system",
                    "content": f"Summarize the following text in under {max_chars} characters. "
                               "Keep the key information and main points. Be concise."
                },
                {"role": "user", "content": text}
            ],
            max_tokens=150,
        )
        summary = result.choices[0].message.content.strip()
        return f"[Summary] {summary}"
    except Exception as e:
        # If summarization fails, truncate instead
        return text[:max_chars] + "..."


# =============================================================================
# Context Store Helpers
# =============================================================================

def clear_context_store():
    """Clear the persistent context store."""
    if CONTEXT_STORE_PATH.exists():
        CONTEXT_STORE_PATH.unlink()
    # Returns: chatbot, current_messages, curated, full, saved_this, pct, detection, topic, switches, stored, contexts, logs
    return ([], [], 0, 0, 0, 0, 0, "", 0, 0, [], "πŸ—‘οΈ All contexts cleared!")


def get_current_mode() -> str:
    """Get the current context mode from store."""
    if CONTEXT_STORE_PATH.exists():
        try:
            store = json.loads(CONTEXT_STORE_PATH.read_text())
            return store.get("mode", "summary")
        except (json.JSONDecodeError, IOError):
            pass
    return "summary"


def set_context_mode(mode: str):
    """Set the context mode and clear all contexts."""
    # Create fresh store with new mode
    store = {
        "contexts": {},
        "current_topic": None,
        "mode": mode,
        "current_session_messages": [],
        "stats": {"total_tokens": 0, "tokens_saved": 0, "context_switches": 0,
                  "cumulative_full_tokens": 0, "cumulative_tokens_saved": 0}
    }
    CONTEXT_STORE_PATH.write_text(json.dumps(store, indent=2))
    return store


def load_current_contexts():
    """Load current contexts from disk for display."""
    if CONTEXT_STORE_PATH.exists():
        try:
            store = json.loads(CONTEXT_STORE_PATH.read_text())
            contexts = store.get("contexts", {})
            stats = store.get("stats", {})
            current_topic = store.get("current_topic", "")
            
            stored_contexts_data = [
                {
                    "topic": topic,
                    "summary": ctx.get("summary", ""),
                    "key_facts": ctx.get("key_facts", []),
                    "tokens": ctx.get("tokens", 0),
                    "is_current": topic == current_topic
                }
                for topic, ctx in contexts.items()
            ]
            
            return (
                current_topic,
                stats.get("total_tokens", 0),
                stats.get("tokens_saved", 0),
                stats.get("context_switches", 0),
                len(contexts),
                stored_contexts_data,
            )
        except (json.JSONDecodeError, IOError):
            pass
    return "", 0, 0, 0, 0, []


# =============================================================================
# MCP Client (for context curation)
# =============================================================================

async def call_context_curator(message: str, chat_history: list) -> dict:
    """Call the MCP server to curate context (detect topic, build messages)."""
    # Pass current environment to subprocess so it gets the API keys
    env = os.environ.copy()
    
    server_params = StdioServerParameters(
        command=sys.executable,
        args=[os.path.join(os.path.dirname(__file__), "mcp_server.py")],
        env=env,  # Pass environment variables to subprocess
    )
    
    async with stdio_client(server_params, errlog=sys.stderr) as (read_stream, write_stream):
        async with ClientSession(read_stream, write_stream) as session:
            await session.initialize()
            
            result = await session.call_tool("context-pilot", {
                "run_args": {
                    "msg": message,
                    "chat_history": chat_history
                }
            })
            
            # Extract text content from MCP response
            if result.content and len(result.content) > 0:
                for content in result.content:
                    if hasattr(content, 'text') and content.text:
                        try:
                            return json.loads(content.text)
                        except json.JSONDecodeError as e:
                            print(f"JSON decode error: {e}")
                            print(f"Raw content: {content.text[:500] if len(content.text) > 500 else content.text}")
                            return {"error": f"Invalid JSON response: {e}"}
            
            print(f"No valid content in MCP response: {result}")
            return {"error": "No content in MCP response"}


# =============================================================================
# LLM Response Generation
# =============================================================================

def generate_response_stream(curated_messages: list[dict]):
    """Generate streaming LLM response from curated messages."""
    # Add context summary instruction to get compact context from LLM
    messages_with_instruction = add_context_summary_instruction(curated_messages)
    
    try:
        stream = llm_client.chat.completions.create(
            model=LLM_MODEL,
            messages=messages_with_instruction,
            stream=True,
        )
        for chunk in stream:
            if chunk.choices and chunk.choices[0].delta.content:
                yield chunk.choices[0].delta.content
    except Exception as e:
        yield f"Error generating response: {e}"


# =============================================================================
# Chat Handler
# =============================================================================

async def curate_context(message: str, history: list) -> tuple[list[dict], list, dict]:
    """
    Curate context for the message (async).
    Returns curated_messages, decisions, and stats.
    """
    # Convert Gradio history format
    chat_history = [h for h in history if isinstance(h, dict)]
    
    # Call context curator
    curation_result = await call_context_curator(message, chat_history)
    
    if not curation_result or "error" in curation_result:
        error_msg = curation_result.get("error", "Unknown error") if curation_result else "No response"
        print(f"Curation error: {error_msg}")
        return [], [f"⚠️ Error: {error_msg}"], {"logs": [f"Error: {error_msg}"]}
    
    return (
        curation_result.get("curated_messages", []),
        curation_result.get("decisions", []),
        curation_result.get("stats", {})
    )


def append_to_session_buffer(message: str, response: str):
    """
    Append the exchange to the session buffer.
    
    This is a temporary buffer for the current topic's conversation.
    When topic changes, the workflow's save_context tool will move this
    buffer to permanent storage.
    
    Both modes buffer messages so the LLM has context for continuation.
    """
    if not CONTEXT_STORE_PATH.exists():
        return
    
    try:
        store = json.loads(CONTEXT_STORE_PATH.read_text())
        
        if "current_session_messages" not in store:
            store["current_session_messages"] = []
        
        # Store user message and response
        store["current_session_messages"].append({"role": "user", "content": message})
        store["current_session_messages"].append({"role": "assistant", "content": response})
        
        CONTEXT_STORE_PATH.write_text(json.dumps(store, indent=2))
    except (json.JSONDecodeError, IOError) as e:
        print(f"Error appending to session buffer: {e}")


def count_tokens(text: str) -> int:
    """Estimate token count."""
    return len(text) // 4 if text else 0


def respond_stream(message: str, chat_history: list):
    """
    Streaming chat handler:
    1. Curate context (non-streaming)
    2. Stream LLM response
    3. In full mode, append messages to session store
    """
    if not message.strip():
        yield (chat_history, [], 0, 0, 0, 0, 0, "", 0, 0, [], "")
        return
    
    # Step 1: Curate context (blocking)
    curated_messages, decisions, stats = asyncio.run(
        curate_context(message, chat_history)
    )
    
    if not curated_messages:
        chat_history.append({"role": "user", "content": message})
        chat_history.append({"role": "assistant", "content": "Error: Could not curate context"})
        yield (chat_history, [], 0, 0, 0, 0, 0, "", 0, 0, [], "")
        return
    
    # Add user message to UI chat
    chat_history.append({"role": "user", "content": message})
    
    # Add decisions as assistant message if any
    if decisions:
        decisions_text = "\n".join(decisions)
        chat_history.append({"role": "assistant", "content": decisions_text})
    
    # Step 2: Stream LLM response
    chat_history.append({"role": "assistant", "content": ""})
    
    logs = "\n".join(stats.get("logs", []))
    stored_contexts_data = stats.get("stored_contexts_data", [])
    
    # Collect full response for session messages
    full_response = ""
    
    # Stream the response
    for chunk in generate_response_stream(curated_messages):
        chat_history[-1]["content"] += chunk
        full_response += chunk
        yield (
            chat_history,
            curated_messages,  # Current messages sent to LLM
            stats.get("curated_tokens", 0),
            stats.get("full_context_tokens", 0),
            stats.get("tokens_saved_this_request", 0),
            stats.get("savings_percent", 0),
            stats.get("detection_tokens", 0),
            stats.get("current_topic", ""),
            stats.get("context_switches", 0),
            stats.get("stored_contexts", 0),
            stored_contexts_data,
            logs,
        )
    
    # After streaming completes:
    # 1. Extract context summary from response (LLM generates this)
    # 2. Clean the response for display
    # 3. Store the compact context summary instead of full response
    
    clean_response, context_summary = extract_context_summary(full_response)
    
    # Update the chat history with clean response (remove context summary tag)
    if context_summary:
        chat_history[-1]["content"] = clean_response
        # Final yield with cleaned response
        yield (
            chat_history,
            curated_messages,
            stats.get("curated_tokens", 0),
            stats.get("full_context_tokens", 0),
            stats.get("tokens_saved_this_request", 0),
            stats.get("savings_percent", 0),
            stats.get("detection_tokens", 0),
            stats.get("current_topic", ""),
            stats.get("context_switches", 0),
            stats.get("stored_contexts", 0),
            stored_contexts_data,
            logs,
        )
    
    # In full mode, buffer the exchange for when topic changes
    # In summary mode, we don't buffer - LLM generates summary on topic change
    append_to_session_buffer(message, clean_response)


# =============================================================================
# Gradio Interface
# =============================================================================

with gr.Blocks() as demo:
    gr.Markdown(
        """
        # 🧭 ContextPilot
        ### Autonomous Context Engineering for LLM Conversations
        
        ContextPilot automatically detects topic shifts and manages conversation context 
        to provide more relevant responses.
        """
    )
    
    # Mode toggle section
    with gr.Row():
        with gr.Column(scale=3):
            mode_radio = gr.Radio(
                choices=["summary", "full"],
                value=get_current_mode(),
                label="Context Mode",
                info="Summary: saves key facts only | Full: saves complete message history",
                interactive=True,
            )
        with gr.Column(scale=1):
            mode_status = gr.Textbox(
                value=f"Current mode: {get_current_mode()}",
                label="Status",
                interactive=False,
                show_label=False,
            )
    
    with gr.Row():
        with gr.Column(scale=3):
            chatbot = gr.Chatbot(label="Chat", height=450)
            
            with gr.Row():
                msg = gr.Textbox(
                    label="Message",
                    placeholder="Type your message here...",
                    scale=4,
                    show_label=False,
                )
                submit = gr.Button("Send", variant="primary", scale=1)
            
            with gr.Row():
                clear = gr.Button("Clear Chat")
                clear_contexts = gr.Button("πŸ—‘οΈ Clear All Contexts", variant="stop")
            
            with gr.Accordion("πŸ“¨ Current Request (Messages Sent to LLM)", open=False):
                current_messages_display = gr.JSON(label="Curated Messages", show_label=False)
        
        with gr.Column(scale=1):
            with gr.Tabs():
                with gr.Tab("πŸ’° Token Savings"):
                    gr.Markdown("### Response LLM (This Request)")
                    with gr.Row():
                        curated_tokens = gr.Number(label="Curated (sent)", interactive=False)
                        full_tokens = gr.Number(label="Full (would be)", interactive=False)
                    with gr.Row():
                        saved_this = gr.Number(label="Tokens Saved", interactive=False)
                        savings_pct = gr.Number(label="Savings %", interactive=False)
                    
                    gr.Markdown("### Topic Detection")
                    detection_tokens = gr.Number(label="Detection Tokens", interactive=False)
                
                with gr.Tab("πŸ“Š Stats"):
                    current_topic = gr.Textbox(label="Current Topic", interactive=False)
                    context_switches = gr.Number(label="Context Switches", interactive=False)
                    stored_contexts = gr.Number(label="Stored Contexts", interactive=False)
                
                with gr.Tab("πŸ“š Stored Contexts"):
                    contexts_display = gr.JSON(label="Context Store", show_label=False)
                    refresh_contexts = gr.Button("πŸ”„ Refresh")
                
                with gr.Tab("πŸ“‹ Logs"):
                    logs_display = gr.Textbox(
                        label="Workflow Logs",
                        interactive=False,
                        lines=12,
                        show_label=False,
                    )
    
    # Event handlers
    def respond(message: str, chat_history: list):
        """Wrapper that clears input and starts streaming."""
        # Return empty message immediately to clear input
        # The actual response is handled by respond_stream
        for result in respond_stream(message, chat_history):
            yield ("",) + result
    
    def clear_chat():
        return ([], [], 0, 0, 0, 0, 0, "", 0, 0, [], "")
    
    def refresh_contexts_display():
        """Refresh the contexts display from disk."""
        current, total, saved, switches, count, contexts = load_current_contexts()
        return current, switches, count, contexts
    
    def switch_mode(new_mode: str):
        """Switch context mode and clear all data."""
        set_context_mode(new_mode)
        status = f"βœ… Switched to {new_mode} mode. All contexts cleared."
        # Return cleared state for all outputs
        return (
            status,                    # mode_status
            [],                        # chatbot
            [],                        # current_messages_display
            0, 0, 0, 0,               # token stats this request (curated, full, saved, pct)
            0,                         # detection_tokens
            "",                        # current_topic
            0, 0,                     # switches, stored
            [],                        # contexts_display
            f"πŸ”„ Mode switched to: {new_mode}",  # logs_display
        )
    
    # Wire up events
    outputs = [
        msg, chatbot, current_messages_display,
        curated_tokens, full_tokens, saved_this, savings_pct,
        detection_tokens,
        current_topic, context_switches, stored_contexts,
        contexts_display, logs_display
    ]
    
    clear_outputs = [
        chatbot, current_messages_display,
        curated_tokens, full_tokens, saved_this, savings_pct,
        detection_tokens,
        current_topic, context_switches, stored_contexts,
        contexts_display, logs_display
    ]
    
    msg.submit(respond, [msg, chatbot], outputs)
    submit.click(respond, [msg, chatbot], outputs)
    clear.click(clear_chat, None, clear_outputs)
    clear_contexts.click(clear_context_store, None, clear_outputs)
    refresh_contexts.click(refresh_contexts_display, None, [current_topic, context_switches, stored_contexts, contexts_display])
    
    # Mode switch clears everything
    mode_outputs = [
        mode_status, chatbot, current_messages_display,
        curated_tokens, full_tokens, saved_this, savings_pct,
        detection_tokens,
        current_topic, context_switches, stored_contexts,
        contexts_display, logs_display
    ]
    mode_radio.change(switch_mode, [mode_radio], mode_outputs)


if __name__ == "__main__":
    demo.launch()