File size: 8,613 Bytes
35205e8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 | from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Any, Dict, Iterable, Iterator, List
from .config import BASE_INSTRUCTIONS, GPT5_CODEX_INSTRUCTIONS
from .fast_mode import ServiceTierResolution, resolve_service_tier
from .model_registry import (
allowed_efforts_for_model,
extract_reasoning_from_model_name,
normalize_model_name,
uses_codex_instructions,
)
from .reasoning import build_reasoning_param
from .session import ensure_session_id
@dataclass(frozen=True)
class ResponsesRequestError(Exception):
message: str
status_code: int = 400
code: str | None = None
def __str__(self) -> str:
return self.message
@dataclass(frozen=True)
class NormalizedResponsesRequest:
payload: Dict[str, Any]
requested_model: str | None
normalized_model: str
session_id: str
service_tier_resolution: ServiceTierResolution
def instructions_for_model(config: Dict[str, Any], model: str) -> str:
base = config.get("BASE_INSTRUCTIONS", BASE_INSTRUCTIONS)
if uses_codex_instructions(model):
codex = config.get("GPT5_CODEX_INSTRUCTIONS") or GPT5_CODEX_INSTRUCTIONS
if isinstance(codex, str) and codex.strip():
return codex
return base
def extract_client_session_id(headers: Any) -> str | None:
try:
return headers.get("X-Session-Id") or headers.get("session_id") or None
except Exception:
return None
def _input_items_for_session(raw_input: Any) -> List[Dict[str, Any]]:
if isinstance(raw_input, list):
return [item for item in raw_input if isinstance(item, dict)]
if isinstance(raw_input, dict):
return [raw_input]
if isinstance(raw_input, str) and raw_input.strip():
return [
{
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": raw_input}],
}
]
return []
def canonicalize_responses_input(raw_input: Any) -> Any:
if isinstance(raw_input, list):
return [item for item in raw_input if isinstance(item, dict)]
if isinstance(raw_input, dict):
return [raw_input]
if isinstance(raw_input, str):
return _input_items_for_session(raw_input)
return raw_input
def normalize_responses_payload(
payload: Dict[str, Any],
*,
config: Dict[str, Any],
client_session_id: str | None = None,
) -> NormalizedResponsesRequest:
requested_model = payload.get("model") if isinstance(payload.get("model"), str) else None
normalized_model = normalize_model_name(requested_model, config.get("DEBUG_MODEL"))
normalized = dict(payload)
normalized["model"] = normalized_model
normalized.pop("max_output_tokens", None)
if "input" in normalized:
normalized["input"] = canonicalize_responses_input(normalized.get("input"))
if "store" not in normalized:
normalized["store"] = False
instructions = normalized.get("instructions")
if not isinstance(instructions, str) or not instructions.strip():
instructions = instructions_for_model(config, normalized_model)
normalized["instructions"] = instructions
reasoning_effort = config.get("REASONING_EFFORT", "medium")
reasoning_summary = config.get("REASONING_SUMMARY", "auto")
reasoning_overrides = (
normalized.get("reasoning")
if isinstance(normalized.get("reasoning"), dict)
else extract_reasoning_from_model_name(requested_model)
)
normalized["reasoning"] = build_reasoning_param(
reasoning_effort,
reasoning_summary,
reasoning_overrides,
allowed_efforts=allowed_efforts_for_model(normalized_model),
)
include = normalized.get("include")
include_list = [item for item in include if isinstance(item, str)] if isinstance(include, list) else []
if "reasoning.encrypted_content" not in include_list:
include_list.append("reasoning.encrypted_content")
normalized["include"] = include_list
tools = normalized.get("tools")
if (not isinstance(tools, list) or not tools) and bool(config.get("DEFAULT_WEB_SEARCH")):
tool_choice = normalized.get("tool_choice")
if not (isinstance(tool_choice, str) and tool_choice.strip().lower() == "none"):
normalized["tools"] = [{"type": "web_search"}]
service_tier_resolution = resolve_service_tier(
normalized_model,
request_fast_mode=normalized.get("fast_mode"),
request_service_tier=normalized.get("service_tier"),
server_fast_mode=bool(config.get("FAST_MODE")),
)
if service_tier_resolution.error_message:
raise ResponsesRequestError(service_tier_resolution.error_message)
if service_tier_resolution.service_tier is None:
normalized.pop("service_tier", None)
else:
normalized["service_tier"] = service_tier_resolution.service_tier
normalized.pop("fast_mode", None)
input_items = _input_items_for_session(normalized.get("input"))
session_id = ensure_session_id(instructions, input_items, client_session_id)
prompt_cache_key = normalized.get("prompt_cache_key")
if not isinstance(prompt_cache_key, str) or not prompt_cache_key.strip():
normalized["prompt_cache_key"] = session_id
return NormalizedResponsesRequest(
payload=normalized,
requested_model=requested_model,
normalized_model=normalized_model,
session_id=session_id,
service_tier_resolution=service_tier_resolution,
)
def iter_sse_event_payloads(upstream: Any) -> Iterator[Dict[str, Any]]:
for raw in upstream.iter_lines(decode_unicode=False):
if not raw:
continue
line = raw.decode("utf-8", errors="ignore") if isinstance(raw, (bytes, bytearray)) else raw
if not line.startswith("data: "):
continue
data = line[len("data: ") :].strip()
if not data or data == "[DONE]":
if data == "[DONE]":
break
continue
try:
evt = json.loads(data)
except Exception:
continue
if isinstance(evt, dict):
yield evt
def aggregate_response_from_sse(
upstream: Any,
*,
on_event: Any | None = None,
) -> tuple[Dict[str, Any] | None, Dict[str, Any] | None]:
response_obj: Dict[str, Any] | None = None
error_obj: Dict[str, Any] | None = None
try:
for evt in iter_sse_event_payloads(upstream):
if callable(on_event):
try:
on_event(evt)
except Exception:
pass
response = evt.get("response")
if isinstance(response, dict):
response_obj = response
kind = evt.get("type")
if kind == "response.failed":
if isinstance(response, dict) and isinstance(response.get("error"), dict):
error_obj = {"error": response.get("error")}
else:
error_obj = {"error": {"message": "response.failed"}}
break
if kind == "response.completed":
break
finally:
upstream.close()
return response_obj, error_obj
def stream_upstream_bytes(
upstream: Any,
*,
on_event: Any | None = None,
) -> Iterable[bytes]:
buffer = b""
try:
for chunk in upstream.iter_content(chunk_size=None):
if chunk:
if callable(on_event):
if isinstance(chunk, bytes):
buffer += chunk
else:
buffer += str(chunk).encode("utf-8", errors="ignore")
while b"\n" in buffer:
line, buffer = buffer.split(b"\n", 1)
line = line.rstrip(b"\r")
if not line.startswith(b"data: "):
continue
data = line[len(b"data: ") :].strip()
if not data or data == b"[DONE]":
continue
try:
evt = json.loads(data.decode("utf-8", errors="ignore"))
except Exception:
evt = None
if isinstance(evt, dict):
try:
on_event(evt)
except Exception:
pass
yield chunk
finally:
upstream.close()
|