| from __future__ import annotations |
|
|
| import json |
| from dataclasses import dataclass |
| from typing import Any, Dict, Iterable, Iterator, List |
|
|
| from .config import BASE_INSTRUCTIONS, GPT5_CODEX_INSTRUCTIONS |
| from .fast_mode import ServiceTierResolution, resolve_service_tier |
| from .model_registry import ( |
| allowed_efforts_for_model, |
| extract_reasoning_from_model_name, |
| normalize_model_name, |
| uses_codex_instructions, |
| ) |
| from .reasoning import build_reasoning_param |
| from .session import ensure_session_id |
|
|
|
|
| @dataclass(frozen=True) |
| class ResponsesRequestError(Exception): |
| message: str |
| status_code: int = 400 |
| code: str | None = None |
|
|
| def __str__(self) -> str: |
| return self.message |
|
|
|
|
| @dataclass(frozen=True) |
| class NormalizedResponsesRequest: |
| payload: Dict[str, Any] |
| requested_model: str | None |
| normalized_model: str |
| session_id: str |
| service_tier_resolution: ServiceTierResolution |
|
|
|
|
| def instructions_for_model(config: Dict[str, Any], model: str) -> str: |
| base = config.get("BASE_INSTRUCTIONS", BASE_INSTRUCTIONS) |
| if uses_codex_instructions(model): |
| codex = config.get("GPT5_CODEX_INSTRUCTIONS") or GPT5_CODEX_INSTRUCTIONS |
| if isinstance(codex, str) and codex.strip(): |
| return codex |
| return base |
|
|
|
|
| def extract_client_session_id(headers: Any) -> str | None: |
| try: |
| return headers.get("X-Session-Id") or headers.get("session_id") or None |
| except Exception: |
| return None |
|
|
|
|
| def _input_items_for_session(raw_input: Any) -> List[Dict[str, Any]]: |
| if isinstance(raw_input, list): |
| return [item for item in raw_input if isinstance(item, dict)] |
| if isinstance(raw_input, dict): |
| return [raw_input] |
| if isinstance(raw_input, str) and raw_input.strip(): |
| return [ |
| { |
| "type": "message", |
| "role": "user", |
| "content": [{"type": "input_text", "text": raw_input}], |
| } |
| ] |
| return [] |
|
|
|
|
| def canonicalize_responses_input(raw_input: Any) -> Any: |
| if isinstance(raw_input, list): |
| return [item for item in raw_input if isinstance(item, dict)] |
| if isinstance(raw_input, dict): |
| return [raw_input] |
| if isinstance(raw_input, str): |
| return _input_items_for_session(raw_input) |
| return raw_input |
|
|
|
|
| def normalize_responses_payload( |
| payload: Dict[str, Any], |
| *, |
| config: Dict[str, Any], |
| client_session_id: str | None = None, |
| ) -> NormalizedResponsesRequest: |
| requested_model = payload.get("model") if isinstance(payload.get("model"), str) else None |
| normalized_model = normalize_model_name(requested_model, config.get("DEBUG_MODEL")) |
|
|
| normalized = dict(payload) |
| normalized["model"] = normalized_model |
| normalized.pop("max_output_tokens", None) |
|
|
| if "input" in normalized: |
| normalized["input"] = canonicalize_responses_input(normalized.get("input")) |
|
|
| if "store" not in normalized: |
| normalized["store"] = False |
|
|
| instructions = normalized.get("instructions") |
| if not isinstance(instructions, str) or not instructions.strip(): |
| instructions = instructions_for_model(config, normalized_model) |
| normalized["instructions"] = instructions |
|
|
| reasoning_effort = config.get("REASONING_EFFORT", "medium") |
| reasoning_summary = config.get("REASONING_SUMMARY", "auto") |
| reasoning_overrides = ( |
| normalized.get("reasoning") |
| if isinstance(normalized.get("reasoning"), dict) |
| else extract_reasoning_from_model_name(requested_model) |
| ) |
| normalized["reasoning"] = build_reasoning_param( |
| reasoning_effort, |
| reasoning_summary, |
| reasoning_overrides, |
| allowed_efforts=allowed_efforts_for_model(normalized_model), |
| ) |
|
|
| include = normalized.get("include") |
| include_list = [item for item in include if isinstance(item, str)] if isinstance(include, list) else [] |
| if "reasoning.encrypted_content" not in include_list: |
| include_list.append("reasoning.encrypted_content") |
| normalized["include"] = include_list |
|
|
| tools = normalized.get("tools") |
| if (not isinstance(tools, list) or not tools) and bool(config.get("DEFAULT_WEB_SEARCH")): |
| tool_choice = normalized.get("tool_choice") |
| if not (isinstance(tool_choice, str) and tool_choice.strip().lower() == "none"): |
| normalized["tools"] = [{"type": "web_search"}] |
|
|
| service_tier_resolution = resolve_service_tier( |
| normalized_model, |
| request_fast_mode=normalized.get("fast_mode"), |
| request_service_tier=normalized.get("service_tier"), |
| server_fast_mode=bool(config.get("FAST_MODE")), |
| ) |
| if service_tier_resolution.error_message: |
| raise ResponsesRequestError(service_tier_resolution.error_message) |
| if service_tier_resolution.service_tier is None: |
| normalized.pop("service_tier", None) |
| else: |
| normalized["service_tier"] = service_tier_resolution.service_tier |
| normalized.pop("fast_mode", None) |
|
|
| input_items = _input_items_for_session(normalized.get("input")) |
| session_id = ensure_session_id(instructions, input_items, client_session_id) |
| prompt_cache_key = normalized.get("prompt_cache_key") |
| if not isinstance(prompt_cache_key, str) or not prompt_cache_key.strip(): |
| normalized["prompt_cache_key"] = session_id |
|
|
| return NormalizedResponsesRequest( |
| payload=normalized, |
| requested_model=requested_model, |
| normalized_model=normalized_model, |
| session_id=session_id, |
| service_tier_resolution=service_tier_resolution, |
| ) |
|
|
|
|
| def iter_sse_event_payloads(upstream: Any) -> Iterator[Dict[str, Any]]: |
| for raw in upstream.iter_lines(decode_unicode=False): |
| if not raw: |
| continue |
| line = raw.decode("utf-8", errors="ignore") if isinstance(raw, (bytes, bytearray)) else raw |
| if not line.startswith("data: "): |
| continue |
| data = line[len("data: ") :].strip() |
| if not data or data == "[DONE]": |
| if data == "[DONE]": |
| break |
| continue |
| try: |
| evt = json.loads(data) |
| except Exception: |
| continue |
| if isinstance(evt, dict): |
| yield evt |
|
|
|
|
| def aggregate_response_from_sse( |
| upstream: Any, |
| *, |
| on_event: Any | None = None, |
| ) -> tuple[Dict[str, Any] | None, Dict[str, Any] | None]: |
| response_obj: Dict[str, Any] | None = None |
| error_obj: Dict[str, Any] | None = None |
| try: |
| for evt in iter_sse_event_payloads(upstream): |
| if callable(on_event): |
| try: |
| on_event(evt) |
| except Exception: |
| pass |
| response = evt.get("response") |
| if isinstance(response, dict): |
| response_obj = response |
| kind = evt.get("type") |
| if kind == "response.failed": |
| if isinstance(response, dict) and isinstance(response.get("error"), dict): |
| error_obj = {"error": response.get("error")} |
| else: |
| error_obj = {"error": {"message": "response.failed"}} |
| break |
| if kind == "response.completed": |
| break |
| finally: |
| upstream.close() |
| return response_obj, error_obj |
|
|
|
|
| def stream_upstream_bytes( |
| upstream: Any, |
| *, |
| on_event: Any | None = None, |
| ) -> Iterable[bytes]: |
| buffer = b"" |
| try: |
| for chunk in upstream.iter_content(chunk_size=None): |
| if chunk: |
| if callable(on_event): |
| if isinstance(chunk, bytes): |
| buffer += chunk |
| else: |
| buffer += str(chunk).encode("utf-8", errors="ignore") |
| while b"\n" in buffer: |
| line, buffer = buffer.split(b"\n", 1) |
| line = line.rstrip(b"\r") |
| if not line.startswith(b"data: "): |
| continue |
| data = line[len(b"data: ") :].strip() |
| if not data or data == b"[DONE]": |
| continue |
| try: |
| evt = json.loads(data.decode("utf-8", errors="ignore")) |
| except Exception: |
| evt = None |
| if isinstance(evt, dict): |
| try: |
| on_event(evt) |
| except Exception: |
| pass |
| yield chunk |
| finally: |
| upstream.close() |
|
|