File size: 8,866 Bytes
35205e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
from __future__ import annotations

import json
import os
import ssl
from typing import Any, Dict

import certifi
from flask import current_app, request
from flask_sock import Sock
from websockets.sync.client import connect as websocket_connect
from websockets.exceptions import ConnectionClosed

from .responses_api import (
    ResponsesRequestError,
    extract_client_session_id,
    normalize_responses_payload,
)
from .session import (
    clear_responses_reuse_state,
    note_responses_stream_event,
    prepare_responses_request_for_session,
)
from .upstream import build_upstream_headers, build_upstream_websocket_url
from .utils import get_effective_chatgpt_auth


def _log_json(prefix: str, payload: Any) -> None:
    try:
        print(f"{prefix}\n{json.dumps(payload, indent=2, ensure_ascii=False)}")
    except Exception:
        try:
            print(f"{prefix}\n{payload}")
        except Exception:
            pass


def _error_event(message: str, *, status_code: int = 400, code: str | None = None) -> Dict[str, Any]:
    error: Dict[str, Any] = {"message": message}
    if code:
        error["code"] = code
    return {"type": "error", "status_code": status_code, "error": error}


def _is_terminal_event(event: Any) -> bool:
    if not isinstance(event, dict):
        return False
    kind = event.get("type")
    return kind in ("response.completed", "response.failed", "error")


def _build_websocket_ssl_context() -> ssl.SSLContext:
    cafile = (
        os.getenv("CODEX_CA_CERTIFICATE")
        or os.getenv("SSL_CERT_FILE")
        or certifi.where()
    )
    return ssl.create_default_context(cafile=cafile)


def connect_upstream_websocket(url: str, headers: Dict[str, str]):
    return websocket_connect(
        url,
        additional_headers=headers,
        open_timeout=15,
        ssl=_build_websocket_ssl_context(),
    )


def register_websocket_routes(sock: Sock) -> None:
    @sock.route("/v1/responses")
    def responses_websocket(ws) -> None:
        verbose = bool(current_app.config.get("VERBOSE"))
        upstream_ws = None
        upstream_session_id: str | None = None
        active_session_id: str | None = None

        def _send_error(message: str, *, status_code: int = 400, code: str | None = None) -> None:
            evt = _error_event(message, status_code=status_code, code=code)
            if verbose:
                _log_json("STREAM OUT WS /v1/responses (error)", evt)
            try:
                ws.send(json.dumps(evt))
            except Exception:
                pass

        try:
            while True:
                incoming = ws.receive()
                if incoming is None:
                    break

                if isinstance(incoming, bytes):
                    incoming_text = incoming.decode("utf-8", errors="ignore")
                else:
                    incoming_text = str(incoming)
                if verbose:
                    print("IN WS /v1/responses\n" + incoming_text)

                try:
                    payload = json.loads(incoming_text)
                except Exception:
                    _send_error("Websocket frames must be valid JSON objects.", status_code=400)
                    break

                if not isinstance(payload, dict):
                    _send_error("Websocket frames must be JSON objects.", status_code=400)
                    break

                client_session_id = extract_client_session_id(request.headers)
                outbound_text = incoming_text
                session_id = upstream_session_id

                if payload.get("type") == "response.create":
                    try:
                        normalized = normalize_responses_payload(
                            payload,
                            config=current_app.config,
                            client_session_id=client_session_id,
                        )
                    except ResponsesRequestError as exc:
                        _send_error(str(exc), status_code=exc.status_code, code=exc.code)
                        continue

                    if normalized.service_tier_resolution.warning_message and verbose:
                        print(f"[FastMode] {normalized.service_tier_resolution.warning_message}")
                    prepared = prepare_responses_request_for_session(
                        normalized.session_id,
                        normalized.payload,
                        allow_previous_response_id=True,
                    )
                    outbound_text = json.dumps(prepared.payload)
                    session_id = normalized.session_id
                    active_session_id = normalized.session_id
                    if verbose:
                        _log_json("OUTBOUND >> ChatGPT Responses WS payload", prepared.payload)
                elif upstream_ws is None:
                    _send_error(
                        "The first websocket message must be a response.create request.",
                        status_code=400,
                    )
                    break

                if upstream_ws is None or (session_id and session_id != upstream_session_id):
                    access_token, account_id = get_effective_chatgpt_auth()
                    if not access_token or not account_id:
                        if session_id:
                            clear_responses_reuse_state(session_id)
                        _send_error(
                            "Missing ChatGPT credentials. Run 'python3 chatmock.py login' first.",
                            status_code=401,
                        )
                        break

                    if upstream_ws is not None:
                        try:
                            upstream_ws.close()
                        except Exception:
                            pass

                    effective_session_id = session_id or client_session_id or ""
                    try:
                        upstream_ws = connect_upstream_websocket(
                            build_upstream_websocket_url(),
                            build_upstream_headers(
                                access_token,
                                account_id,
                                effective_session_id,
                                accept="application/json",
                            ),
                        )
                    except Exception as exc:
                        if session_id:
                            clear_responses_reuse_state(session_id)
                        _send_error(
                            f"Upstream websocket connection failed: {exc}",
                            status_code=502,
                        )
                        break
                    upstream_session_id = effective_session_id

                upstream_ws.send(outbound_text)

                while True:
                    try:
                        upstream_message = upstream_ws.recv()
                    except ConnectionClosed:
                        if active_session_id:
                            clear_responses_reuse_state(active_session_id)
                        _send_error("Upstream websocket closed unexpectedly.", status_code=502)
                        return
                    if upstream_message is None:
                        if active_session_id:
                            clear_responses_reuse_state(active_session_id)
                        _send_error("Upstream websocket closed unexpectedly.", status_code=502)
                        return
                    if verbose:
                        try:
                            print("STREAM OUT WS /v1/responses\n" + str(upstream_message))
                        except Exception:
                            pass
                    ws.send(upstream_message)

                    try:
                        parsed = json.loads(upstream_message)
                    except Exception:
                        parsed = None
                    if isinstance(parsed, dict) and active_session_id:
                        note_responses_stream_event(active_session_id, parsed)
                    if _is_terminal_event(parsed):
                        if isinstance(parsed, dict) and parsed.get("type") in ("response.failed", "error"):
                            if upstream_ws is not None:
                                try:
                                    upstream_ws.close()
                                except Exception:
                                    pass
                            upstream_ws = None
                            upstream_session_id = None
                        break
        finally:
            if upstream_ws is not None:
                try:
                    upstream_ws.close()
                except Exception:
                    pass