#!/usr/bin/env python3 """Tool-safe OpenAI-compatible fast proxy for Kaiju Coder 7 OpenCode. The normal Gojira gateway is product/API oriented and aggregates content. OpenCode needs raw tool-call chunks preserved, so this proxy only patches serving knobs and then passes upstream responses through unchanged. """ from __future__ import annotations import argparse import json import os import time import urllib.error import urllib.request from http import HTTPStatus from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from typing import Any DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = int(os.environ.get("KAIJU_OPENCODE_FAST_PROXY_PORT", "18181")) UPSTREAM_BASE_URL = os.environ.get("KAIJU_OPENAI_BASE_URL", "http://100.109.109.14:18084/v1") DEFAULT_MODEL = os.environ.get("KAIJU_DEFAULT_MODEL", "kaiju-coder-7") API_KEY = os.environ.get("KAIJU_OPENAI_API_KEY", "") NORMAL_MAX_TOKENS = int(os.environ.get("KAIJU_NORMAL_MAX_TOKENS", "384")) WORK_MAX_TOKENS = int(os.environ.get("KAIJU_WORK_MAX_TOKENS", "1536")) ARTIFACT_MAX_TOKENS = int(os.environ.get("KAIJU_ARTIFACT_MAX_TOKENS", "4096")) MAX_REQUEST_BYTES = int(os.environ.get("KAIJU_MAX_REQUEST_BYTES", "2097152")) def normalize_messages(messages: Any) -> list[dict[str, Any]]: if not isinstance(messages, list): return [] return [message for message in messages if isinstance(message, dict)] def message_text(messages: list[dict[str, Any]]) -> str: parts: list[str] = [] for message in messages: content = message.get("content", "") if isinstance(content, str): parts.append(content) else: parts.append(json.dumps(content, ensure_ascii=False)) return "\n".join(parts).lower() def classify_job(messages: list[dict[str, Any]]) -> str: text = message_text(messages) artifact_terms = ( "complete html", "html file", "one-file website", "landing page", "build a website", "make a website", "full file", ) work_terms = ( "create ", "write ", "edit ", "implement", "debug", "fix", "refactor", "test", "repo", "file", ) if any(term in text for term in artifact_terms): return "artifact" if any(term in text for term in work_terms): return "work" return "normal" def target_tokens(job_class: str) -> int: if job_class == "artifact": return ARTIFACT_MAX_TOKENS if job_class == "work": return WORK_MAX_TOKENS return NORMAL_MAX_TOKENS def patch_chat_payload(body: dict[str, Any]) -> dict[str, Any]: patched = dict(body) patched["model"] = DEFAULT_MODEL messages = normalize_messages(patched.get("messages")) job_class = classify_job(messages) patched["max_tokens"] = target_tokens(job_class) patched["chat_template_kwargs"] = { **(patched.get("chat_template_kwargs") if isinstance(patched.get("chat_template_kwargs"), dict) else {}), "enable_thinking": False, "thinking": False, } return patched class Handler(BaseHTTPRequestHandler): server_version = "KaijuOpenCodeFastProxy/0.1" def log_message(self, fmt: str, *args: Any) -> None: print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} {self.address_string()} - {fmt % args}", flush=True) def _json(self, status: int, payload: dict[str, Any]) -> None: data = json.dumps(payload).encode("utf-8") self.send_response(status) self.send_header("content-type", "application/json; charset=utf-8") self.send_header("cache-control", "no-store") self.send_header("content-length", str(len(data))) self.end_headers() self.wfile.write(data) def _read_json(self) -> dict[str, Any]: length = int(self.headers.get("content-length", "0")) if length > MAX_REQUEST_BYTES: raise ValueError("request body too large") raw = self.rfile.read(length) if not raw: return {} value = json.loads(raw.decode("utf-8")) if not isinstance(value, dict): raise ValueError("request body must be a JSON object") return value def do_GET(self) -> None: # noqa: N802 - BaseHTTPRequestHandler API. if self.path == "/health": self._json( HTTPStatus.OK, { "ok": True, "model": DEFAULT_MODEL, "upstream": UPSTREAM_BASE_URL, "normal_max_tokens": NORMAL_MAX_TOKENS, "work_max_tokens": WORK_MAX_TOKENS, "artifact_max_tokens": ARTIFACT_MAX_TOKENS, }, ) return if self.path == "/v1/models": self._forward_get("/models") return self._json(HTTPStatus.NOT_FOUND, {"error": {"message": "Not found", "type": "not_found"}}) def do_POST(self) -> None: # noqa: N802 - BaseHTTPRequestHandler API. if self.path != "/v1/chat/completions": self._json(HTTPStatus.NOT_FOUND, {"error": {"message": "Not found", "type": "not_found"}}) return try: body = patch_chat_payload(self._read_json()) except Exception as error: # noqa: BLE001 - return request parse failures. self._json(HTTPStatus.BAD_REQUEST, {"error": {"message": str(error), "type": "bad_request"}}) return self._forward_post("/chat/completions", body) def _headers(self) -> dict[str, str]: headers = {"content-type": "application/json"} if API_KEY: headers["authorization"] = f"Bearer {API_KEY}" return headers def _forward_get(self, suffix: str) -> None: request = urllib.request.Request( f"{UPSTREAM_BASE_URL.rstrip('/')}{suffix}", headers=self._headers(), method="GET", ) try: with urllib.request.urlopen(request, timeout=30) as upstream: data = upstream.read() self.send_response(upstream.status) self.send_header("content-type", upstream.headers.get("content-type", "application/json")) self.send_header("cache-control", "no-store") self.send_header("content-length", str(len(data))) self.end_headers() self.wfile.write(data) except urllib.error.HTTPError as error: self._json(error.code, {"error": {"message": error.read().decode("utf-8", errors="replace")[:500]}}) except Exception as error: # noqa: BLE001 - proxy health should surface upstream failures. self._json(HTTPStatus.BAD_GATEWAY, {"error": {"message": str(error), "type": "upstream_error"}}) def _forward_post(self, suffix: str, body: dict[str, Any]) -> None: data = json.dumps(body).encode("utf-8") request = urllib.request.Request( f"{UPSTREAM_BASE_URL.rstrip('/')}{suffix}", data=data, headers=self._headers(), method="POST", ) try: timeout = 1200 if classify_job(normalize_messages(body.get("messages"))) == "artifact" else 600 with urllib.request.urlopen(request, timeout=timeout) as upstream: content_type = upstream.headers.get("content-type", "application/json") if body.get("stream") is True: self.send_response(upstream.status) self.send_header("content-type", content_type) self.send_header("cache-control", "no-store, no-transform") self.send_header("connection", "close") self.end_headers() for chunk in upstream: self.wfile.write(chunk) self.wfile.flush() return response = upstream.read() self.send_response(upstream.status) self.send_header("content-type", content_type) self.send_header("cache-control", "no-store") self.send_header("content-length", str(len(response))) self.end_headers() self.wfile.write(response) except urllib.error.HTTPError as error: detail = error.read().decode("utf-8", errors="replace")[:500] self._json(error.code, {"error": {"message": detail, "type": "upstream_error"}}) except Exception as error: # noqa: BLE001 - proxy should report upstream failures. self._json(HTTPStatus.BAD_GATEWAY, {"error": {"message": str(error), "type": "upstream_error"}}) def main() -> int: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--host", default=DEFAULT_HOST) parser.add_argument("--port", type=int, default=DEFAULT_PORT) args = parser.parse_args() server = ThreadingHTTPServer((args.host, args.port), Handler) print(f"Kaiju OpenCode fast proxy listening on http://{args.host}:{args.port}", flush=True) print(f"Upstream: {UPSTREAM_BASE_URL}", flush=True) server.serve_forever() return 0 if __name__ == "__main__": raise SystemExit(main())