narcolepticchicken's picture
Upload aco/proxy.py
7aa55cc verified
Raw
History Blame Contribute Delete
26.1 kB
"""
ACO Proxy Server — OpenAI-compatible HTTP server that applies ACO cost
optimizations transparently to any agent's LLM calls.
Start: aco-proxy --port 8080
Use: openai.api_base = "http://localhost:8080/v1"
The proxy intercepts POST /v1/chat/completions and:
1. Routes to cheapest adequate model
2. Gates unnecessary tool calls (v1 tool-gater, F1=0.92)
3. Lays out prompts for cache reuse (system + tools in prefix)
4. Compresses verbose error traces and thinking-only turns
5. Collects telemetry: cost, tokens, latency, cache hits
6. Live dashboard at GET /dashboard
7. JSON telemetry at GET /telemetry
Zero agent code changes needed.
"""
import json, time, re, hashlib, os, threading
from datetime import datetime
from typing import Optional, Dict, List, Any
from dataclasses import dataclass, field
from collections import defaultdict
# ── FastAPI ──────────────────────────────────────────────────────────
try:
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse, HTMLResponse
import uvicorn
import httpx
FASTAPI_AVAILABLE = True
except ImportError:
FASTAPI_AVAILABLE = False
# ── Model Registry ───────────────────────────────────────────────────
MODEL_REGISTRY = {
"deepseek-v4-flash": {"tier": 1, "cost_in": 0.14, "cost_out": 0.28, "ctx": 128000},
"gpt-5-nano": {"tier": 1, "cost_in": 0.15, "cost_out": 0.60, "ctx": 128000},
"gpt-5-mini": {"tier": 2, "cost_in": 0.15, "cost_out": 0.60, "ctx": 128000},
"deepseek-v3.2": {"tier": 2, "cost_in": 0.27, "cost_out": 1.10, "ctx": 131072},
"gemini-2.5-flash": {"tier": 2, "cost_in": 0.15, "cost_out": 0.60, "ctx": 1048576},
"gemini-2.5-pro": {"tier": 3, "cost_in": 1.25, "cost_out": 10.00, "ctx": 1048576},
"claude-opus-4.7": {"tier": 4, "cost_in": 15.00, "cost_out": 75.00, "ctx": 200000},
"gpt-5.2": {"tier": 4, "cost_in": 1.75, "cost_out": 14.00, "ctx": 272000},
"gemini-3-pro": {"tier": 5, "cost_in": 2.00, "cost_out": 12.50, "ctx": 1048576},
}
PROVIDER_ENDPOINTS = {
"openai": os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1"),
"anthropic": os.environ.get("ANTHROPIC_BASE_URL", "https://api.anthropic.com/v1"),
"google": os.environ.get("GOOGLE_BASE_URL", "https://generativelanguage.googleapis.com/v1beta"),
"deepseek": os.environ.get("DEEPSEEK_BASE_URL", "https://api.deepseek.com/v1"),
}
MODEL_PROVIDER = {
"deepseek-v4-flash": "deepseek",
"gpt-5-nano": "openai", "gpt-5-mini": "openai", "gpt-5.2": "openai",
"claude-opus-4.7": "anthropic",
"gemini-2.5-flash": "google", "gemini-2.5-pro": "google", "gemini-3-pro": "google",
"deepseek-v3.2": "deepseek",
}
# ── Tool-Gater Classifier (v1: DistilBERT, F1=0.92) ─────────────────
_tool_gater = None # Lazy-loaded singleton
def _get_tool_gater():
"""Lazy-load the v1 DistilBERT tool-gater."""
global _tool_gater
if _tool_gater is not None:
return _tool_gater
try:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
model_id = "narcolepticchicken/aco-specialists-tool-gater"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=2)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
_tool_gater = (model, tokenizer, device)
return _tool_gater
except Exception as e:
print(f"[ACO] Failed to load tool-gater: {e}. Using heuristic fallback.")
return None
def should_gate_tools_ml(messages: List[Dict]) -> bool:
"""
ML-based tool gating using the v1 DistilBERT classifier (F1=0.92).
Falls back to heuristic if classifier is unavailable.
"""
# Find user query
user_text = ""
system_text = ""
for msg in messages:
if msg.get("role") == "user":
user_text = str(msg.get("content", ""))[:1500]
break
elif msg.get("role") == "system":
system_text = str(msg.get("content", ""))[:500]
if not user_text:
return False
# Check if tools have already been used in this conversation
has_tool_history = any(
'<function=' in str(msg.get("content", ""))
for msg in messages if msg.get("role") == "assistant"
)
if has_tool_history:
return False
# Build input text
text = f"Query: {user_text}"
if system_text:
text = f"System: {system_text}\n\n{text}"
# Try ML classifier
gater = _get_tool_gater()
if gater:
model, tokenizer, device = gater
try:
import torch
inputs = tokenizer(text[:2000], truncation=True, max_length=512,
return_tensors="pt").to(device)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
# Label 0 = "skip_tool", Label 1 = "call_tool"
# Gate tools OFF (return True) when prob[no_tool] > prob[tool]
return probs[0] >= probs[1]
except Exception as e:
print(f"[ACO] Tool-gater inference failed: {e}")
# Heuristic fallback
return heuristic_tool_gate(user_text)
def heuristic_tool_gate(user_text: str) -> bool:
"""Heuristic: should we gate (suppress) tools for this query?"""
ut = user_text.lower()
simple_patterns = [
r'\bwhat is\b', r'\bwho (is|was)\b', r'\bwhen (is|was)\b',
r'\bdefine\b', r'\bexplain\b', r'\bsummarize\b',
r'\bhow (do|does|to)\b', r'\bdifference between\b',
r'\bcapital of\b', r'\bmeaning of\b', r'\btranslate\b',
]
return any(re.search(p, ut) for p in simple_patterns)
# ── Cache-Aware Prompt Layout ────────────────────────────────────────
def layout_cache_prompt(messages: List[Dict], tools: Optional[List[Dict]] = None) -> List[Dict]:
"""
Reorder messages for maximum prefix-cache reuse.
Stable content (system, tool defs) first; dynamic content last.
Strips timestamps/request IDs from user messages to improve cache hits.
"""
laid_out = []
has_tool_block = False
for msg in messages:
if msg.get("role") == "system":
# System prompt first (most stable — best cache target)
content = str(msg.get("content", ""))
# Normalize: strip timestamps
content = re.sub(r'\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}[^\s]*', '[TIME]', content)
content = re.sub(r'run_[a-f0-9]{8,}', 'run_xxx', content)
laid_out.insert(0, {"role": "system", "content": content})
elif msg.get("role") == "tool" and not has_tool_block:
# Tool definitions: convert to system block for cache
laid_out.insert(1, {"role": "system",
"content": f"[TOOL_DEFS]\n{str(msg.get('content', ''))}"})
has_tool_block = True
elif msg.get("role") == "user":
content = str(msg.get("content", ""))
# Normalize dynamic markers
content = re.sub(r'(?:req|trace|run)_[a-f0-9]{8,32}', 'xxx', content)
content = re.sub(r'\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}[^\s]*', '[TIME]', content)
laid_out.append({"role": "user", "content": content})
else:
laid_out.append(msg)
# Append tools as stable suffix block if provided externally
if tools and not has_tool_block:
tool_names = [t.get("function", {}).get("name", t.get("name", "?"))
for t in tools][:20]
laid_out.insert(1, {"role": "system",
"content": f"[TOOL_DEFS]\n{json.dumps(tool_names)}"})
return laid_out
# ── Context Compression ──────────────────────────────────────────────
def compress_context(messages: List[Dict]) -> tuple:
"""Compress verbose agent messages while preserving signal."""
compressed = []
total_orig = 0
total_comp = 0
for msg in messages:
content = str(msg.get("content", ""))
role = msg.get("role", "")
total_orig += len(content)
if role == "user":
cl = content.lower()
# Trim stack traces to head + tail
if len(content) > 2000 and any(k in cl for k in ['traceback', 'error:', 'exception']):
lines = content.split('\n')
head = '\n'.join(lines[:8])
tail = '\n'.join(lines[-5:])
content = f"{head}\n... [{len(lines)-13} lines trimmed] ...\n{tail}"
elif len(content) > 3000:
content = content[:2000] + '\n... [output trimmed] ...'
elif role == "assistant":
s = content
# Drop pure-thinking turns (no function calls, no code blocks)
if len(s) > 800 and '```' not in s and '<function=' not in s:
if not re.search(r'\b(?:execute|run|apply|create|delete|modify|write|patch|fix|submit)\b',
s, re.IGNORECASE):
content = s[:200] + '\n... [thinking trimmed] ...'
# Trim large code blocks
if len(s) > 4000:
content = s[:3000] + '\n... [truncated] ...'
total_comp += len(content)
compressed.append({**msg, "content": content})
ratio = total_comp / max(total_orig, 1)
return compressed, ratio
# ── Model Router ─────────────────────────────────────────────────────
def route_model(requested_model: str, messages: List[Dict]) -> str:
"""Route to cheapest model that can handle this request."""
info = MODEL_REGISTRY.get(requested_model)
if not info:
return requested_model
tier = info["tier"]
# Get last user message
user_text = ""
for msg in reversed(messages):
if msg.get("role") == "user":
user_text = str(msg.get("content", ""))
break
# Downgrade: if using tier 3+ for short simple text, use tier 1
if tier >= 3 and len(user_text) < 300:
return "deepseek-v4-flash"
# Coding floor: keep tier 2 minimum
code_words = ['def ', 'class ', 'function', 'import ', '```', 'fix ', 'bug',
'implement', 'refactor', 'test_', 'pytest', 'traceback', 'error:']
if any(c in user_text for c in code_words) and tier < 2:
return "gpt-5-mini"
return requested_model
# ── Cost Calculator ──────────────────────────────────────────────────
def compute_cost(model: str, input_tokens: int, output_tokens: int,
cache_hit_tokens: int = 0) -> float:
"""Estimate cost in USD per current provider pricing."""
info = MODEL_REGISTRY.get(model)
if not info:
return 0.0
chargeable_input = max(0, input_tokens - cache_hit_tokens)
return round(
(chargeable_input / 1_000_000) * info["cost_in"] +
(output_tokens / 1_000_000) * info["cost_out"],
6)
# ── Telemetry Store ──────────────────────────────────────────────────
@dataclass
class TraceRecord:
request_id: str
timestamp: str
model: str
provider: str
tier: int
input_tokens: int
output_tokens: int
cache_hit_tokens: int
latency_ms: float
cost: float
tool_gated: bool
gated_by: str # "ml" | "heuristic" | "none"
context_compressed: float # ratio
cache_layout_applied: bool
model_routed: bool # was the model changed?
original_model: str
success: bool
error: Optional[str] = None
# ══════════════════════════════════════════════════════════════════════
# FastAPI App
# ══════════════════════════════════════════════════════════════════════
if FASTAPI_AVAILABLE:
app = FastAPI(title="ACO Proxy", version="1.1.0")
telemetry_store: List[TraceRecord] = []
telemetry_lock = threading.Lock()
start_time = datetime.utcnow()
@app.get("/health")
async def health():
return {"status": "ok", "uptime_seconds": (datetime.utcnow() - start_time).total_seconds()}
@app.get("/v1/models")
async def list_models():
return {
"object": "list",
"data": [{"id": m, "object": "model",
"owned_by": MODEL_PROVIDER.get(m, "unknown")}
for m in MODEL_REGISTRY]
}
@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
"""OpenAI-compatible endpoint. ACO optimizations applied transparently."""
body = await request.json()
import uuid
request_id = body.get("user", str(uuid.uuid4())[:8])
# ── Extract params ──
messages = body.get("messages", [])
tools = body.get("tools")
requested_model = body.get("model", "gpt-5-mini")
stream = body.get("stream", False)
original_model = requested_model
# ── Route model ──
routed_model = route_model(requested_model, messages)
model_routed = routed_model != requested_model
provider = MODEL_PROVIDER.get(routed_model, "openai")
# ── Gate tools (ML classifier with heuristic fallback) ──
tools_gated = False
gated_by = "none"
if tools and should_gate_tools_ml(messages):
tools = None
tools_gated = True
gated_by = "ml" if _get_tool_gater() else "heuristic"
# ── Layout for cache ──
laid_out = layout_cache_prompt(messages, tools)
cache_applied = laid_out != messages
# ── Compress context ──
compressed_messages, compression_ratio = compress_context(laid_out)
# ── Forward request ──
forward_body = {**body}
forward_body["model"] = routed_model
forward_body["messages"] = compressed_messages
if tools is None and "tools" in forward_body:
del forward_body["tools"]
endpoint = PROVIDER_ENDPOINTS.get(provider, PROVIDER_ENDPOINTS["openai"])
target_url = f"{endpoint}/chat/completions"
api_key_map = {
"openai": os.environ.get("OPENAI_API_KEY"),
"anthropic": os.environ.get("ANTHROPIC_API_KEY"),
"google": os.environ.get("GOOGLE_API_KEY"),
"deepseek": os.environ.get("DEEPSEEK_API_KEY"),
}
auth = request.headers.get("authorization", "")
headers = {"content-type": "application/json"}
if not auth:
key = api_key_map.get(provider)
if key:
headers["authorization"] = f"Bearer {key}"
else:
headers["authorization"] = auth
# ── Make upstream call ──
t_start = time.time()
error = None
success = True
response_data = {}
try:
async with httpx.AsyncClient(timeout=300.0) as client:
upstream = await client.post(target_url, json=forward_body, headers=headers)
latency = (time.time() - t_start) * 1000
if upstream.status_code != 200 and model_routed:
# Fall back to original model
forward_body["model"] = original_model
upstream2 = await client.post(target_url, json=forward_body, headers=headers)
latency = (time.time() - t_start) * 1000
upstream = upstream2
routed_model = original_model
model_routed = False
if stream:
return StreamingResponse(
upstream.aiter_bytes(),
media_type="text/event-stream",
headers={"x-aco-model": routed_model, "x-aco-tier": str(
MODEL_REGISTRY.get(routed_model, {}).get("tier", "?"))}
)
if upstream.status_code == 200:
response_data = upstream.json()
else:
error = f"Upstream {upstream.status_code}: {upstream.text[:200]}"
success = False
except Exception as e:
latency = (time.time() - t_start) * 1000
error = str(e)
success = False
if not success:
response_data = {
"id": f"aco-err-{request_id}",
"object": "chat.completion",
"created": int(time.time()),
"model": requested_model,
"choices": [{"index": 0, "message": {"role": "assistant",
"content": f"[ACO proxy error: {error}]"},
"finish_reason": "error"}],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
}
# ── Compute cost ──
usage = response_data.get("usage", {})
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
cache_hit = (usage.get("cache_read_input_tokens", 0) or
usage.get("prompt_tokens_details", {}).get("cached_tokens", 0))
cost = compute_cost(routed_model, input_tokens, output_tokens, cache_hit)
# ── Rewrite response ──
response_data.setdefault("usage", {})
response_data["usage"]["aco_cost_usd"] = cost
response_data["usage"]["aco_model"] = routed_model
response_data["usage"]["aco_tier"] = MODEL_REGISTRY.get(routed_model, {}).get("tier", 0)
response_data["usage"]["aco_cache_hit_tokens"] = cache_hit
response_data["usage"]["aco_compression_ratio"] = round(compression_ratio, 2)
response_data["usage"]["aco_tool_gated"] = tools_gated
response_data["model"] = requested_model # Agent sees original model
# ── Record telemetry ──
trace = TraceRecord(
request_id=request_id,
timestamp=datetime.utcnow().isoformat(),
model=routed_model,
provider=provider,
tier=MODEL_REGISTRY.get(routed_model, {}).get("tier", 0),
input_tokens=input_tokens,
output_tokens=output_tokens,
cache_hit_tokens=cache_hit,
latency_ms=round(latency, 1),
cost=cost,
tool_gated=tools_gated,
gated_by=gated_by,
context_compressed=round(compression_ratio, 3),
cache_layout_applied=cache_applied,
model_routed=model_routed,
original_model=original_model,
success=success,
error=error,
)
with telemetry_lock:
telemetry_store.append(trace)
return JSONResponse(response_data)
@app.get("/dashboard")
async def dashboard():
"""Live HTML cost dashboard."""
with telemetry_lock:
traces = list(telemetry_store)
n = len(traces)
if n == 0:
return HTMLResponse("<h2>No traffic yet. Send requests to /v1/chat/completions</h2>")
total_cost = sum(t.cost for t in traces)
successful = sum(1 for t in traces if t.success)
total_in = sum(t.input_tokens for t in traces)
total_out = sum(t.output_tokens for t in traces)
total_cache = sum(t.cache_hit_tokens for t in traces)
avg_lat = sum(t.latency_ms for t in traces) / n
gated = sum(1 for t in traces if t.tool_gated)
routed = sum(1 for t in traces if t.model_routed)
tier_calls = defaultdict(int)
tier_cost = defaultdict(float)
model_calls = defaultdict(int)
for t in traces:
tier_calls[t.tier] += 1
tier_cost[t.tier] += t.cost
model_calls[t.model] += 1
html = f"""<!DOCTYPE html><html><head>
<title>ACO Proxy</title><meta charset="utf-8"><meta http-equiv="refresh" content="3">
<style>
* {{ margin:0; padding:0; box-sizing:border-box; }}
body {{ font-family: system-ui; background: #0d1117; color: #c9d1d9; padding: 1.5rem; }}
h1 {{ font-size: 1.2rem; margin-bottom: 0.5rem; }}
.grid {{ display: grid; grid-template-columns: repeat(auto-fit, minmax(180px,1fr)); gap: 0.5rem; margin-bottom: 1rem; }}
.card {{ background: #161b22; border: 1px solid #30363d; border-radius: 6px; padding: 0.8rem; }}
.card .val {{ font-size: 1.6rem; font-weight: 700; color: #58a6ff; }}
.card .lbl {{ font-size: 0.7rem; color: #8b949e; text-transform: uppercase; letter-spacing: 0.5px; }}
table {{ width: 100%; border-collapse: collapse; font-size: 0.8rem; margin-bottom: 1rem; }}
th, td {{ padding: 0.4rem 0.5rem; text-align: right; border-bottom: 1px solid #21262d; }}
th {{ color: #8b949e; font-weight: 500; text-transform: uppercase; font-size: 0.65rem; }}
td:first-child, th:first-child {{ text-align: left; }}
.good {{ color: #3fb950; }} .bad {{ color: #f85149; }} .dim {{ color: #8b949e; }}
</style></head><body>
<h1>🤖 ACO Proxy <span class="dim">— {n} calls, ${total_cost:.4f} total</span></h1>
<div class="grid">
<div class="card"><div class="val">{successful/n*100:.0f}%</div><div class="lbl">Success Rate</div></div>
<div class="card"><div class="val">${total_cost:.4f}</div><div class="lbl">Total Cost</div></div>
<div class="card"><div class="val">${total_cost/max(n,1):.5f}</div><div class="lbl">Avg Cost/Call</div></div>
<div class="card"><div class="val">{avg_lat:.0f}ms</div><div class="lbl">Avg Latency</div></div>
<div class="card"><div class="val">{total_in//1000}k</div><div class="lbl">Tokens In</div></div>
<div class="card"><div class="val">{total_out//1000}k</div><div class="lbl">Tokens Out</div></div>
<div class="card"><div class="val">{total_cache//1000}k</div><div class="lbl">Cache Hits</div></div>
<div class="card"><div class="val">{gated}</div><div class="lbl">Tools Gated</div></div>
<div class="card"><div class="val">{routed}</div><div class="lbl">Models Rerouted</div></div>
</div>
<table><tr><th>Model</th><th>Calls</th></tr>"""
for m, c in sorted(model_calls.items(), key=lambda x: -x[1]):
html += f"<tr><td>{m}</td><td>{c}</td></tr>"
html += """</table>
<table><tr><th>Time</th><th>Model</th><th>Tier</th><th>Tokens</th><th>Cost</th><th>Lat</th><th>Gated</th><th>Routed</th></tr>"""
for t in reversed(traces[-30:]):
s = 'good' if t.success else 'bad'
html += f"<tr><td class='dim'>{t.timestamp[-8:]}</td>"
html += f"<td>{t.model[:20]}</td><td>{t.tier}</td>"
html += f"<td>{t.input_tokens}+{t.output_tokens}</td>"
html += f"<td>${t.cost:.5f}</td><td>{t.latency_ms:.0f}ms</td>"
html += f"<td>{'✓' if t.tool_gated else ''}</td>"
html += f"<td>{'✓' if t.model_routed else ''}</td></tr>"
html += "</table></body></html>"
return HTMLResponse(html)
@app.get("/telemetry")
async def telemetry_json():
"""JSON telemetry for programmatic consumption."""
with telemetry_lock:
traces = [{"model": t.model, "tier": t.tier, "cost": t.cost,
"input_tokens": t.input_tokens, "output_tokens": t.output_tokens,
"cache_hit_tokens": t.cache_hit_tokens, "latency_ms": t.latency_ms,
"tool_gated": t.tool_gated, "gated_by": t.gated_by,
"model_routed": t.model_routed, "original_model": t.original_model,
"context_compressed": t.context_compressed,
"success": t.success, "error": t.error,
"timestamp": t.timestamp}
for t in telemetry_store]
return {
"total_calls": len(traces),
"total_cost": round(sum(t["cost"] for t in traces), 6),
"calls": traces,
}
@app.get("/telemetry/reset")
async def reset_telemetry():
with telemetry_lock:
telemetry_store.clear()
return {"status": "ok"}
def serve(host: str = "0.0.0.0", port: int = 8080):
if not FASTAPI_AVAILABLE:
print("ERROR: pip install fastapi uvicorn httpx")
return
print(f"🚀 ACO Proxy → http://{host}:{port}")
print(f" Dashboard: http://localhost:{port}/dashboard")
print(f" Telemetry: http://localhost:{port}/telemetry")
print(f" Agent usage: openai.api_base = 'http://localhost:{port}/v1'")
uvicorn.run(app, host=host, port=port, log_level="warning")
def main():
import argparse
p = argparse.ArgumentParser(description="ACO Proxy Server")
p.add_argument("--host", default="0.0.0.0")
p.add_argument("--port", type=int, default=8080)
serve(**vars(p.parse_args()))
if __name__ == "__main__":
main()