CheckMat / chatmock /oauth.py
aiqknow's picture
Upload 97 files
35205e8 verified
from __future__ import annotations
import datetime
import ssl
import http.server
import json
import secrets
import threading
import time
import urllib.parse
import urllib.request
from typing import Any, Dict, Tuple
import certifi
from .config import OAUTH_ISSUER_DEFAULT
from .models import AuthBundle, PkceCodes, TokenData
from .utils import eprint, generate_pkce, parse_jwt_claims, write_auth_file
REQUIRED_PORT = 1455
URL_BASE = f"http://localhost:{REQUIRED_PORT}"
DEFAULT_ISSUER = OAUTH_ISSUER_DEFAULT
LOGIN_SUCCESS_HTML = """<!DOCTYPE html>
<html lang=\"en\">
<head>
<meta charset=\"utf-8\" />
<title>Login successful</title>
</head>
<body>
<div style=\"max-width: 640px; margin: 80px auto; font-family: system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial, sans-serif;\">
<h1>Login successful</h1>
<p>You can now close this window and return to the terminal and run <code>python3 chatmock.py serve</code> to start the server.</p>
</div>
</body>
</html>
"""
_SSL_CONTEXT = ssl.create_default_context(cafile=certifi.where())
class OAuthHTTPServer(http.server.HTTPServer):
def __init__(
self,
server_address: tuple[str, int],
request_handler_class: type[http.server.BaseHTTPRequestHandler],
*,
home_dir: str,
client_id: str,
verbose: bool = False,
) -> None:
super().__init__(server_address, request_handler_class, bind_and_activate=True)
self.exit_code = 1
self.home_dir = home_dir
self.verbose = verbose
self.issuer = DEFAULT_ISSUER
self.token_endpoint = f"{self.issuer}/oauth/token"
self.client_id = client_id
port = server_address[1]
self.redirect_uri = f"http://localhost:{port}/auth/callback"
self.pkce = generate_pkce()
self.state = secrets.token_hex(32)
def auth_url(self) -> str:
params = {
"response_type": "code",
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"scope": "openid profile email offline_access",
"code_challenge": self.pkce.code_challenge,
"code_challenge_method": "S256",
"id_token_add_organizations": "true",
"codex_cli_simplified_flow": "true",
"state": self.state,
}
return f"{self.issuer}/oauth/authorize?" + urllib.parse.urlencode(params)
def exchange_code(self, code: str) -> tuple[AuthBundle, str]:
data = urllib.parse.urlencode(
{
"grant_type": "authorization_code",
"code": code,
"redirect_uri": self.redirect_uri,
"client_id": self.client_id,
"code_verifier": self.pkce.code_verifier,
}
).encode()
with urllib.request.urlopen(
urllib.request.Request(
self.token_endpoint,
data=data,
method="POST",
headers={"Content-Type": "application/x-www-form-urlencoded"},
),
context=_SSL_CONTEXT,
) as resp:
payload = json.loads(resp.read().decode())
id_token = payload.get("id_token", "")
access_token = payload.get("access_token", "")
refresh_token = payload.get("refresh_token", "")
id_token_claims = parse_jwt_claims(id_token)
access_token_claims = parse_jwt_claims(access_token)
auth_claims = (id_token_claims or {}).get("https://api.openai.com/auth", {})
chatgpt_account_id = auth_claims.get("chatgpt_account_id", "")
token_data = TokenData(
id_token=id_token,
access_token=access_token,
refresh_token=refresh_token,
account_id=chatgpt_account_id,
)
api_key, success_url = self.maybe_obtain_api_key(
id_token_claims or {}, access_token_claims or {}, token_data
)
last_refresh_str = (
datetime.datetime.now(datetime.timezone.utc).isoformat().replace("+00:00", "Z")
)
bundle = AuthBundle(api_key=api_key, token_data=token_data, last_refresh=last_refresh_str)
return bundle, success_url or f"{URL_BASE}/success"
def maybe_obtain_api_key(
self,
token_claims: Dict[str, Any],
access_claims: Dict[str, Any],
token_data: TokenData,
) -> tuple[str | None, str | None]:
org_id = token_claims.get("organization_id")
project_id = token_claims.get("project_id")
if not org_id or not project_id:
query = {
"id_token": token_data.id_token,
"needs_setup": "false",
"org_id": org_id or "",
"project_id": project_id or "",
"plan_type": access_claims.get("chatgpt_plan_type"),
"platform_url": "https://platform.openai.com",
}
return None, f"{URL_BASE}/success?{urllib.parse.urlencode(query)}"
today = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d")
exchange_data = urllib.parse.urlencode(
{
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
"client_id": self.client_id,
"requested_token": "openai-api-key",
"subject_token": token_data.id_token,
"subject_token_type": "urn:ietf:params:oauth:token-type:id_token",
"name": f"ChatMock [auto-generated] ({today})",
}
).encode()
with urllib.request.urlopen(
urllib.request.Request(
self.token_endpoint,
data=exchange_data,
method="POST",
headers={"Content-Type": "application/x-www-form-urlencoded"},
),
context=_SSL_CONTEXT,
) as resp:
exchange_payload = json.loads(resp.read().decode())
exchanged_access_token = exchange_payload.get("access_token")
chatgpt_plan_type = access_claims.get("chatgpt_plan_type")
success_url_query = {
"id_token": token_data.id_token,
"access_token": token_data.access_token,
"refresh_token": token_data.refresh_token,
"exchanged_access_token": exchanged_access_token,
"org_id": org_id,
"project_id": project_id,
"plan_type": chatgpt_plan_type,
"platform_url": "https://platform.openai.com",
}
success_url = f"{URL_BASE}/success?{urllib.parse.urlencode(success_url_query)}"
return exchanged_access_token, success_url
def persist_auth(self, bundle: AuthBundle) -> bool:
auth_json_contents = {
"OPENAI_API_KEY": bundle.api_key,
"tokens": {
"id_token": bundle.token_data.id_token,
"access_token": bundle.token_data.access_token,
"refresh_token": bundle.token_data.refresh_token,
"account_id": bundle.token_data.account_id,
},
"last_refresh": bundle.last_refresh,
}
return write_auth_file(auth_json_contents)
class OAuthHandler(http.server.BaseHTTPRequestHandler):
server: "OAuthHTTPServer"
def do_GET(self) -> None:
path = urllib.parse.urlparse(self.path).path
if path == "/success":
self._send_html(LOGIN_SUCCESS_HTML)
try:
self.wfile.flush()
except Exception as e:
eprint(f"Failed to flush response: {e}")
self._shutdown_after_delay(2.0)
return
if path != "/auth/callback":
self.send_error(404, "Not Found")
self._shutdown()
return
query = urllib.parse.urlparse(self.path).query
params = urllib.parse.parse_qs(query)
code = params.get("code", [None])[0]
if not code:
self.send_error(400, "Missing auth code")
self._shutdown()
return
try:
auth_bundle, success_url = self._exchange_code(code)
except Exception as exc:
self.send_error(500, f"Token exchange failed: {exc}")
self._shutdown()
return
auth_json_contents = {
"OPENAI_API_KEY": auth_bundle.api_key,
"tokens": {
"id_token": auth_bundle.token_data.id_token,
"access_token": auth_bundle.token_data.access_token,
"refresh_token": auth_bundle.token_data.refresh_token,
"account_id": auth_bundle.token_data.account_id,
},
"last_refresh": auth_bundle.last_refresh,
}
if write_auth_file(auth_json_contents):
self.server.exit_code = 0
self._send_html(LOGIN_SUCCESS_HTML)
else:
self.send_error(500, "Unable to persist auth file")
self._shutdown_after_delay(2.0)
def do_POST(self) -> None:
self.send_error(404, "Not Found")
self._shutdown()
def log_message(self, fmt: str, *args):
if getattr(self.server, "verbose", False):
super().log_message(fmt, *args)
def _send_redirect(self, url: str) -> None:
self.send_response(302)
self.send_header("Location", url)
self.end_headers()
def _send_html(self, body: str) -> None:
encoded = body.encode()
self.send_response(200)
self.send_header("Content-Type", "text/html; charset=utf-8")
self.send_header("Content-Length", str(len(encoded)))
self.end_headers()
self.wfile.write(encoded)
def _shutdown(self) -> None:
threading.Thread(target=self.server.shutdown, daemon=True).start()
def _shutdown_after_delay(self, seconds: float = 2.0) -> None:
def _later():
try:
time.sleep(seconds)
finally:
self._shutdown()
threading.Thread(target=_later, daemon=True).start()
def _exchange_code(self, code: str) -> Tuple[AuthBundle, str]:
return self.server.exchange_code(code)
def _maybe_obtain_api_key(
self,
token_claims: Dict[str, Any],
access_claims: Dict[str, Any],
token_data: TokenData,
) -> Tuple[str | None, str | None]:
org_id = token_claims.get("organization_id")
project_id = token_claims.get("project_id")
if not org_id or not project_id:
query = {
"id_token": token_data.id_token,
"needs_setup": "false",
"org_id": org_id or "",
"project_id": project_id or "",
"plan_type": access_claims.get("chatgpt_plan_type"),
"platform_url": "https://platform.openai.com",
}
return None, f"{URL_BASE}/success?{urllib.parse.urlencode(query)}"
today = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d")
exchange_data = urllib.parse.urlencode(
{
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
"client_id": self.server.client_id,
"requested_token": "openai-api-key",
"subject_token": token_data.id_token,
"subject_token_type": "urn:ietf:params:oauth:token-type:id_token",
"name": f"ChatMock [auto-generated] ({today})",
}
).encode()
with urllib.request.urlopen(
urllib.request.Request(
self.server.token_endpoint,
data=exchange_data,
method="POST",
headers={"Content-Type": "application/x-www-form-urlencoded"},
),
context=_SSL_CONTEXT,
) as resp:
exchange_payload = json.loads(resp.read().decode())
exchanged_access_token = exchange_payload.get("access_token")
chatgpt_plan_type = access_claims.get("chatgpt_plan_type")
success_url_query = {
"id_token": token_data.id_token,
"needs_setup": "false",
"org_id": org_id,
"project_id": project_id,
"plan_type": chatgpt_plan_type,
"platform_url": "https://platform.openai.com",
}
success_url = f"{URL_BASE}/success?{urllib.parse.urlencode(success_url_query)}"
return exchanged_access_token, success_url