Spaces:
Sleeping
Sleeping
| """Memory optimization strategies for agent conversations.""" | |
| import json | |
| from typing import Optional, Dict, List | |
| from .models import ExecutionContext, Message, ToolCall, ToolResult, ContentItem | |
| from .llm import LlmRequest, LlmResponse, LlmClient, build_messages | |
| def apply_sliding_window( | |
| context: ExecutionContext, | |
| request: LlmRequest, | |
| window_size: int = 20 | |
| ) -> None: | |
| """Sliding window that keeps only the most recent N messages""" | |
| contents = request.contents | |
| # Find user message position | |
| user_message_idx = None | |
| for i, item in enumerate(contents): | |
| if isinstance(item, Message) and item.role == "user": | |
| user_message_idx = i | |
| break | |
| if user_message_idx is None: | |
| return | |
| # Preserve up to user message | |
| preserved = contents[:user_message_idx + 1] | |
| # Keep only the most recent N from remaining items | |
| remaining = contents[user_message_idx + 1:] | |
| if len(remaining) > window_size: | |
| remaining = remaining[-window_size:] | |
| request.contents = preserved + remaining | |
| def count_tokens(request: LlmRequest, model_id: str = "gpt-4") -> int: | |
| """Calculate total token count of LlmRequest. | |
| Args: | |
| request: The LLM request to count tokens for | |
| model_id: Model identifier for selecting encoding (default: "gpt-4") | |
| Returns: | |
| Estimated total token count | |
| """ | |
| import tiktoken | |
| # Select encoding for model, use default on failure | |
| try: | |
| encoding = tiktoken.encoding_for_model(model_id) | |
| except KeyError: | |
| encoding = tiktoken.get_encoding("o200k_base") | |
| # Convert to API message format then count tokens | |
| messages = build_messages(request) | |
| total_tokens = 0 | |
| for message in messages: | |
| # Per-message overhead (role, separators, etc.) | |
| total_tokens += 4 | |
| # Content tokens | |
| if message.get("content"): | |
| total_tokens += len(encoding.encode(message["content"])) | |
| # tool_calls tokens | |
| if message.get("tool_calls"): | |
| for tool_call in message["tool_calls"]: | |
| func = tool_call.get("function", {}) | |
| if func.get("name"): | |
| total_tokens += len(encoding.encode(func["name"])) | |
| if func.get("arguments"): | |
| total_tokens += len(encoding.encode(func["arguments"])) | |
| # Tool definition tokens | |
| if request.tools: | |
| for tool in request.tools: | |
| tool_def = tool.tool_definition | |
| total_tokens += len(encoding.encode(json.dumps(tool_def))) | |
| return total_tokens | |
| # Tools to compress ToolCall arguments | |
| TOOLCALL_COMPACTION_RULES = { | |
| "create_file": "[Content saved to file]", | |
| } | |
| # Tools to compress ToolResult content | |
| TOOLRESULT_COMPACTION_RULES = { | |
| "read_file": "File content from {file_path}. Re-read if needed.", | |
| "search_web": "Search results processed. Query: {query}. Re-search if needed.", | |
| "tavily_search": "Search results processed. Query: {query}. Re-search if needed.", | |
| } | |
| def apply_compaction(context: ExecutionContext, request: LlmRequest) -> None: | |
| """Compress tool calls and results into reference messages""" | |
| tool_call_args: Dict[str, Dict] = {} | |
| compacted = [] | |
| for item in request.contents: | |
| if isinstance(item, ToolCall): | |
| # Save arguments (for use when compressing ToolResult later) | |
| tool_call_args[item.tool_call_id] = item.arguments | |
| # If the ToolCall itself is a compression target (create_file, etc.) | |
| if item.name in TOOLCALL_COMPACTION_RULES: | |
| compressed_args = { | |
| k: TOOLCALL_COMPACTION_RULES[item.name] if k == "content" else v | |
| for k, v in item.arguments.items() | |
| } | |
| compacted.append(ToolCall( | |
| tool_call_id=item.tool_call_id, | |
| name=item.name, | |
| arguments=compressed_args | |
| )) | |
| else: | |
| compacted.append(item) | |
| elif isinstance(item, ToolResult): | |
| # If ToolResult is a compression target (read_file, search_web, etc.) | |
| if item.name in TOOLRESULT_COMPACTION_RULES: | |
| args = tool_call_args.get(item.tool_call_id, {}) | |
| template = TOOLRESULT_COMPACTION_RULES[item.name] | |
| compressed_content = template.format( | |
| file_path=args.get("file_path", args.get("path", "unknown")), | |
| query=args.get("query", "unknown") | |
| ) | |
| compacted.append(ToolResult( | |
| tool_call_id=item.tool_call_id, | |
| name=item.name, | |
| status=item.status, | |
| content=[compressed_content] | |
| )) | |
| else: | |
| compacted.append(item) | |
| else: | |
| compacted.append(item) | |
| request.contents = compacted | |
| SUMMARIZATION_PROMPT = """You are summarizing an AI agent's work progress. | |
| Given the following execution history, extract: | |
| 1. Key findings: Important information discovered | |
| 2. Tools used: List of tools that were called | |
| 3. Current status: What has been accomplished and what remains | |
| Be concise. Focus on information that will help the agent continue its work. | |
| Execution History: | |
| {history} | |
| Provide a structured summary.""" | |
| async def apply_summarization( | |
| context: ExecutionContext, | |
| request: LlmRequest, | |
| llm_client: LlmClient, | |
| keep_recent: int = 5 | |
| ) -> None: | |
| """Replace old messages with a summary""" | |
| contents = request.contents | |
| # Find user message position | |
| user_idx = None | |
| for i, item in enumerate(contents): | |
| if isinstance(item, Message) and item.role == "user": | |
| user_idx = i | |
| break | |
| if user_idx is None: | |
| return | |
| # Check previous summary position (skip already-summarized portions) | |
| last_summary_idx = context.state.get("last_summary_idx", user_idx) | |
| # Calculate summarization target range | |
| summary_start = last_summary_idx + 1 | |
| summary_end = len(contents) - keep_recent | |
| # Overlap prevention: exit if nothing to summarize or range is invalid | |
| if summary_end <= summary_start: | |
| return | |
| # Determine portions to preserve (no overlap) | |
| preserved_start = contents[:last_summary_idx + 1] | |
| preserved_end = contents[summary_end:] | |
| to_summarize = contents[summary_start:summary_end] | |
| # Generate summary | |
| history_text = format_history_for_summary(to_summarize) | |
| summary = await generate_summary(llm_client, history_text) | |
| # Add summary to instructions | |
| request.append_instructions(f"[Previous work summary]\n{summary}") | |
| # Keep only preserved portions in contents | |
| request.contents = preserved_start + preserved_end | |
| # Record summary position | |
| context.state["last_summary_idx"] = len(preserved_start) - 1 | |
| def format_history_for_summary(items: List[ContentItem]) -> str: | |
| """Convert ContentItem list to text for summarization""" | |
| lines = [] | |
| for item in items: | |
| if isinstance(item, Message): | |
| lines.append(f"[{item.role}]: {item.content[:500]}...") | |
| elif isinstance(item, ToolCall): | |
| lines.append(f"[Tool Call]: {item.name}({item.arguments})") | |
| elif isinstance(item, ToolResult): | |
| content_preview = str(item.content[0])[:200] if item.content else "" | |
| lines.append(f"[Tool Result]: {item.name} -> {content_preview}...") | |
| return "\n".join(lines) | |
| async def generate_summary(llm_client: LlmClient, history: str) -> str: | |
| """Generate history summary using LLM""" | |
| request = LlmRequest( | |
| instructions=[SUMMARIZATION_PROMPT.format(history=history)], | |
| contents=[Message(role="user", content="Please summarize.")] | |
| ) | |
| response = await llm_client.generate(request) | |
| for item in response.content: | |
| if isinstance(item, Message): | |
| return item.content | |
| return "" | |
| class ContextOptimizer: | |
| """Hierarchical context optimization strategy""" | |
| def __init__( | |
| self, | |
| llm_client: LlmClient, | |
| token_threshold: int = 50000, | |
| enable_compaction: bool = True, | |
| enable_summarization: bool = True, | |
| keep_recent: int = 5 | |
| ): | |
| self.llm_client = llm_client | |
| self.token_threshold = token_threshold | |
| self.enable_compaction = enable_compaction | |
| self.enable_summarization = enable_summarization | |
| self.keep_recent = keep_recent | |
| async def __call__( | |
| self, | |
| context: ExecutionContext, | |
| request: LlmRequest | |
| ) -> Optional[LlmResponse]: | |
| """Register as before_llm_callback""" | |
| # Step 1: Measure tokens | |
| if count_tokens(request) < self.token_threshold: | |
| return None | |
| # Step 2: Apply Compaction | |
| if self.enable_compaction: | |
| apply_compaction(context, request) | |
| if count_tokens(request) < self.token_threshold: | |
| return None | |
| # Step 3: Apply Summarization | |
| if self.enable_summarization: | |
| await apply_summarization( | |
| context, | |
| request, | |
| self.llm_client, | |
| self.keep_recent | |
| ) | |
| return None |