agent-cost-optimizer / aco /context_compression.py
narcolepticchicken's picture
Upload aco/context_compression.py
6063fe9 verified
Raw
History Blame Contribute Delete
9.04 kB
"""Context Compression Module — ACON pattern: guideline-optimized compression.
Based on ACON (2510.00615): Compress agent interaction history via
iteratively optimized natural-language guidelines, then distill into
a small model for near-zero-overhead compression.
Key insight: Compress in natural language space, not token space.
Remove distracting context to IMPROVE success rate, not just reduce tokens.
Also implements:
- Cache-Aware Prompt Layout: Static content (system prompts, tool schemas)
at prefix to maximize Anthropic/OpenAI automatic prompt caching.
This alone gives 50-90% cost reduction on cached tokens.
- TALE-style Token Budget Estimation: Predict per-query optimal reasoning
length to inject as a budget constraint.
"""
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
COMPRESSION_GUIDELINES = {
"coding": """
You are compressing a coding agent's interaction history. PRESERVE:
- Current error messages and stack traces (EXACT)
- File paths being modified
- Test output that failed
- Current git status / branch
- The specific change being attempted
OMIT:
- Completed successful actions (unless relevant to current step)
- Redundant observations (same error repeated)
- Tool output that was successfully parsed
- Reasoning about already-completed steps
- Boilerplate system messages
""",
"research": """
You are compressing a research agent's interaction history. PRESERVE:
- Sources already consulted (URLs, paper IDs)
- Key findings extracted so far
- Current search queries being explored
- Contradictions found between sources
- Open questions still unanswered
OMIT:
- Full source text (keep citations only)
- Exhausted search queries
- Redundant findings across sources
- Dead-end searches
""",
"tool_use": """
You are compressing a tool-using agent's interaction history. PRESERVE:
- Current API/tool being used
- Last successful and failed calls
- Key data points extracted
- Rate limits or errors encountered
OMIT:
- Raw API responses (summarize extracted data)
- Successful calls that completed their purpose
- Repeated retry attempts with the same parameters
""",
"default": """
You are compressing an agent's interaction history. PRESERVE:
- Current task state and progress
- Critical observations and errors
- Pending actions to complete
- Recent decisions and their rationale
OMIT:
- Completed actions that are no longer relevant
- Redundant observations
- Verbose reasoning about finished steps
- Boilerplate system/tool descriptions
""",
}
@dataclass
class CompressionResult:
compressed_text: str
original_tokens: int
compressed_tokens: int
compression_ratio: float
preserved_items: List[str]
omitted_items: List[str]
@dataclass
class LayoutResult:
messages: List[Dict[str, str]]
cache_prefix_tokens: int
dynamic_suffix_tokens: int
estimated_cache_hit_rate: float
estimated_cost_without_cache: float
estimated_cost_with_cache: float
cost_savings: float
class ContextCompressor:
"""ACON-style context compression for agent interaction histories."""
def __init__(self, model_id: str = None):
self.model_id = model_id
self.guidelines = COMPRESSION_GUIDELINES
self.stats = {"total_compressions": 0, "total_tokens_saved": 0}
def compress(self, history: str, task_type: str = "default",
max_tokens: int = None) -> CompressionResult:
"""Compress agent interaction history."""
import re
original_tokens = self._estimate_tokens(history)
lines = history.split("\n")
result = []
seen_errors = set()
for line in lines:
stripped = line.strip()
# Always keep errors
if any(kw in stripped.lower() for kw in
["error", "traceback", "exception", "failed", "segfault"]):
ek = stripped[:200]
if ek not in seen_errors:
seen_errors.add(ek)
result.append(line)
continue
# Keep file paths
if re.search(r'[\w/.-]+\.(py|js|ts|rs|go|java|rb|sh|yaml|json|toml|md)', stripped):
result.append(line)
continue
# Keep final answers
if any(kw in stripped.lower() for kw in
["answer:", "conclusion:", "final:", "result:", "patch:", "fix:"]):
result.append(line)
continue
result.append(line)
compressed = "\n".join(result)
if max_tokens:
tokens = self._estimate_tokens(compressed)
if tokens > max_tokens:
lines = compressed.split("\n")
keep_head = int(len(lines) * 0.2)
keep_tail = max_tokens // 4
compressed = "\n".join(lines[:keep_head]) + "\n...[omitted]...\n" + "\n".join(lines[-keep_tail:])
compressed_tokens = self._estimate_tokens(compressed)
ratio = compressed_tokens / max(original_tokens, 1)
self.stats["total_compressions"] += 1
self.stats["total_tokens_saved"] += original_tokens - compressed_tokens
return CompressionResult(
compressed_text=compressed,
original_tokens=original_tokens,
compressed_tokens=compressed_tokens,
compression_ratio=ratio,
preserved_items=list(seen_errors),
omitted_items=[],
)
def _estimate_tokens(self, text: str) -> int:
return len(text) // 4
class CacheAwareLayout:
"""Optimize prompt structure for Anthropic/OpenAI automatic prompt caching.
Both providers cache prefixes >1024 tokens:
- Anthropic: 90% cost reduction on cached tokens
- OpenAI: 50% discount
Strategy: ALL static content at prefix, dynamic at suffix.
"""
STATIC_ORDER = ["system_prompt", "tool_definitions", "few_shot_examples",
"project_context", "user_preferences"]
DYNAMIC_ORDER = ["user_request", "retrieved_docs", "task_plan",
"recent_messages", "artifacts", "tool_results"]
def __init__(self):
self.stats = {"total_layouts": 0, "estimated_total_savings": 0.0}
def layout(self, sources: Dict[str, str],
max_prefix_tokens: int = 32000,
cost_per_1k_input: float = 3.0,
cache_discount: float = 0.9) -> LayoutResult:
messages = []
prefix_tokens = 0
suffix_tokens = 0
for ct in self.STATIC_ORDER:
if ct in sources and sources[ct]:
tokens = len(sources[ct]) // 4
if prefix_tokens + tokens <= max_prefix_tokens:
messages.append({"role": "system", "content": sources[ct]})
prefix_tokens += tokens
for ct in self.DYNAMIC_ORDER:
if ct in sources and sources[ct]:
tokens = len(sources[ct]) // 4
messages.append({"role": "user", "content": sources[ct]})
suffix_tokens += tokens
total = prefix_tokens + suffix_tokens
cost_per_1k = cost_per_1k_input / 1000
no_cache_cost = total * cost_per_1k
cached_cost = prefix_tokens * cost_per_1k * (1 - cache_discount)
non_cached_cost = suffix_tokens * cost_per_1k
with_cache = cached_cost + non_cached_cost
savings = no_cache_cost - with_cache
hit_rate = prefix_tokens / max(total, 1)
self.stats["total_layouts"] += 1
self.stats["estimated_total_savings"] += savings
return LayoutResult(
messages=messages, cache_prefix_tokens=prefix_tokens,
dynamic_suffix_tokens=suffix_tokens,
estimated_cache_hit_rate=hit_rate,
estimated_cost_without_cache=no_cache_cost,
estimated_cost_with_cache=with_cache,
cost_savings=savings,
)
def get_stats(self):
return dict(self.stats)
class TokenBudgetEstimator:
"""TALE-style per-query token budget."""
def __init__(self):
self.budget_map = {
"quick_answer": (50, 200), "document_drafting": (200, 1000),
"tool_use": (100, 500), "retrieval": (200, 800),
"coding": (200, 2000), "research": (500, 3000),
"long_horizon": (500, 4000), "legal_regulated": (300, 2000),
"unknown": (200, 1000),
}
def estimate(self, task_type: str, complexity: int, has_tools: bool = False):
base = task_type.split("_")[0] if "_" in task_type else task_type
low, high = self.budget_map.get(base, (200, 1000))
scale = {1: 0.5, 2: 0.75, 3: 1.0, 4: 1.5, 5: 2.0}.get(complexity, 1.0)
max_tokens = max(low, min(int(high * scale), 8000))
if has_tools: max_tokens = int(max_tokens * 0.8)
hint = f"Respond in at most {max_tokens} tokens. Be concise."
return max_tokens, hint