# context_pilot_workflow.py """ ContextPilot: Context Curation Workflow ======================================= This module handles topic detection and context curation for the ContextPilot system. It does NOT generate responses - it only curates the context that will be sent to the response LLM. KEY RESPONSIBILITIES -------------------- 1. Topic Detection: Detect when the conversation topic changes 2. Context Storage: Save/load conversation context when switching topics 3. Context Curation: Build the optimal message list for the response LLM TOPIC DETECTION --------------- Uses a cheap LLM (CONTEXT_LLM) with function calling to decide: - Is this the same topic? → No action needed - Is this a new topic? → Save old context, set new topic - Is this a previously discussed topic? → Save old context, load old topic The topic detection LLM has access to these tools: - save_context(topic, summary, key_facts): Save current topic before switching - set_current_topic(topic): Set a new topic that hasn't been discussed - load_context(topic): Load and switch to a previously saved topic - list_saved_contexts(): See what topics have been discussed CONTEXT CURATION (build_curated_messages) ----------------------------------------- Builds the message list that will be sent to the response LLM: Summary Mode: - System prompt includes topic summary + key facts - Includes session messages (for within-session continuity) - Includes current user message Full Mode: - Plain system prompt - Full message history from stored context - Current session messages - Current user message DATA FLOW --------- 1. User message arrives via MCP server 2. load_context_store() loads persisted state 3. detect_and_handle_topic_change() decides if topic changed 4. If changed: save_context() + set_current_topic()/load_context() 5. build_curated_messages() creates optimized message list 6. Return curated messages + stats to app.py for response generation """ from llama_index.core.workflow import Workflow, StartEvent, StopEvent, step, Context from llama_index.llms.openai_like import OpenAILike from llama_index.core.tools import FunctionTool from llama_index.core.llms import ChatMessage from dataclasses import dataclass, field from pathlib import Path from dotenv import load_dotenv import json import time import os load_dotenv() # ============================================================================= # Configuration # ============================================================================= CONTEXT_STORE_PATH = Path(__file__).parent / ".context_store.json" SYSTEM_PROMPT_PATH = Path(__file__).parent.parent / "prompts" / "system_prompt.txt" # Context management tool names (excluded from full history) CONTEXT_TOOL_NAMES = {'save_context', 'load_context', 'set_current_topic', 'list_saved_contexts'} # ============================================================================= # Events # ============================================================================= class MessageEvent(StartEvent): """Input event for the workflow (MCP compatible).""" msg: str chat_history: list = [] # ============================================================================= # Data Classes # ============================================================================= @dataclass class CurationResult: """Result of context curation.""" curated_messages: list[dict] # Messages ready for LLM current_topic: str topic_changed: bool decisions: list[str] # UI-friendly decision messages logs: list[str] stats: dict @dataclass class ProcessingState: """Holds state during message processing.""" msg: str chat_history: list store: dict logs: list = field(default_factory=list) decisions: list = field(default_factory=list) topic_changed: bool = False detection_tokens: int = 0 # Tokens used by topic detection LLM @property def current_topic(self) -> str: return self.store.get("current_topic") or "None" @property def mode(self) -> str: return self.store.get("mode", "summary") @property def contexts(self) -> dict: return self.store.get("contexts", {}) @property def saved_topics_str(self) -> str: topics = list(self.contexts.keys()) return ", ".join(topics) if topics else "None" @property def context_summaries_str(self) -> str: """Build context summaries string including saved topics AND current session.""" parts = [] # Include saved topic summaries if self.contexts: parts.append("### Saved Topics:") for topic, ctx in self.contexts.items(): parts.append(f"- {topic}: {ctx.get('summary', 'No summary')}") # Include current session messages (compact summaries) as plain text session_messages = self.store.get("current_session_messages", []) if session_messages: parts.append(f"\n### Current Session ({self.current_topic or 'unknown topic'}):") for msg in session_messages[-5:]: # Last 5 compact exchanges content = msg.get("content", "") # Just show the content - these are compact summaries like "Q: ... | A: ..." parts.append(f"- {content[:200]}") # Truncate if somehow long if not parts: return "No context yet." return "\n".join(parts) # ============================================================================= # Utility Functions # ============================================================================= def count_tokens(text: str) -> int: """Estimate token count (roughly 4 chars per token).""" return len(text) // 4 if text else 0 def extract_text_content(content) -> str: """Safely extract text from various content formats.""" if content is None: return "" if isinstance(content, str): return content if isinstance(content, list): texts = [] for block in content: if isinstance(block, dict): texts.append(block.get('text') or block.get('content', '')) elif hasattr(block, 'text'): texts.append(block.text) elif isinstance(block, str): texts.append(block) return "".join(str(t) for t in texts) return str(content) # ============================================================================= # Context Store # ============================================================================= def load_context_store() -> dict: """Load context store from disk.""" if CONTEXT_STORE_PATH.exists(): try: store = json.loads(CONTEXT_STORE_PATH.read_text()) # Ensure mode exists (default to summary for backwards compatibility) if "mode" not in store: store["mode"] = "summary" # Ensure session messages list exists if "current_session_messages" not in store: store["current_session_messages"] = [] return store except (json.JSONDecodeError, IOError): pass return { "contexts": {}, "current_topic": None, "mode": "summary", # "summary" or "full" "current_session_messages": [], # Messages for current topic in this session "stats": {"total_tokens": 0, "tokens_saved": 0, "context_switches": 0} } def get_current_mode() -> str: """Get the current context mode.""" store = load_context_store() return store.get("mode", "summary") def set_mode(mode: str) -> dict: """Set the context mode and clear all contexts. Returns the new store.""" if mode not in ("summary", "full"): raise ValueError(f"Invalid mode: {mode}. Must be 'summary' or 'full'") # Create fresh store with new mode store = { "contexts": {}, "current_topic": None, "mode": mode, "current_session_messages": [], "stats": {"total_tokens": 0, "tokens_saved": 0, "context_switches": 0, "cumulative_full_tokens": 0, "cumulative_tokens_saved": 0} } save_context_store(store) return store def append_session_message(role: str, content: str): """Append a message to the current session messages (for full mode).""" store = load_context_store() store["current_session_messages"].append({"role": role, "content": content}) save_context_store(store) def clear_session_messages(): """Clear session messages (called when topic changes).""" store = load_context_store() store["current_session_messages"] = [] save_context_store(store) def save_context_store(store: dict): """Save context store to disk.""" CONTEXT_STORE_PATH.write_text(json.dumps(store, indent=2)) # ============================================================================= # System Prompt # ============================================================================= def load_system_prompt(**variables) -> str: """Load and format system prompt template.""" try: template = SYSTEM_PROMPT_PATH.read_text() return template.format(**variables) except FileNotFoundError: return f"You are a helpful AI assistant. Current topic: {variables.get('current_topic', 'unknown')}" def build_system_prompt(state: ProcessingState) -> str: """Build system prompt from current state.""" return load_system_prompt( current_topic=state.current_topic, saved_topics=state.saved_topics_str, context_summaries=state.context_summaries_str ) # ============================================================================= # Context Tools (LLM-callable for topic detection) # ============================================================================= # Global variable to hold pending full history for save # This is set by the workflow before tool calling _pending_full_history: list = [] def set_pending_full_history(history: list): """Set the full history to be saved when save_context is called in full mode.""" global _pending_full_history # In full mode, we use session messages from the store instead # But we also include the UI history filtered as backup _pending_full_history = [ msg for msg in history if not _is_context_tool_message(msg) ] def _is_context_tool_message(msg: dict) -> bool: """Check if a message is a context management tool output.""" content = msg.get("content", "") if not isinstance(content, str): return False # Check for our context tool output markers markers = ["💾 **Context saved", "📂 **Context loaded", "📍 **Topic set", "📍 **Topic inferred", "🆕 **Topic changed", "🧭 Current topic:", "📚 **Contexts listed"] return any(marker in content for marker in markers) def save_context(topic: str, summary: str, key_facts: list[str]) -> str: """Save conversation context before switching topics.""" store = load_context_store() mode = store.get("mode", "summary") session_messages = store.get("current_session_messages", []) content = summary + " ".join(key_facts) tokens = count_tokens(content) context_data = { "topic": topic, "summary": summary, "key_facts": key_facts, "tokens": tokens, "saved_at": time.time(), "mode": mode, } # In full mode, save the session messages as full_history if mode == "full": # Combine any existing stored history with session messages existing_history = [] if topic in store.get("contexts", {}) and store["contexts"][topic].get("full_history"): existing_history = store["contexts"][topic]["full_history"] full_history = existing_history + session_messages context_data["full_history"] = full_history context_data["tokens"] = sum( count_tokens(m.get("content", "")) for m in full_history ) tokens = context_data["tokens"] store["contexts"][topic] = context_data store["stats"]["tokens_saved"] = sum( c.get("tokens", 0) for c in store["contexts"].values() ) store["stats"]["context_switches"] += 1 # Clear session messages after saving store["current_session_messages"] = [] save_context_store(store) if mode == "full": msg_count = len(context_data.get("full_history", [])) return f"✅ Saved context '{topic}' with {msg_count} messages ({tokens} tokens)" return f"✅ Saved context '{topic}' with {len(key_facts)} key facts ({tokens} tokens)" def load_context(topic: str) -> str: """Load a previously saved conversation context.""" store = load_context_store() if topic in store["contexts"]: ctx = store["contexts"][topic] store["current_topic"] = topic save_context_store(store) result = { "topic": ctx["topic"], "summary": ctx["summary"], "key_facts": ctx["key_facts"], } # In full mode, include the full history indicator if ctx.get("mode") == "full" and ctx.get("full_history"): result["has_full_history"] = True result["message_count"] = len(ctx["full_history"]) return json.dumps(result, indent=2) return f"No saved context found for topic '{topic}'" def list_saved_contexts() -> str: """List all saved conversation contexts.""" store = load_context_store() contexts = store.get("contexts", {}) if not contexts: return "No saved contexts yet." return "\n".join( f"• {topic}: {ctx.get('summary', 'No summary')}..." for topic, ctx in contexts.items() ) def set_current_topic(topic: str) -> str: """Set the current conversation topic.""" store = load_context_store() store["current_topic"] = topic save_context_store(store) return f"📍 Current topic set to: {topic}" # Tool objects for topic detection CONTEXT_TOOLS = [ FunctionTool.from_defaults(fn=save_context), FunctionTool.from_defaults(fn=load_context), FunctionTool.from_defaults(fn=list_saved_contexts), FunctionTool.from_defaults(fn=set_current_topic), ] # ============================================================================= # LLM for Topic Detection Only (cheaper model) # ============================================================================= # Use CONTEXT_LLM for topic detection (cheaper) # Use RESPONSE_LLM for generation (more capable) - configured in app.py CONTEXT_LLM_MODEL = os.getenv("CONTEXT_LLM", os.getenv("NEBIUS_MODEL", "openai/gpt-4o-mini")) NEBIUS_BASE_URL = os.getenv("NEBIUS_BASE_URL") NEBIUS_API_KEY = os.getenv("NEBIUS_API_KEY") # Validate required environment variables if not NEBIUS_BASE_URL: print("WARNING: NEBIUS_BASE_URL not set. Topic detection will fail.") if not NEBIUS_API_KEY: print("WARNING: NEBIUS_API_KEY not set. Topic detection will fail.") print(f"[ContextPilot] Context LLM: {CONTEXT_LLM_MODEL}") print(f"[ContextPilot] API Base: {NEBIUS_BASE_URL}") print(f"[ContextPilot] API Key set: {bool(NEBIUS_API_KEY)}") topic_llm = OpenAILike( model=CONTEXT_LLM_MODEL, api_base=NEBIUS_BASE_URL, api_key=NEBIUS_API_KEY, is_chat_model=True, is_function_calling_model=True, context_window=128000, ) # ============================================================================= # Message Building # ============================================================================= def build_detection_messages(state: ProcessingState, system_prompt: str) -> list[ChatMessage]: """Build messages for topic detection (minimal context). The system prompt already contains compact context summaries from the store. We only need to add the current user message - no need to include chat history since that would send full messages defeating the purpose of compact summaries. """ messages = [ChatMessage(role="system", content=system_prompt)] # Only the current user message - context summaries are in system prompt messages.append(ChatMessage(role="user", content=state.msg)) return messages @dataclass class CurationMetrics: """Metrics comparing full context vs curated context.""" curated_messages: list[dict] curated_tokens: int full_context_tokens: int # What it would be without curation tokens_saved_this_request: int savings_percent: float def build_full_context_messages(state: ProcessingState) -> list[dict]: """ Build what the FULL context would look like without curation. This is for comparison only - to show how many tokens we saved. Includes ALL chat history without any summarization. """ store = load_context_store() base_prompt = "You are a helpful AI assistant." # In a non-curated approach, we'd include ALL stored contexts expanded all_contexts = store.get("contexts", {}) if all_contexts: base_prompt += "\n\nFull conversation history from all topics:\n" for topic, ctx in all_contexts.items(): base_prompt += f"\n[Topic: {topic}]\n" base_prompt += f"Summary: {ctx.get('summary', '')}\n" if ctx.get('key_facts'): base_prompt += "Key facts:\n" + "\n".join(f"- {fact}" for fact in ctx['key_facts']) base_prompt += "\n" messages = [{"role": "system", "content": base_prompt}] # Include FULL chat history (no truncation) for h in state.chat_history: if isinstance(h, dict): content = extract_text_content(h.get("content", "")) if content: messages.append({"role": h.get("role", "user"), "content": content}) messages.append({"role": "user", "content": state.msg}) return messages def build_curated_messages(state: ProcessingState) -> CurationMetrics: """ Build the curated message list for LLM consumption. Summary Mode: - System prompt with summary of CURRENT topic only - Only the current user message is sent (no history) Full Mode: - Plain system prompt - Full message history for CURRENT topic only - Current session messages - The current user message """ store = load_context_store() current_topic = store.get("current_topic", "None") mode = store.get("mode", "summary") session_messages = store.get("current_session_messages", []) all_contexts = store.get("contexts", {}) # Debug logging state.logs.append(f"🔧 Mode from store: {mode}") state.logs.append(f"🔧 Session messages count: {len(session_messages)}") state.logs.append(f"🔧 Current topic: {current_topic}") base_prompt = "You are a helpful AI assistant." curated_messages = [] # Check if we have stored context for current topic has_stored_context = ( current_topic and current_topic != "None" and current_topic in all_contexts ) state.logs.append(f"🔧 Has stored context for '{current_topic}': {has_stored_context}") if mode == "full": # FULL MODE: Plain system prompt + full history for CURRENT topic only state.logs.append("🔧 Using FULL mode") curated_messages = [{"role": "system", "content": base_prompt}] # Restore saved FULL history for current topic only if has_stored_context: ctx = all_contexts[current_topic] if ctx.get("full_history"): state.logs.append(f"🔧 Restoring {len(ctx['full_history'])} messages from full_history") curated_messages.extend(ctx["full_history"]) else: state.logs.append("🔧 No full_history in stored context") # Add session messages (messages from this session for current topic) if session_messages: state.logs.append(f"🔧 Adding {len(session_messages)} session messages") curated_messages.extend(session_messages) # Add current message curated_messages.append({"role": "user", "content": state.msg}) else: # SUMMARY MODE: System prompt with CURRENT topic summary + session messages state.logs.append("🔧 Using SUMMARY mode") if has_stored_context: ctx = all_contexts[current_topic] base_prompt += f"\n\n## Context for '{current_topic}':\n" base_prompt += f"Summary: {ctx.get('summary', 'No summary')}\n" if ctx.get('key_facts'): base_prompt += "Key facts:\n" + "\n".join(f"- {fact}" for fact in ctx['key_facts']) state.logs.append(f"🔧 Added summary for current topic '{current_topic}'") curated_messages = [{"role": "system", "content": base_prompt}] # Include session messages for continuation (current session's exchanges) if session_messages: state.logs.append(f"🔧 Adding {len(session_messages)} session messages for continuation") curated_messages.extend(session_messages) # Add current message curated_messages.append({"role": "user", "content": state.msg}) # Calculate full context for comparison (what it would be without ANY curation) full_messages = build_full_context_messages(state) # Count tokens curated_tokens = sum(count_tokens(m.get("content", "")) for m in curated_messages) full_context_tokens = sum(count_tokens(m.get("content", "")) for m in full_messages) tokens_saved = full_context_tokens - curated_tokens savings_percent = (tokens_saved / full_context_tokens * 100) if full_context_tokens > 0 else 0 return CurationMetrics( curated_messages=curated_messages, curated_tokens=curated_tokens, full_context_tokens=full_context_tokens, tokens_saved_this_request=max(0, tokens_saved), savings_percent=round(savings_percent, 1) ) # ============================================================================= # Tool Processing # ============================================================================= def process_tool_sources(sources: list, state: ProcessingState) -> set: """Process tool call sources and update state. Returns set of called tool names.""" called_tools = set() for source in sources: tool_name = getattr(source, 'tool_name', 'unknown') raw_output = getattr(source, 'raw_output', '') called_tools.add(tool_name) state.logs.append(f" → {tool_name}: {str(raw_output)[:100]}...") # User-friendly decisions decisions_map = { 'save_context': f"💾 **Context saved**: {raw_output}", 'load_context': f"📂 **Context loaded**: {raw_output}", 'set_current_topic': f"📍 **Topic set**: {raw_output}", 'list_saved_contexts': "📚 **Contexts listed**", } if tool_name in decisions_map: state.decisions.append(decisions_map[tool_name]) return called_tools async def infer_topic_from_message(msg: str) -> str: """Use LLM to infer topic from user message.""" messages = [ ChatMessage( role="system", content="Extract the MAIN SUBJECT (noun) from the user's message. " "Focus on WHAT the question is about, not the action/verb. " "For 'how do cats hunt?' → 'cats' (not 'hunting'). " "For 'what is Python used for?' → 'python' (not 'programming'). " "For 'how to cook pasta?' → 'pasta' (not 'cooking'). " "Respond with ONLY ONE WORD (lowercase, no punctuation)." ), ChatMessage(role="user", content=msg) ] response = await topic_llm.achat(messages) topic = extract_text_content(response.message.content) return topic.strip().lower().replace(" ", "").replace("_", "") async def detect_and_handle_topic_change(state: ProcessingState) -> bool: """ Use LLM with tools to detect if topic changed. Returns True if topic changed, False otherwise. """ system_prompt = build_system_prompt(state) messages = build_detection_messages(state, system_prompt) # Calculate tokens used for topic detection detection_tokens = sum(count_tokens(m.content or "") for m in messages) state.detection_tokens = detection_tokens state.logs.append(f"🔍 Detecting topic change... (mode: {state.mode})") state.logs.append(f"🎯 Detection tokens: {detection_tokens}") # Set pending history for full mode (in case save_context is called) if state.mode == "full": # Build filtered history from chat_history filtered_history = [] for h in state.chat_history: if isinstance(h, dict) and not _is_context_tool_message(h): content = extract_text_content(h.get("content", "")) if content: filtered_history.append({"role": h.get("role", "user"), "content": content}) set_pending_full_history(filtered_history) # Call LLM with context tools response = await topic_llm.apredict_and_call( tools=CONTEXT_TOOLS, chat_history=messages, error_on_no_tool_call=False, ) sources = getattr(response, 'sources', []) if not sources: state.logs.append("� No topic change detected") return False state.logs.append(f"🔧 LLM called {len(sources)} tools") called_tools = process_tool_sources(sources, state) context_tools = {'save_context', 'load_context', 'set_current_topic', 'list_saved_contexts'} if called_tools & context_tools: # Topic change detected # Infer topic if save_context called without set_current_topic if 'save_context' in called_tools and 'set_current_topic' not in called_tools and 'load_context' not in called_tools: state.logs.append("⚠️ Inferring new topic...") new_topic = await infer_topic_from_message(state.msg) set_current_topic(new_topic) state.decisions.append(f"� **Topic inferred**: {new_topic}") state.logs.append(f"� New topic: {new_topic}") # Reload store state.store = load_context_store() state.decisions.append(f"🆕 **Topic changed to**: **{state.current_topic}**") return True return False # ============================================================================= # Result Building # ============================================================================= def build_stats(state: ProcessingState, metrics: CurationMetrics = None) -> dict: """Build stats dictionary with token comparison metrics.""" store = load_context_store() stats = store.get("stats", {}) stored_contexts_data = [ { "topic": topic, "summary": ctx.get("summary", ""), "key_facts": ctx.get("key_facts", []), "tokens": ctx.get("tokens", 0), "is_current": topic == store.get("current_topic") } for topic, ctx in store.get("contexts", {}).items() ] # Update cumulative stats if metrics: stats["total_tokens"] = stats.get("total_tokens", 0) + metrics.curated_tokens stats["cumulative_full_tokens"] = stats.get("cumulative_full_tokens", 0) + metrics.full_context_tokens stats["cumulative_tokens_saved"] = stats.get("cumulative_tokens_saved", 0) + metrics.tokens_saved_this_request save_context_store(store) # Update detection token stats if state.detection_tokens: stats["cumulative_detection_tokens"] = stats.get("cumulative_detection_tokens", 0) + state.detection_tokens store["stats"] = stats save_context_store(store) return { # Per-request metrics "curated_tokens": metrics.curated_tokens if metrics else 0, "full_context_tokens": metrics.full_context_tokens if metrics else 0, "tokens_saved_this_request": metrics.tokens_saved_this_request if metrics else 0, "savings_percent": metrics.savings_percent if metrics else 0, "detection_tokens": state.detection_tokens, # Tokens for topic detection # Cumulative metrics "cumulative_curated_tokens": stats.get("total_tokens", 0), "cumulative_full_tokens": stats.get("cumulative_full_tokens", 0), "cumulative_tokens_saved": stats.get("cumulative_tokens_saved", 0), "cumulative_detection_tokens": stats.get("cumulative_detection_tokens", 0), # Other stats "context_switches": stats.get("context_switches", 0), "stored_contexts": len(store.get("contexts", {})), "current_topic": store.get("current_topic", "None"), "mode": store.get("mode", "summary"), "logs": state.logs, "stored_contexts_data": stored_contexts_data, } # ============================================================================= # Workflow # ============================================================================= class ContextPilotWorkflow(Workflow): """ Context curation workflow. Detects topic changes and returns curated messages ready for LLM consumption. Does NOT generate the actual response - that's done by the caller. """ @step async def process_message(self, ctx: Context, ev: MessageEvent) -> StopEvent: """ Curate context for a message. Returns: - curated_messages: List of messages ready for LLM - current_topic: The current topic - topic_changed: Whether topic changed - decisions: UI-friendly decision messages - stats: Statistics """ try: # Initialize state store = load_context_store() state = ProcessingState( msg=ev.msg, chat_history=getattr(ev, 'chat_history', []), store=store, ) state.logs.append(f"📨 Processing: {state.msg[:50]}...") state.logs.append(f"🧭 Current topic: {state.current_topic}") state.decisions.append(f"🧭 Current topic: **{state.current_topic}**") # Detect topic change state.topic_changed = await detect_and_handle_topic_change(state) # Build curated messages with metrics metrics = build_curated_messages(state) state.logs.append(f"📝 Curated {len(metrics.curated_messages)} messages (topic_changed={state.topic_changed})") state.logs.append(f"📊 Tokens: {metrics.curated_tokens} curated vs {metrics.full_context_tokens} full ({metrics.savings_percent}% saved)") # Build result result = { "curated_messages": metrics.curated_messages, "current_topic": state.current_topic, "topic_changed": state.topic_changed, "decisions": state.decisions, "stats": build_stats(state, metrics), } return StopEvent(result=json.dumps(result)) except Exception as e: import traceback error_result = { "error": str(e), "traceback": traceback.format_exc(), "curated_messages": [], "decisions": [f"❌ Error: {e}"], "stats": {"logs": [f"Error: {e}", traceback.format_exc()]}, } return StopEvent(result=json.dumps(error_result))