lukasgross's picture
Add unmanged sample with a minimal backend (#98)
69f2236
"""FastAPI entrypoint for exchanging workflow ids for ChatKit client secrets."""
from __future__ import annotations
import json
import os
import uuid
from typing import Any, Mapping
import httpx
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
DEFAULT_CHATKIT_BASE = "https://api.openai.com"
SESSION_COOKIE_NAME = "chatkit_session_id"
SESSION_COOKIE_MAX_AGE_SECONDS = 60 * 60 * 24 * 30 # 30 days
app = FastAPI(title="Managed ChatKit Session API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health")
async def health() -> Mapping[str, str]:
return {"status": "ok"}
@app.post("/api/create-session")
async def create_session(request: Request) -> JSONResponse:
"""Exchange a workflow id for a ChatKit client secret."""
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
return respond({"error": "Missing OPENAI_API_KEY environment variable"}, 500)
body = await read_json_body(request)
workflow_id = resolve_workflow_id(body)
if not workflow_id:
return respond({"error": "Missing workflow id"}, 400)
user_id, cookie_value = resolve_user(request.cookies)
api_base = chatkit_api_base()
try:
async with httpx.AsyncClient(base_url=api_base, timeout=10.0) as client:
upstream = await client.post(
"/v1/chatkit/sessions",
headers={
"Authorization": f"Bearer {api_key}",
"OpenAI-Beta": "chatkit_beta=v1",
"Content-Type": "application/json",
},
json={"workflow": {"id": workflow_id}, "user": user_id},
)
except httpx.RequestError as error:
return respond(
{"error": f"Failed to reach ChatKit API: {error}"},
502,
cookie_value,
)
payload = parse_json(upstream)
if not upstream.is_success:
message = None
if isinstance(payload, Mapping):
message = payload.get("error")
message = message or upstream.reason_phrase or "Failed to create session"
return respond({"error": message}, upstream.status_code, cookie_value)
client_secret = None
expires_after = None
if isinstance(payload, Mapping):
client_secret = payload.get("client_secret")
expires_after = payload.get("expires_after")
if not client_secret:
return respond(
{"error": "Missing client secret in response"},
502,
cookie_value,
)
return respond(
{"client_secret": client_secret, "expires_after": expires_after},
200,
cookie_value,
)
def respond(
payload: Mapping[str, Any], status_code: int, cookie_value: str | None = None
) -> JSONResponse:
response = JSONResponse(payload, status_code=status_code)
if cookie_value:
response.set_cookie(
key=SESSION_COOKIE_NAME,
value=cookie_value,
max_age=SESSION_COOKIE_MAX_AGE_SECONDS,
httponly=True,
samesite="lax",
secure=is_prod(),
path="/",
)
return response
def is_prod() -> bool:
env = (os.getenv("ENVIRONMENT") or os.getenv("NODE_ENV") or "").lower()
return env == "production"
async def read_json_body(request: Request) -> Mapping[str, Any]:
raw = await request.body()
if not raw:
return {}
try:
parsed = json.loads(raw)
except json.JSONDecodeError:
return {}
return parsed if isinstance(parsed, Mapping) else {}
def resolve_workflow_id(body: Mapping[str, Any]) -> str | None:
workflow = body.get("workflow", {})
workflow_id = None
if isinstance(workflow, Mapping):
workflow_id = workflow.get("id")
workflow_id = workflow_id or body.get("workflowId")
env_workflow = os.getenv("CHATKIT_WORKFLOW_ID") or os.getenv(
"VITE_CHATKIT_WORKFLOW_ID"
)
if not workflow_id and env_workflow:
workflow_id = env_workflow
if workflow_id and isinstance(workflow_id, str) and workflow_id.strip():
return workflow_id.strip()
return None
def resolve_user(cookies: Mapping[str, str]) -> tuple[str, str | None]:
existing = cookies.get(SESSION_COOKIE_NAME)
if existing:
return existing, None
user_id = str(uuid.uuid4())
return user_id, user_id
def chatkit_api_base() -> str:
return (
os.getenv("CHATKIT_API_BASE")
or os.getenv("VITE_CHATKIT_API_BASE")
or DEFAULT_CHATKIT_BASE
)
def parse_json(response: httpx.Response) -> Mapping[str, Any]:
try:
parsed = response.json()
return parsed if isinstance(parsed, Mapping) else {}
except (json.JSONDecodeError, httpx.DecodingError):
return {}