| from __future__ import annotations |
|
|
| import json |
| import os |
| import ssl |
| from typing import Any, Dict |
|
|
| import certifi |
| from flask import current_app, request |
| from flask_sock import Sock |
| from websockets.sync.client import connect as websocket_connect |
| from websockets.exceptions import ConnectionClosed |
|
|
| from .responses_api import ( |
| ResponsesRequestError, |
| extract_client_session_id, |
| normalize_responses_payload, |
| ) |
| from .session import ( |
| clear_responses_reuse_state, |
| note_responses_stream_event, |
| prepare_responses_request_for_session, |
| ) |
| from .upstream import build_upstream_headers, build_upstream_websocket_url |
| from .utils import get_effective_chatgpt_auth |
|
|
|
|
| def _log_json(prefix: str, payload: Any) -> None: |
| try: |
| print(f"{prefix}\n{json.dumps(payload, indent=2, ensure_ascii=False)}") |
| except Exception: |
| try: |
| print(f"{prefix}\n{payload}") |
| except Exception: |
| pass |
|
|
|
|
| def _error_event(message: str, *, status_code: int = 400, code: str | None = None) -> Dict[str, Any]: |
| error: Dict[str, Any] = {"message": message} |
| if code: |
| error["code"] = code |
| return {"type": "error", "status_code": status_code, "error": error} |
|
|
|
|
| def _is_terminal_event(event: Any) -> bool: |
| if not isinstance(event, dict): |
| return False |
| kind = event.get("type") |
| return kind in ("response.completed", "response.failed", "error") |
|
|
|
|
| def _build_websocket_ssl_context() -> ssl.SSLContext: |
| cafile = ( |
| os.getenv("CODEX_CA_CERTIFICATE") |
| or os.getenv("SSL_CERT_FILE") |
| or certifi.where() |
| ) |
| return ssl.create_default_context(cafile=cafile) |
|
|
|
|
| def connect_upstream_websocket(url: str, headers: Dict[str, str]): |
| return websocket_connect( |
| url, |
| additional_headers=headers, |
| open_timeout=15, |
| ssl=_build_websocket_ssl_context(), |
| ) |
|
|
|
|
| def register_websocket_routes(sock: Sock) -> None: |
| @sock.route("/v1/responses") |
| def responses_websocket(ws) -> None: |
| verbose = bool(current_app.config.get("VERBOSE")) |
| upstream_ws = None |
| upstream_session_id: str | None = None |
| active_session_id: str | None = None |
|
|
| def _send_error(message: str, *, status_code: int = 400, code: str | None = None) -> None: |
| evt = _error_event(message, status_code=status_code, code=code) |
| if verbose: |
| _log_json("STREAM OUT WS /v1/responses (error)", evt) |
| try: |
| ws.send(json.dumps(evt)) |
| except Exception: |
| pass |
|
|
| try: |
| while True: |
| incoming = ws.receive() |
| if incoming is None: |
| break |
|
|
| if isinstance(incoming, bytes): |
| incoming_text = incoming.decode("utf-8", errors="ignore") |
| else: |
| incoming_text = str(incoming) |
| if verbose: |
| print("IN WS /v1/responses\n" + incoming_text) |
|
|
| try: |
| payload = json.loads(incoming_text) |
| except Exception: |
| _send_error("Websocket frames must be valid JSON objects.", status_code=400) |
| break |
|
|
| if not isinstance(payload, dict): |
| _send_error("Websocket frames must be JSON objects.", status_code=400) |
| break |
|
|
| client_session_id = extract_client_session_id(request.headers) |
| outbound_text = incoming_text |
| session_id = upstream_session_id |
|
|
| if payload.get("type") == "response.create": |
| try: |
| normalized = normalize_responses_payload( |
| payload, |
| config=current_app.config, |
| client_session_id=client_session_id, |
| ) |
| except ResponsesRequestError as exc: |
| _send_error(str(exc), status_code=exc.status_code, code=exc.code) |
| continue |
|
|
| if normalized.service_tier_resolution.warning_message and verbose: |
| print(f"[FastMode] {normalized.service_tier_resolution.warning_message}") |
| prepared = prepare_responses_request_for_session( |
| normalized.session_id, |
| normalized.payload, |
| allow_previous_response_id=True, |
| ) |
| outbound_text = json.dumps(prepared.payload) |
| session_id = normalized.session_id |
| active_session_id = normalized.session_id |
| if verbose: |
| _log_json("OUTBOUND >> ChatGPT Responses WS payload", prepared.payload) |
| elif upstream_ws is None: |
| _send_error( |
| "The first websocket message must be a response.create request.", |
| status_code=400, |
| ) |
| break |
|
|
| if upstream_ws is None or (session_id and session_id != upstream_session_id): |
| access_token, account_id = get_effective_chatgpt_auth() |
| if not access_token or not account_id: |
| if session_id: |
| clear_responses_reuse_state(session_id) |
| _send_error( |
| "Missing ChatGPT credentials. Run 'python3 chatmock.py login' first.", |
| status_code=401, |
| ) |
| break |
|
|
| if upstream_ws is not None: |
| try: |
| upstream_ws.close() |
| except Exception: |
| pass |
|
|
| effective_session_id = session_id or client_session_id or "" |
| try: |
| upstream_ws = connect_upstream_websocket( |
| build_upstream_websocket_url(), |
| build_upstream_headers( |
| access_token, |
| account_id, |
| effective_session_id, |
| accept="application/json", |
| ), |
| ) |
| except Exception as exc: |
| if session_id: |
| clear_responses_reuse_state(session_id) |
| _send_error( |
| f"Upstream websocket connection failed: {exc}", |
| status_code=502, |
| ) |
| break |
| upstream_session_id = effective_session_id |
|
|
| upstream_ws.send(outbound_text) |
|
|
| while True: |
| try: |
| upstream_message = upstream_ws.recv() |
| except ConnectionClosed: |
| if active_session_id: |
| clear_responses_reuse_state(active_session_id) |
| _send_error("Upstream websocket closed unexpectedly.", status_code=502) |
| return |
| if upstream_message is None: |
| if active_session_id: |
| clear_responses_reuse_state(active_session_id) |
| _send_error("Upstream websocket closed unexpectedly.", status_code=502) |
| return |
| if verbose: |
| try: |
| print("STREAM OUT WS /v1/responses\n" + str(upstream_message)) |
| except Exception: |
| pass |
| ws.send(upstream_message) |
|
|
| try: |
| parsed = json.loads(upstream_message) |
| except Exception: |
| parsed = None |
| if isinstance(parsed, dict) and active_session_id: |
| note_responses_stream_event(active_session_id, parsed) |
| if _is_terminal_event(parsed): |
| if isinstance(parsed, dict) and parsed.get("type") in ("response.failed", "error"): |
| if upstream_ws is not None: |
| try: |
| upstream_ws.close() |
| except Exception: |
| pass |
| upstream_ws = None |
| upstream_session_id = None |
| break |
| finally: |
| if upstream_ws is not None: |
| try: |
| upstream_ws.close() |
| except Exception: |
| pass |
|
|