"""Memory optimization strategies for agent conversations.""" import json from typing import Optional, Dict, List from .models import ExecutionContext, Message, ToolCall, ToolResult, ContentItem from .llm import LlmRequest, LlmResponse, LlmClient, build_messages def apply_sliding_window( context: ExecutionContext, request: LlmRequest, window_size: int = 20 ) -> None: """Sliding window that keeps only the most recent N messages""" contents = request.contents # Find user message position user_message_idx = None for i, item in enumerate(contents): if isinstance(item, Message) and item.role == "user": user_message_idx = i break if user_message_idx is None: return # Preserve up to user message preserved = contents[:user_message_idx + 1] # Keep only the most recent N from remaining items remaining = contents[user_message_idx + 1:] if len(remaining) > window_size: remaining = remaining[-window_size:] request.contents = preserved + remaining def count_tokens(request: LlmRequest, model_id: str = "gpt-4") -> int: """Calculate total token count of LlmRequest. Args: request: The LLM request to count tokens for model_id: Model identifier for selecting encoding (default: "gpt-4") Returns: Estimated total token count """ import tiktoken # Select encoding for model, use default on failure try: encoding = tiktoken.encoding_for_model(model_id) except KeyError: encoding = tiktoken.get_encoding("o200k_base") # Convert to API message format then count tokens messages = build_messages(request) total_tokens = 0 for message in messages: # Per-message overhead (role, separators, etc.) total_tokens += 4 # Content tokens if message.get("content"): total_tokens += len(encoding.encode(message["content"])) # tool_calls tokens if message.get("tool_calls"): for tool_call in message["tool_calls"]: func = tool_call.get("function", {}) if func.get("name"): total_tokens += len(encoding.encode(func["name"])) if func.get("arguments"): total_tokens += len(encoding.encode(func["arguments"])) # Tool definition tokens if request.tools: for tool in request.tools: tool_def = tool.tool_definition total_tokens += len(encoding.encode(json.dumps(tool_def))) return total_tokens # Tools to compress ToolCall arguments TOOLCALL_COMPACTION_RULES = { "create_file": "[Content saved to file]", } # Tools to compress ToolResult content TOOLRESULT_COMPACTION_RULES = { "read_file": "File content from {file_path}. Re-read if needed.", "search_web": "Search results processed. Query: {query}. Re-search if needed.", "tavily_search": "Search results processed. Query: {query}. Re-search if needed.", } def apply_compaction(context: ExecutionContext, request: LlmRequest) -> None: """Compress tool calls and results into reference messages""" tool_call_args: Dict[str, Dict] = {} compacted = [] for item in request.contents: if isinstance(item, ToolCall): # Save arguments (for use when compressing ToolResult later) tool_call_args[item.tool_call_id] = item.arguments # If the ToolCall itself is a compression target (create_file, etc.) if item.name in TOOLCALL_COMPACTION_RULES: compressed_args = { k: TOOLCALL_COMPACTION_RULES[item.name] if k == "content" else v for k, v in item.arguments.items() } compacted.append(ToolCall( tool_call_id=item.tool_call_id, name=item.name, arguments=compressed_args )) else: compacted.append(item) elif isinstance(item, ToolResult): # If ToolResult is a compression target (read_file, search_web, etc.) if item.name in TOOLRESULT_COMPACTION_RULES: args = tool_call_args.get(item.tool_call_id, {}) template = TOOLRESULT_COMPACTION_RULES[item.name] compressed_content = template.format( file_path=args.get("file_path", args.get("path", "unknown")), query=args.get("query", "unknown") ) compacted.append(ToolResult( tool_call_id=item.tool_call_id, name=item.name, status=item.status, content=[compressed_content] )) else: compacted.append(item) else: compacted.append(item) request.contents = compacted SUMMARIZATION_PROMPT = """You are summarizing an AI agent's work progress. Given the following execution history, extract: 1. Key findings: Important information discovered 2. Tools used: List of tools that were called 3. Current status: What has been accomplished and what remains Be concise. Focus on information that will help the agent continue its work. Execution History: {history} Provide a structured summary.""" async def apply_summarization( context: ExecutionContext, request: LlmRequest, llm_client: LlmClient, keep_recent: int = 5 ) -> None: """Replace old messages with a summary""" contents = request.contents # Find user message position user_idx = None for i, item in enumerate(contents): if isinstance(item, Message) and item.role == "user": user_idx = i break if user_idx is None: return # Check previous summary position (skip already-summarized portions) last_summary_idx = context.state.get("last_summary_idx", user_idx) # Calculate summarization target range summary_start = last_summary_idx + 1 summary_end = len(contents) - keep_recent # Overlap prevention: exit if nothing to summarize or range is invalid if summary_end <= summary_start: return # Determine portions to preserve (no overlap) preserved_start = contents[:last_summary_idx + 1] preserved_end = contents[summary_end:] to_summarize = contents[summary_start:summary_end] # Generate summary history_text = format_history_for_summary(to_summarize) summary = await generate_summary(llm_client, history_text) # Add summary to instructions request.append_instructions(f"[Previous work summary]\n{summary}") # Keep only preserved portions in contents request.contents = preserved_start + preserved_end # Record summary position context.state["last_summary_idx"] = len(preserved_start) - 1 def format_history_for_summary(items: List[ContentItem]) -> str: """Convert ContentItem list to text for summarization""" lines = [] for item in items: if isinstance(item, Message): lines.append(f"[{item.role}]: {item.content[:500]}...") elif isinstance(item, ToolCall): lines.append(f"[Tool Call]: {item.name}({item.arguments})") elif isinstance(item, ToolResult): content_preview = str(item.content[0])[:200] if item.content else "" lines.append(f"[Tool Result]: {item.name} -> {content_preview}...") return "\n".join(lines) async def generate_summary(llm_client: LlmClient, history: str) -> str: """Generate history summary using LLM""" request = LlmRequest( instructions=[SUMMARIZATION_PROMPT.format(history=history)], contents=[Message(role="user", content="Please summarize.")] ) response = await llm_client.generate(request) for item in response.content: if isinstance(item, Message): return item.content return "" class ContextOptimizer: """Hierarchical context optimization strategy""" def __init__( self, llm_client: LlmClient, token_threshold: int = 50000, enable_compaction: bool = True, enable_summarization: bool = True, keep_recent: int = 5 ): self.llm_client = llm_client self.token_threshold = token_threshold self.enable_compaction = enable_compaction self.enable_summarization = enable_summarization self.keep_recent = keep_recent async def __call__( self, context: ExecutionContext, request: LlmRequest ) -> Optional[LlmResponse]: """Register as before_llm_callback""" # Step 1: Measure tokens if count_tokens(request) < self.token_threshold: return None # Step 2: Apply Compaction if self.enable_compaction: apply_compaction(context, request) if count_tokens(request) < self.token_threshold: return None # Step 3: Apply Summarization if self.enable_summarization: await apply_summarization( context, request, self.llm_client, self.keep_recent ) return None