"""OpenAI Responses-API compatibility layer for the /v1 developer API. Pure translation between ml-intern's internal session events and an OpenAI-Responses-style wire format: SSE event names, the response object, terminal-status derivation, and artifact extraction. No I/O here — routes in routes/v1_responses.py own all store/session access. Internal events arrive either as live broadcaster messages (``{"event_type", "data", "seq"}``) or persisted Mongo docs (same keys plus ``_id``/``created_at``) — both shapes are accepted everywhere below. """ import json import re import uuid from datetime import UTC, datetime from typing import Any RESPONSE_ID_PREFIX = "resp_" MAX_TOOL_OUTPUT_CHARS = 4096 # Internal terminal event → response status. ``approval_required`` is handled # separately: it pauses the response (resumable) rather than ending it. HARD_TERMINAL_STATUS: dict[str, str] = { "turn_complete": "completed", "error": "failed", "interrupted": "cancelled", "shutdown": "failed", } APPROVAL_EVENT = "approval_required" TERMINAL_RESPONSE_STATUSES = {"completed", "failed", "cancelled"} # HF job URLs look like https://huggingface.co/jobs// _JOB_URL_ID_RE = re.compile(r"/jobs/[^/?#]+/([^/?#]+)/?(?:[?#]|$)") _HF_DATASET_URL_RE = re.compile( r"https://huggingface\.co/datasets/" r"([A-Za-z0-9][\w.-]*/[A-Za-z0-9][\w.-]*)(?:[/?#]|$)" ) _HF_SPACE_URL_RE = re.compile( r"https://huggingface\.co/spaces/" r"([A-Za-z0-9][\w.-]*/[A-Za-z0-9][\w.-]*)(?:[/?#]|$)" ) _HF_MODEL_URL_RE = re.compile( r"https://huggingface\.co/" r"(?!(?:datasets|spaces|jobs|papers|collections)/)" r"([A-Za-z0-9][\w.-]*/[A-Za-z0-9][\w.-]*)(?:[/?#]|$)" ) class V1APIError(Exception): """OpenAI-shaped API error: ``{"error": {message, type, code}}``.""" def __init__( self, status_code: int, message: str, *, code: str | None = None, error_type: str = "invalid_request_error", ) -> None: super().__init__(message) self.status_code = status_code self.message = message self.code = code self.error_type = error_type def body(self) -> dict[str, Any]: return { "error": { "message": self.message, "type": self.error_type, "code": self.code, } } def new_response_id() -> str: return f"{RESPONSE_ID_PREFIX}{uuid.uuid4().hex}" def _timestamp(value: Any) -> int | None: if isinstance(value, datetime): if value.tzinfo is None: value = value.replace(tzinfo=UTC) return int(value.timestamp()) if isinstance(value, (int, float)): return int(value) return None # --------------------------------------------------------------------------- # Artifacts # --------------------------------------------------------------------------- def _job_id_from_url(url: str) -> str | None: match = _JOB_URL_ID_RE.search(url or "") return match.group(1) if match else None def _repo_artifacts_from_text(text: str | None) -> list[dict[str, Any]]: """Extract Hub repo links from assistant/tool text as best-effort artifacts.""" if not isinstance(text, str) or "https://huggingface.co/" not in text: return [] artifacts: list[dict[str, Any]] = [] for repo_type, pattern, prefix in ( ("dataset", _HF_DATASET_URL_RE, "datasets/"), ("space", _HF_SPACE_URL_RE, "spaces/"), ("model", _HF_MODEL_URL_RE, ""), ): for match in pattern.finditer(text): repo_id = match.group(1) artifacts.append( { "type": repo_type, "repo_id": repo_id, "url": f"https://huggingface.co/{prefix}{repo_id}", } ) return merge_artifacts([], artifacts) def artifact_key(artifact: dict[str, Any]) -> str: ident = ( artifact.get("id") or artifact.get("repo_id") or artifact.get("space_id") or artifact.get("slug") or artifact.get("url") or "" ) return f"{artifact.get('type')}:{ident}" def artifacts_from_event(event_type: str, data: dict[str, Any]) -> list[dict[str, Any]]: """Structured artifacts carried by a single internal event, if any.""" artifacts: list[dict[str, Any]] = [] data = data or {} if event_type == "tool_state_change": job_url = data.get("jobUrl") if isinstance(job_url, str) and job_url: artifacts.append( { "type": "hf_job", "id": _job_id_from_url(job_url), "url": job_url, } ) space_id = data.get("trackioSpaceId") if isinstance(space_id, str) and space_id: artifact: dict[str, Any] = { "type": "trackio_dashboard", "space_id": space_id, "url": f"https://huggingface.co/spaces/{space_id}", } project = data.get("trackioProject") if isinstance(project, str) and project: artifact["project"] = project artifacts.append(artifact) elif event_type == "hub_artifact": repo_id = data.get("repo_id") repo_type = data.get("repo_type") or "model" if isinstance(repo_id, str) and repo_id: prefix = {"model": "", "dataset": "datasets/", "space": "spaces/"}.get( repo_type, "" ) artifacts.append( { "type": repo_type, "repo_id": repo_id, "url": f"https://huggingface.co/{prefix}{repo_id}", } ) for key in ("content", "final_response", "output"): artifacts.extend(_repo_artifacts_from_text(data.get(key))) return artifacts def merge_artifacts( existing: list[dict[str, Any]] | None, new: list[dict[str, Any]] | None, ) -> list[dict[str, Any]]: merged: list[dict[str, Any]] = [] seen: set[str] = set() for artifact in [*(existing or []), *(new or [])]: if not isinstance(artifact, dict) or not artifact.get("type"): continue key = artifact_key(artifact) if key in seen: continue seen.add(key) merged.append(artifact) return merged def extract_artifacts(events: list[dict[str, Any]]) -> list[dict[str, Any]]: found: list[dict[str, Any]] = [] for event in events or []: found.extend( artifacts_from_event( str(event.get("event_type") or ""), event.get("data") or {} ) ) return merge_artifacts([], found) # --------------------------------------------------------------------------- # Turn state (event-sourced status derivation) # --------------------------------------------------------------------------- def derive_turn_state(events: list[dict[str, Any]]) -> dict[str, Any] | None: """Derive the response status from a turn's event slice. Returns None when no verdict can be drawn from the events alone (turn still running — or the server restarted mid-run; callers disambiguate by checking the live session). ``approval_required`` only counts as a pause if nothing ran after it (a resumed turn keeps appending events). """ last_event: dict[str, Any] | None = None for event in events or []: event_type = str(event.get("event_type") or "") status = HARD_TERMINAL_STATUS.get(event_type) if status is not None: data = event.get("data") or {} error = None if status == "failed": error = { "code": ( "session_shutdown" if event_type == "shutdown" else "agent_error" ), "message": str(data.get("error") or event_type), } return { "status": status, "terminal_event_type": event_type, "end_seq": event.get("seq"), "error": error, "incomplete_details": None, "final_response": data.get("final_response"), } last_event = event if last_event and str(last_event.get("event_type")) == APPROVAL_EVENT: return { "status": "incomplete", "terminal_event_type": APPROVAL_EVENT, "end_seq": None, "error": None, "incomplete_details": { "reason": "approval_required", "approval": last_event.get("data") or {}, }, "final_response": None, } return None # --------------------------------------------------------------------------- # Output reconstruction # --------------------------------------------------------------------------- def _truncate(text: str, limit: int = MAX_TOOL_OUTPUT_CHARS) -> str: if len(text) <= limit: return text return text[:limit] + f"… [truncated {len(text) - limit} chars]" def build_output_items(events: list[dict[str, Any]]) -> list[dict[str, Any]]: """Rebuild OpenAI-style output items from a turn's event slice.""" items: list[dict[str, Any]] = [] tool_items: dict[str, dict[str, Any]] = {} text_parts: list[str] = [] final_response: str | None = None def flush_text() -> None: nonlocal text_parts text = "".join(text_parts) text_parts = [] if text.strip(): items.append( { "type": "message", "id": f"msg_{len(items)}", "role": "assistant", "status": "completed", "content": [{"type": "output_text", "text": text}], } ) for event in events or []: event_type = str(event.get("event_type") or "") data = event.get("data") or {} if event_type == "assistant_chunk": content = data.get("content") if isinstance(content, str): text_parts.append(content) elif event_type == "assistant_message": # Non-streamed full message replaces any partial chunk state. content = data.get("content") if isinstance(content, str): text_parts = [content] flush_text() elif event_type == "assistant_stream_end": flush_text() elif event_type == "tool_call": flush_text() tool_call_id = str(data.get("tool_call_id") or f"call_{len(items)}") item = { "type": "custom_tool_call", "id": tool_call_id, "name": data.get("tool"), "input": json.dumps(data.get("arguments") or {}), "output": None, "status": "in_progress", } tool_items[tool_call_id] = item items.append(item) elif event_type == "tool_output": tool_call_id = str(data.get("tool_call_id") or "") item = tool_items.get(tool_call_id) output = data.get("output") if item is not None: item["output"] = _truncate(str(output)) if output is not None else None item["status"] = ( "completed" if data.get("success", True) else "incomplete" ) elif event_type == "turn_complete": response_text = data.get("final_response") if isinstance(response_text, str): final_response = response_text flush_text() # turn_complete.final_response is authoritative for the last message: if # the chunk-reconstructed tail diverges (or is missing), replace it. if final_response and final_response.strip(): last = items[-1] if items else None if ( last is not None and last.get("type") == "message" and last["content"][0]["text"].strip() == final_response.strip() ): pass else: items.append( { "type": "message", "id": f"msg_{len(items)}", "role": "assistant", "status": "completed", "content": [{"type": "output_text", "text": final_response}], } ) return items # --------------------------------------------------------------------------- # Response object # --------------------------------------------------------------------------- def build_response_object( doc: dict[str, Any], *, output: list[dict[str, Any]] | None = None, artifacts: list[dict[str, Any]] | None = None, usage: dict[str, Any] | None = None, ) -> dict[str, Any]: return { "id": doc.get("_id"), "object": "response", "created_at": _timestamp(doc.get("created_at")), "completed_at": _timestamp(doc.get("completed_at")), "status": doc.get("status") or "queued", "model": doc.get("model"), "background": bool(doc.get("background")), "previous_response_id": doc.get("previous_response_id"), "session_id": doc.get("session_id"), "max_cost_usd": doc.get("max_cost_usd"), "instructions": doc.get("instructions"), "output": output if output is not None else (doc.get("output") or []), "error": doc.get("error"), "incomplete_details": doc.get("incomplete_details"), "usage": usage if usage is not None else doc.get("usage"), "artifacts": ( artifacts if artifacts is not None else (doc.get("artifacts") or []) ), "metadata": doc.get("metadata") or {}, } # --------------------------------------------------------------------------- # SSE translation # --------------------------------------------------------------------------- _DROPPED_EVENTS = { "ready", "compacted", "new_complete", "resume_complete", "undo_complete", "session_terminated", } def translate_event( msg: dict[str, Any], response_id: str ) -> list[tuple[str, dict[str, Any]]]: """Translate one internal event into zero or more (name, payload) SSE frames. Terminal frames (``response.completed`` etc.) carry only a light payload here; the streaming routes merge in the full response snapshot they have been accumulating. """ event_type = str(msg.get("event_type") or "") data = msg.get("data") or {} seq = msg.get("seq") base: dict[str, Any] = {"response_id": response_id} if seq is not None: base["sequence_number"] = seq if event_type in _DROPPED_EVENTS: return [] if event_type == "processing": return [("response.in_progress", base)] if event_type == "assistant_chunk": return [("response.output_text.delta", {**base, "delta": data.get("content")})] if event_type == "assistant_message": return [("response.output_text.done", {**base, "text": data.get("content")})] if event_type == "assistant_stream_end": return [("response.output_text.done", base)] if event_type == "tool_call": return [ ( "response.output_item.added", { **base, "item": { "type": "custom_tool_call", "id": data.get("tool_call_id"), "name": data.get("tool"), "input": json.dumps(data.get("arguments") or {}), "status": "in_progress", }, }, ) ] if event_type == "tool_output": output = data.get("output") return [ ( "response.output_item.done", { **base, "item": { "type": "custom_tool_call", "id": data.get("tool_call_id"), "name": data.get("tool"), "output": ( _truncate(str(output)) if output is not None else None ), "status": ( "completed" if data.get("success", True) else "incomplete" ), }, }, ) ] if event_type == "tool_log": return [("response.tool_log", {**base, **data})] if event_type == "tool_state_change": frames: list[tuple[str, dict[str, Any]]] = [ ("response.tool_state.changed", {**base, **data}) ] for artifact in artifacts_from_event(event_type, data): frames.append(("response.artifact.created", {**base, "artifact": artifact})) return frames if event_type == "hub_artifact": return [ ("response.artifact.created", {**base, "artifact": artifact}) for artifact in artifacts_from_event(event_type, data) ] if event_type == APPROVAL_EVENT: return [("response.approval_required", {**base, **data})] if event_type == "turn_complete": return [ ( "response.completed", {**base, "final_response": data.get("final_response")}, ) ] if event_type == "error": return [("response.failed", {**base, "error": data.get("error")})] if event_type == "interrupted": return [("response.cancelled", base)] if event_type == "shutdown": return [ ( "response.failed", {**base, "error": {"code": "session_shutdown"}}, ) ] # Unknown event types pass through under a generic name so new internal # events are visible to API consumers without a translator change. return [(f"response.{event_type}", {**base, **data})] def format_v1_sse(event_name: str, payload: dict[str, Any]) -> str: seq = payload.get("sequence_number") body = json.dumps({"type": event_name, **payload}) if isinstance(seq, int): return f"id: {seq}\nevent: {event_name}\ndata: {body}\n\n" return f"event: {event_name}\ndata: {body}\n\n"