kaiju-coder-7-quantized-runtime / scripts /kaiju_opencode_fast_proxy.py
restokes92's picture
Upload Kaiju Coder 7 runtime quantization recipe
785f3d7 verified
raw
history blame
9.27 kB
#!/usr/bin/env python3
"""Tool-safe OpenAI-compatible fast proxy for Kaiju Coder 7 OpenCode.
The normal Gojira gateway is product/API oriented and aggregates content. OpenCode
needs raw tool-call chunks preserved, so this proxy only patches serving knobs
and then passes upstream responses through unchanged.
"""
from __future__ import annotations
import argparse
import json
import os
import time
import urllib.error
import urllib.request
from http import HTTPStatus
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from typing import Any
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = int(os.environ.get("KAIJU_OPENCODE_FAST_PROXY_PORT", "18181"))
UPSTREAM_BASE_URL = os.environ.get("KAIJU_OPENAI_BASE_URL", "http://100.109.109.14:18084/v1")
DEFAULT_MODEL = os.environ.get("KAIJU_DEFAULT_MODEL", "kaiju-coder-7")
API_KEY = os.environ.get("KAIJU_OPENAI_API_KEY", "")
NORMAL_MAX_TOKENS = int(os.environ.get("KAIJU_NORMAL_MAX_TOKENS", "384"))
WORK_MAX_TOKENS = int(os.environ.get("KAIJU_WORK_MAX_TOKENS", "1536"))
ARTIFACT_MAX_TOKENS = int(os.environ.get("KAIJU_ARTIFACT_MAX_TOKENS", "4096"))
MAX_REQUEST_BYTES = int(os.environ.get("KAIJU_MAX_REQUEST_BYTES", "2097152"))
def normalize_messages(messages: Any) -> list[dict[str, Any]]:
if not isinstance(messages, list):
return []
return [message for message in messages if isinstance(message, dict)]
def message_text(messages: list[dict[str, Any]]) -> str:
parts: list[str] = []
for message in messages:
content = message.get("content", "")
if isinstance(content, str):
parts.append(content)
else:
parts.append(json.dumps(content, ensure_ascii=False))
return "\n".join(parts).lower()
def classify_job(messages: list[dict[str, Any]]) -> str:
text = message_text(messages)
artifact_terms = (
"complete html",
"html file",
"one-file website",
"landing page",
"build a website",
"make a website",
"full file",
)
work_terms = (
"create ",
"write ",
"edit ",
"implement",
"debug",
"fix",
"refactor",
"test",
"repo",
"file",
)
if any(term in text for term in artifact_terms):
return "artifact"
if any(term in text for term in work_terms):
return "work"
return "normal"
def target_tokens(job_class: str) -> int:
if job_class == "artifact":
return ARTIFACT_MAX_TOKENS
if job_class == "work":
return WORK_MAX_TOKENS
return NORMAL_MAX_TOKENS
def patch_chat_payload(body: dict[str, Any]) -> dict[str, Any]:
patched = dict(body)
patched["model"] = DEFAULT_MODEL
messages = normalize_messages(patched.get("messages"))
job_class = classify_job(messages)
patched["max_tokens"] = target_tokens(job_class)
patched["chat_template_kwargs"] = {
**(patched.get("chat_template_kwargs") if isinstance(patched.get("chat_template_kwargs"), dict) else {}),
"enable_thinking": False,
"thinking": False,
}
return patched
class Handler(BaseHTTPRequestHandler):
server_version = "KaijuOpenCodeFastProxy/0.1"
def log_message(self, fmt: str, *args: Any) -> None:
print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} {self.address_string()} - {fmt % args}", flush=True)
def _json(self, status: int, payload: dict[str, Any]) -> None:
data = json.dumps(payload).encode("utf-8")
self.send_response(status)
self.send_header("content-type", "application/json; charset=utf-8")
self.send_header("cache-control", "no-store")
self.send_header("content-length", str(len(data)))
self.end_headers()
self.wfile.write(data)
def _read_json(self) -> dict[str, Any]:
length = int(self.headers.get("content-length", "0"))
if length > MAX_REQUEST_BYTES:
raise ValueError("request body too large")
raw = self.rfile.read(length)
if not raw:
return {}
value = json.loads(raw.decode("utf-8"))
if not isinstance(value, dict):
raise ValueError("request body must be a JSON object")
return value
def do_GET(self) -> None: # noqa: N802 - BaseHTTPRequestHandler API.
if self.path == "/health":
self._json(
HTTPStatus.OK,
{
"ok": True,
"model": DEFAULT_MODEL,
"upstream": UPSTREAM_BASE_URL,
"normal_max_tokens": NORMAL_MAX_TOKENS,
"work_max_tokens": WORK_MAX_TOKENS,
"artifact_max_tokens": ARTIFACT_MAX_TOKENS,
},
)
return
if self.path == "/v1/models":
self._forward_get("/models")
return
self._json(HTTPStatus.NOT_FOUND, {"error": {"message": "Not found", "type": "not_found"}})
def do_POST(self) -> None: # noqa: N802 - BaseHTTPRequestHandler API.
if self.path != "/v1/chat/completions":
self._json(HTTPStatus.NOT_FOUND, {"error": {"message": "Not found", "type": "not_found"}})
return
try:
body = patch_chat_payload(self._read_json())
except Exception as error: # noqa: BLE001 - return request parse failures.
self._json(HTTPStatus.BAD_REQUEST, {"error": {"message": str(error), "type": "bad_request"}})
return
self._forward_post("/chat/completions", body)
def _headers(self) -> dict[str, str]:
headers = {"content-type": "application/json"}
if API_KEY:
headers["authorization"] = f"Bearer {API_KEY}"
return headers
def _forward_get(self, suffix: str) -> None:
request = urllib.request.Request(
f"{UPSTREAM_BASE_URL.rstrip('/')}{suffix}",
headers=self._headers(),
method="GET",
)
try:
with urllib.request.urlopen(request, timeout=30) as upstream:
data = upstream.read()
self.send_response(upstream.status)
self.send_header("content-type", upstream.headers.get("content-type", "application/json"))
self.send_header("cache-control", "no-store")
self.send_header("content-length", str(len(data)))
self.end_headers()
self.wfile.write(data)
except urllib.error.HTTPError as error:
self._json(error.code, {"error": {"message": error.read().decode("utf-8", errors="replace")[:500]}})
except Exception as error: # noqa: BLE001 - proxy health should surface upstream failures.
self._json(HTTPStatus.BAD_GATEWAY, {"error": {"message": str(error), "type": "upstream_error"}})
def _forward_post(self, suffix: str, body: dict[str, Any]) -> None:
data = json.dumps(body).encode("utf-8")
request = urllib.request.Request(
f"{UPSTREAM_BASE_URL.rstrip('/')}{suffix}",
data=data,
headers=self._headers(),
method="POST",
)
try:
timeout = 1200 if classify_job(normalize_messages(body.get("messages"))) == "artifact" else 600
with urllib.request.urlopen(request, timeout=timeout) as upstream:
content_type = upstream.headers.get("content-type", "application/json")
if body.get("stream") is True:
self.send_response(upstream.status)
self.send_header("content-type", content_type)
self.send_header("cache-control", "no-store, no-transform")
self.send_header("connection", "close")
self.end_headers()
for chunk in upstream:
self.wfile.write(chunk)
self.wfile.flush()
return
response = upstream.read()
self.send_response(upstream.status)
self.send_header("content-type", content_type)
self.send_header("cache-control", "no-store")
self.send_header("content-length", str(len(response)))
self.end_headers()
self.wfile.write(response)
except urllib.error.HTTPError as error:
detail = error.read().decode("utf-8", errors="replace")[:500]
self._json(error.code, {"error": {"message": detail, "type": "upstream_error"}})
except Exception as error: # noqa: BLE001 - proxy should report upstream failures.
self._json(HTTPStatus.BAD_GATEWAY, {"error": {"message": str(error), "type": "upstream_error"}})
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--host", default=DEFAULT_HOST)
parser.add_argument("--port", type=int, default=DEFAULT_PORT)
args = parser.parse_args()
server = ThreadingHTTPServer((args.host, args.port), Handler)
print(f"Kaiju OpenCode fast proxy listening on http://{args.host}:{args.port}", flush=True)
print(f"Upstream: {UPSTREAM_BASE_URL}", flush=True)
server.serve_forever()
return 0
if __name__ == "__main__":
raise SystemExit(main())