File size: 3,678 Bytes
61a7817
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""Tiny pass-through HTTP proxy that clamps max_tokens to <=128000.

Why: aibuildai's bundled Claude SDK (claude-cli 2.1.44) hardcodes
max_tokens=128001 in its first /v1/messages request — one over Anthropic's
current cap for claude-sonnet-4-6. The SDK ignores
CLAUDE_CODE_MAX_OUTPUT_TOKENS in this version and exposes no CLI flag, so
we intercept on the network and clamp before forwarding to CLIProxyAPI.

Usage:
    .venv/bin/python -m agents.cliproxyapi.clamp_proxy
    # then in the agent's env:
    #   ANTHROPIC_BASE_URL=http://127.0.0.1:8318
    #   ANTHROPIC_API_KEY=<the proxy's api-key>

Env vars:
    CLIPROXYAPI_HOST/PORT       upstream CLIProxyAPI (default 127.0.0.1:8317)
    CLIPROXYAPI_CLAMP_PORT      this server's port (default 8318)
    CLIPROXYAPI_MAX_TOKENS_CAP  cap value (default 128000)
"""

from __future__ import annotations

import json
import os

import requests
from flask import Flask, Response, request, stream_with_context

UPSTREAM_HOST = os.environ.get("CLIPROXYAPI_HOST", "127.0.0.1")
UPSTREAM_PORT = int(os.environ.get("CLIPROXYAPI_PORT", "8317"))
UPSTREAM = f"http://{UPSTREAM_HOST}:{UPSTREAM_PORT}"
LISTEN_PORT = int(os.environ.get("CLIPROXYAPI_CLAMP_PORT", "8318"))
MAX_TOKENS_CAP = int(os.environ.get("CLIPROXYAPI_MAX_TOKENS_CAP", "128000"))

app = Flask(__name__)

_HOP_BY_HOP = {
    "connection", "keep-alive", "proxy-authenticate", "proxy-authorization",
    "te", "trailers", "transfer-encoding", "upgrade",
    "content-encoding", "content-length", "host",
}


def _clamp_max_tokens(body: bytes) -> tuple[bytes, bool]:
    """If JSON body has max_tokens > cap, clamp it. Returns (body, clamped?)."""
    if not body:
        return body, False
    try:
        obj = json.loads(body)
    except (ValueError, TypeError):
        return body, False
    if not isinstance(obj, dict):
        return body, False
    mt = obj.get("max_tokens")
    if isinstance(mt, int) and mt > MAX_TOKENS_CAP:
        obj["max_tokens"] = MAX_TOKENS_CAP
        return json.dumps(obj).encode(), True
    return body, False


@app.route("/", defaults={"path": ""}, methods=[
    "GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"])
@app.route("/<path:path>", methods=[
    "GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"])
def forward(path: str):
    url = f"{UPSTREAM}/{path}"
    if request.query_string:
        url += "?" + request.query_string.decode()

    body = request.get_data() if request.method in ("POST", "PUT", "PATCH") else None
    clamped = False
    if body is not None:
        body, clamped = _clamp_max_tokens(body)

    headers = {k: v for k, v in request.headers
               if k.lower() not in _HOP_BY_HOP}
    if clamped:
        # Recompute Content-Length implicitly by letting requests handle it.
        headers.pop("Content-Length", None)
        headers.pop("content-length", None)

    upstream = requests.request(
        method=request.method,
        url=url,
        headers=headers,
        data=body,
        stream=True,
        timeout=900,
        allow_redirects=False,
    )

    resp_headers = [(k, v) for k, v in upstream.headers.items()
                    if k.lower() not in _HOP_BY_HOP]

    def gen():
        for chunk in upstream.iter_content(chunk_size=8192):
            if chunk:
                yield chunk

    return Response(stream_with_context(gen()),
                    status=upstream.status_code,
                    headers=resp_headers)


if __name__ == "__main__":
    print(f"clamp-proxy listening on :{LISTEN_PORT}{UPSTREAM} "
          f"(max_tokens cap = {MAX_TOKENS_CAP})", flush=True)
    app.run(host="0.0.0.0", port=LISTEN_PORT, threaded=True, use_reloader=False)