Spaces:
Running
Running
| """ | |
| Strict Token Budget Management | |
| Implements sliding window conversation history, aggressive compression, | |
| and emergency context truncation to prevent context window overflow. | |
| """ | |
| from typing import List, Dict, Any, Optional, Tuple | |
| import json | |
| import tiktoken | |
| from pathlib import Path | |
| class ConversationMessage: | |
| """Represents a message with priority for history management.""" | |
| def __init__(self, role: str, content: str, message_type: str = "normal", | |
| priority: int = 5, tokens: Optional[int] = None): | |
| self.role = role | |
| self.content = content | |
| self.message_type = message_type # system, tool_result, assistant, user, normal | |
| self.priority = priority # 1 (drop first) to 10 (keep last) | |
| self.tokens = tokens | |
| self.timestamp = None | |
| def to_dict(self) -> Dict[str, str]: | |
| """Convert to OpenAI message format.""" | |
| return {"role": self.role, "content": self.content} | |
| class TokenBudgetManager: | |
| """ | |
| Manages conversation history with strict token budget enforcement. | |
| Features: | |
| - Accurate token counting using tiktoken | |
| - Priority-based message dropping | |
| - Sliding window with smart compression | |
| - Emergency context truncation | |
| - Keeps recent tool results, drops old assistant messages | |
| """ | |
| def __init__(self, model: str = "gpt-4", max_tokens: int = 128000, | |
| reserve_tokens: int = 8000): | |
| """ | |
| Initialize token budget manager. | |
| Args: | |
| model: Model name for token counting | |
| max_tokens: Maximum context window size | |
| reserve_tokens: Tokens to reserve for response | |
| """ | |
| self.model = model | |
| self.max_tokens = max_tokens | |
| self.reserve_tokens = reserve_tokens | |
| self.available_tokens = max_tokens - reserve_tokens | |
| # Initialize tokenizer | |
| try: | |
| self.encoding = tiktoken.encoding_for_model(model) | |
| except: | |
| # Fallback to cl100k_base (GPT-4/GPT-3.5) | |
| self.encoding = tiktoken.get_encoding("cl100k_base") | |
| print(f"📊 Token Budget: {self.available_tokens:,} tokens available ({self.max_tokens:,} - {self.reserve_tokens:,} reserve)") | |
| def count_tokens(self, text: str) -> int: | |
| """Count tokens in text using tiktoken.""" | |
| try: | |
| return len(self.encoding.encode(text)) | |
| except: | |
| # Fallback estimation: ~4 chars per token | |
| return len(text) // 4 | |
| def count_message_tokens(self, message) -> int: | |
| """ | |
| Count tokens in a message (includes role overhead). | |
| Args: | |
| message: Either a dict or a Pydantic ChatMessage object | |
| """ | |
| # Format: <|role|>content<|endofmessage|> | |
| # Approximately 4 tokens overhead per message | |
| # Handle both dict and Pydantic object formats | |
| if isinstance(message, dict): | |
| content = message.get("content", "") | |
| role = message.get("role", "") | |
| else: | |
| # Pydantic object (like ChatMessage from Mistral SDK) | |
| content = getattr(message, "content", "") | |
| role = getattr(message, "role", "") | |
| content_tokens = self.count_tokens(str(content)) | |
| role_tokens = self.count_tokens(str(role)) | |
| return content_tokens + role_tokens + 4 | |
| def count_messages_tokens(self, messages: List) -> int: | |
| """Count total tokens in message list.""" | |
| return sum(self.count_message_tokens(msg) for msg in messages) | |
| def compress_tool_result(self, tool_result: str, max_tokens: int = 500) -> str: | |
| """ | |
| Aggressively compress tool result while keeping key information. | |
| Keeps: | |
| - Success/failure status | |
| - Key metrics and numbers | |
| - Error messages | |
| Drops: | |
| - Verbose logs | |
| - Duplicate information | |
| - Large data structures | |
| """ | |
| if self.count_tokens(tool_result) <= max_tokens: | |
| return tool_result | |
| try: | |
| # Try to parse as JSON | |
| result_dict = json.loads(tool_result) | |
| # Extract essential fields | |
| compressed = { | |
| "success": result_dict.get("success", True), | |
| } | |
| # Add error if present | |
| if "error" in result_dict: | |
| compressed["error"] = str(result_dict["error"])[:200] | |
| # Add key metrics (numbers, scores, paths) | |
| for key in ["score", "accuracy", "best_score", "n_rows", "n_cols", | |
| "output_path", "best_model", "result_summary"]: | |
| if key in result_dict: | |
| compressed[key] = result_dict[key] | |
| # Add result if it's small | |
| if "result" in result_dict: | |
| result_str = str(result_dict["result"]) | |
| if len(result_str) < 300: | |
| compressed["result"] = result_str[:300] | |
| return json.dumps(compressed, indent=None) | |
| except json.JSONDecodeError: | |
| # Not JSON - truncate intelligently | |
| lines = tool_result.split('\n') | |
| # Keep first 5 and last 5 lines | |
| if len(lines) > 15: | |
| compressed_lines = lines[:5] + ["... (truncated) ..."] + lines[-5:] | |
| result = '\n'.join(compressed_lines) | |
| else: | |
| result = tool_result | |
| # Hard truncate if still too long | |
| token_count = self.count_tokens(result) | |
| if token_count > max_tokens: | |
| # Truncate to character limit (rough) | |
| char_limit = max_tokens * 4 | |
| result = result[:char_limit] + "... (truncated)" | |
| return result | |
| def prioritize_messages(self, messages: List[ConversationMessage]) -> List[ConversationMessage]: | |
| """ | |
| Assign priorities to messages based on type and importance. | |
| Priority levels: | |
| - 10: System prompt, recent user messages | |
| - 9: Recent tool results (last 3) | |
| - 8: Recent assistant responses (last 2) | |
| - 5: Normal messages | |
| - 3: Old tool results | |
| - 2: Old assistant responses | |
| - 1: Very old messages | |
| """ | |
| # Find recent messages (last 5) | |
| recent_threshold = max(0, len(messages) - 5) | |
| for i, msg in enumerate(messages): | |
| if msg.message_type == "system": | |
| msg.priority = 10 | |
| elif msg.role == "user": | |
| msg.priority = 10 if i >= recent_threshold else 7 | |
| elif msg.message_type == "tool_result": | |
| msg.priority = 9 if i >= recent_threshold else 3 | |
| elif msg.role == "assistant": | |
| msg.priority = 8 if i >= recent_threshold else 2 | |
| else: | |
| msg.priority = 5 if i >= recent_threshold else 1 | |
| return messages | |
| def apply_sliding_window(self, messages: List[ConversationMessage], | |
| target_tokens: int) -> List[ConversationMessage]: | |
| """ | |
| Apply sliding window to fit within token budget. | |
| Strategy: | |
| 1. Always keep system prompt (first message) | |
| 2. Keep recent messages (last N) | |
| 3. Drop low-priority messages from middle | |
| 4. Compress tool results if needed | |
| Args: | |
| messages: List of ConversationMessage objects | |
| target_tokens: Target token count | |
| Returns: | |
| Filtered message list within budget | |
| """ | |
| if not messages: | |
| return [] | |
| # Always keep system prompt | |
| system_msg = messages[0] if messages[0].message_type == "system" else None | |
| other_messages = messages[1:] if system_msg else messages | |
| # Prioritize messages | |
| other_messages = self.prioritize_messages(other_messages) | |
| # Sort by priority (high to low) | |
| sorted_messages = sorted(other_messages, key=lambda m: m.priority, reverse=True) | |
| # Calculate tokens for each message | |
| for msg in sorted_messages: | |
| if msg.tokens is None: | |
| msg.tokens = self.count_message_tokens(msg.to_dict()) | |
| # Greedily add messages until budget exhausted | |
| kept_messages = [] | |
| current_tokens = 0 | |
| # Add system prompt first | |
| if system_msg: | |
| system_msg.tokens = self.count_message_tokens(system_msg.to_dict()) | |
| kept_messages.append(system_msg) | |
| current_tokens += system_msg.tokens | |
| # Add other messages by priority | |
| for msg in sorted_messages: | |
| if current_tokens + msg.tokens <= target_tokens: | |
| kept_messages.append(msg) | |
| current_tokens += msg.tokens | |
| elif msg.message_type == "tool_result" and msg.priority >= 8: | |
| # Try compressing critical tool results | |
| compressed_content = self.compress_tool_result(msg.content, max_tokens=300) | |
| compressed_tokens = self.count_tokens(compressed_content) | |
| if current_tokens + compressed_tokens <= target_tokens: | |
| msg.content = compressed_content | |
| msg.tokens = compressed_tokens | |
| kept_messages.append(msg) | |
| current_tokens += compressed_tokens | |
| # Sort kept messages back to chronological order | |
| # System message stays first, rest in order they appeared | |
| if system_msg: | |
| non_system = [m for m in kept_messages if m != system_msg] | |
| # Sort by original index (approximate by content comparison) | |
| original_order = [] | |
| for orig_msg in messages: | |
| for kept in non_system: | |
| if kept.content == orig_msg.content: | |
| original_order.append(kept) | |
| break | |
| kept_messages = [system_msg] + original_order | |
| print(f"📊 Sliding window: {len(messages)} → {len(kept_messages)} messages ({current_tokens:,} tokens)") | |
| return kept_messages | |
| def emergency_truncate(self, messages: List[Dict[str, str]], | |
| max_tokens: int) -> List[Dict[str, str]]: | |
| """ | |
| Emergency truncation when context is about to overflow. | |
| Aggressive strategy: | |
| - Keep system prompt | |
| - Keep last user message | |
| - Keep last 2 messages | |
| - Truncate everything else | |
| Args: | |
| messages: Message list | |
| max_tokens: Hard token limit | |
| Returns: | |
| Truncated message list | |
| """ | |
| if not messages: | |
| return [] | |
| print("⚠️ EMERGENCY TRUNCATION: Context overflow imminent") | |
| # Always keep system, last user, and last 2 messages | |
| essential_messages = [] | |
| # System prompt (first message) | |
| if messages: | |
| essential_messages.append(messages[0]) | |
| # Last 2 messages | |
| if len(messages) > 2: | |
| essential_messages.extend(messages[-2:]) | |
| else: | |
| essential_messages.extend(messages[1:]) | |
| # Count tokens | |
| total_tokens = self.count_messages_tokens(essential_messages) | |
| if total_tokens <= max_tokens: | |
| return essential_messages | |
| # Still too large - truncate system prompt | |
| print("⚠️ Truncating system prompt to fit budget") | |
| system_msg = essential_messages[0] | |
| # Handle both dict and Pydantic object formats | |
| if isinstance(system_msg, dict): | |
| system_content = system_msg["content"] | |
| else: | |
| system_content = getattr(system_msg, "content", "") | |
| # Keep first 1000 chars of system prompt | |
| truncated_system = { | |
| "role": "system", | |
| "content": str(system_content)[:1000] + "\n\n... (truncated due to context limit) ..." | |
| } | |
| return [truncated_system] + essential_messages[1:] | |
| def enforce_budget(self, messages: List[Dict[str, str]], | |
| system_prompt: Optional[str] = None) -> Tuple[List[Dict[str, str]], int]: | |
| """ | |
| Main entry point: Enforce token budget on message list. | |
| Args: | |
| messages: List of messages | |
| system_prompt: Optional new system prompt to prepend | |
| Returns: | |
| (filtered_messages, total_tokens) | |
| """ | |
| # Add system prompt if provided | |
| if system_prompt: | |
| messages = [{"role": "system", "content": system_prompt}] + messages | |
| # Count current tokens | |
| current_tokens = self.count_messages_tokens(messages) | |
| print(f"📊 Token Budget Check: {current_tokens:,} / {self.available_tokens:,} tokens") | |
| # If within budget, return as-is | |
| if current_tokens <= self.available_tokens: | |
| print("✅ Within budget") | |
| return messages, current_tokens | |
| print(f"⚠️ Over budget by {current_tokens - self.available_tokens:,} tokens") | |
| # Convert to ConversationMessage objects | |
| conv_messages = [] | |
| for i, msg in enumerate(messages): | |
| # Handle both dict and Pydantic object formats | |
| if isinstance(msg, dict): | |
| role = msg.get("role", "") | |
| content = msg.get("content", "") | |
| else: | |
| role = getattr(msg, "role", "") | |
| content = getattr(msg, "content", "") | |
| msg_type = "system" if i == 0 and role == "system" else "normal" | |
| if "tool" in str(content).lower() or "function" in str(content).lower(): | |
| msg_type = "tool_result" | |
| conv_msg = ConversationMessage( | |
| role=role, | |
| content=str(content), | |
| message_type=msg_type | |
| ) | |
| conv_messages.append(conv_msg) | |
| # Apply sliding window | |
| filtered = self.apply_sliding_window(conv_messages, self.available_tokens) | |
| # Convert back to dict format | |
| result_messages = [msg.to_dict() for msg in filtered] | |
| final_tokens = self.count_messages_tokens(result_messages) | |
| # Emergency truncation if still over | |
| if final_tokens > self.available_tokens: | |
| result_messages = self.emergency_truncate(result_messages, self.available_tokens) | |
| final_tokens = self.count_messages_tokens(result_messages) | |
| print(f"✅ Budget enforced: {final_tokens:,} tokens ({len(result_messages)} messages)") | |
| return result_messages, final_tokens | |
| # Global token budget manager instance | |
| _token_manager = None | |
| def get_token_manager(model: str = "gpt-4", max_tokens: int = 128000) -> TokenBudgetManager: | |
| """Get or create global token budget manager.""" | |
| global _token_manager | |
| if _token_manager is None: | |
| _token_manager = TokenBudgetManager(model=model, max_tokens=max_tokens) | |
| return _token_manager | |