Spaces:
Running
Running
| """OpenAI Responses-API compatibility layer for the /v1 developer API. | |
| Pure translation between ml-intern's internal session events and an | |
| OpenAI-Responses-style wire format: SSE event names, the response object, | |
| terminal-status derivation, and artifact extraction. No I/O here — routes | |
| in routes/v1_responses.py own all store/session access. | |
| Internal events arrive either as live broadcaster messages | |
| (``{"event_type", "data", "seq"}``) or persisted Mongo docs (same keys plus | |
| ``_id``/``created_at``) — both shapes are accepted everywhere below. | |
| """ | |
| import json | |
| import re | |
| import uuid | |
| from datetime import UTC, datetime | |
| from typing import Any | |
| RESPONSE_ID_PREFIX = "resp_" | |
| MAX_TOOL_OUTPUT_CHARS = 4096 | |
| # Internal terminal event → response status. ``approval_required`` is handled | |
| # separately: it pauses the response (resumable) rather than ending it. | |
| HARD_TERMINAL_STATUS: dict[str, str] = { | |
| "turn_complete": "completed", | |
| "error": "failed", | |
| "interrupted": "cancelled", | |
| "shutdown": "failed", | |
| } | |
| APPROVAL_EVENT = "approval_required" | |
| TERMINAL_RESPONSE_STATUSES = {"completed", "failed", "cancelled"} | |
| # HF job URLs look like https://huggingface.co/jobs/<namespace>/<job_id> | |
| _JOB_URL_ID_RE = re.compile(r"/jobs/[^/?#]+/([^/?#]+)/?(?:[?#]|$)") | |
| _HF_DATASET_URL_RE = re.compile( | |
| r"https://huggingface\.co/datasets/" | |
| r"([A-Za-z0-9][\w.-]*/[A-Za-z0-9][\w.-]*)(?:[/?#]|$)" | |
| ) | |
| _HF_SPACE_URL_RE = re.compile( | |
| r"https://huggingface\.co/spaces/" | |
| r"([A-Za-z0-9][\w.-]*/[A-Za-z0-9][\w.-]*)(?:[/?#]|$)" | |
| ) | |
| _HF_MODEL_URL_RE = re.compile( | |
| r"https://huggingface\.co/" | |
| r"(?!(?:datasets|spaces|jobs|papers|collections)/)" | |
| r"([A-Za-z0-9][\w.-]*/[A-Za-z0-9][\w.-]*)(?:[/?#]|$)" | |
| ) | |
| class V1APIError(Exception): | |
| """OpenAI-shaped API error: ``{"error": {message, type, code}}``.""" | |
| def __init__( | |
| self, | |
| status_code: int, | |
| message: str, | |
| *, | |
| code: str | None = None, | |
| error_type: str = "invalid_request_error", | |
| ) -> None: | |
| super().__init__(message) | |
| self.status_code = status_code | |
| self.message = message | |
| self.code = code | |
| self.error_type = error_type | |
| def body(self) -> dict[str, Any]: | |
| return { | |
| "error": { | |
| "message": self.message, | |
| "type": self.error_type, | |
| "code": self.code, | |
| } | |
| } | |
| def new_response_id() -> str: | |
| return f"{RESPONSE_ID_PREFIX}{uuid.uuid4().hex}" | |
| def _timestamp(value: Any) -> int | None: | |
| if isinstance(value, datetime): | |
| if value.tzinfo is None: | |
| value = value.replace(tzinfo=UTC) | |
| return int(value.timestamp()) | |
| if isinstance(value, (int, float)): | |
| return int(value) | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Artifacts | |
| # --------------------------------------------------------------------------- | |
| def _job_id_from_url(url: str) -> str | None: | |
| match = _JOB_URL_ID_RE.search(url or "") | |
| return match.group(1) if match else None | |
| def _repo_artifacts_from_text(text: str | None) -> list[dict[str, Any]]: | |
| """Extract Hub repo links from assistant/tool text as best-effort artifacts.""" | |
| if not isinstance(text, str) or "https://huggingface.co/" not in text: | |
| return [] | |
| artifacts: list[dict[str, Any]] = [] | |
| for repo_type, pattern, prefix in ( | |
| ("dataset", _HF_DATASET_URL_RE, "datasets/"), | |
| ("space", _HF_SPACE_URL_RE, "spaces/"), | |
| ("model", _HF_MODEL_URL_RE, ""), | |
| ): | |
| for match in pattern.finditer(text): | |
| repo_id = match.group(1) | |
| artifacts.append( | |
| { | |
| "type": repo_type, | |
| "repo_id": repo_id, | |
| "url": f"https://huggingface.co/{prefix}{repo_id}", | |
| } | |
| ) | |
| return merge_artifacts([], artifacts) | |
| def artifact_key(artifact: dict[str, Any]) -> str: | |
| ident = ( | |
| artifact.get("id") | |
| or artifact.get("repo_id") | |
| or artifact.get("space_id") | |
| or artifact.get("slug") | |
| or artifact.get("url") | |
| or "" | |
| ) | |
| return f"{artifact.get('type')}:{ident}" | |
| def artifacts_from_event(event_type: str, data: dict[str, Any]) -> list[dict[str, Any]]: | |
| """Structured artifacts carried by a single internal event, if any.""" | |
| artifacts: list[dict[str, Any]] = [] | |
| data = data or {} | |
| if event_type == "tool_state_change": | |
| job_url = data.get("jobUrl") | |
| if isinstance(job_url, str) and job_url: | |
| artifacts.append( | |
| { | |
| "type": "hf_job", | |
| "id": _job_id_from_url(job_url), | |
| "url": job_url, | |
| } | |
| ) | |
| space_id = data.get("trackioSpaceId") | |
| if isinstance(space_id, str) and space_id: | |
| artifact: dict[str, Any] = { | |
| "type": "trackio_dashboard", | |
| "space_id": space_id, | |
| "url": f"https://huggingface.co/spaces/{space_id}", | |
| } | |
| project = data.get("trackioProject") | |
| if isinstance(project, str) and project: | |
| artifact["project"] = project | |
| artifacts.append(artifact) | |
| elif event_type == "hub_artifact": | |
| repo_id = data.get("repo_id") | |
| repo_type = data.get("repo_type") or "model" | |
| if isinstance(repo_id, str) and repo_id: | |
| prefix = {"model": "", "dataset": "datasets/", "space": "spaces/"}.get( | |
| repo_type, "" | |
| ) | |
| artifacts.append( | |
| { | |
| "type": repo_type, | |
| "repo_id": repo_id, | |
| "url": f"https://huggingface.co/{prefix}{repo_id}", | |
| } | |
| ) | |
| for key in ("content", "final_response", "output"): | |
| artifacts.extend(_repo_artifacts_from_text(data.get(key))) | |
| return artifacts | |
| def merge_artifacts( | |
| existing: list[dict[str, Any]] | None, | |
| new: list[dict[str, Any]] | None, | |
| ) -> list[dict[str, Any]]: | |
| merged: list[dict[str, Any]] = [] | |
| seen: set[str] = set() | |
| for artifact in [*(existing or []), *(new or [])]: | |
| if not isinstance(artifact, dict) or not artifact.get("type"): | |
| continue | |
| key = artifact_key(artifact) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| merged.append(artifact) | |
| return merged | |
| def extract_artifacts(events: list[dict[str, Any]]) -> list[dict[str, Any]]: | |
| found: list[dict[str, Any]] = [] | |
| for event in events or []: | |
| found.extend( | |
| artifacts_from_event( | |
| str(event.get("event_type") or ""), event.get("data") or {} | |
| ) | |
| ) | |
| return merge_artifacts([], found) | |
| # --------------------------------------------------------------------------- | |
| # Turn state (event-sourced status derivation) | |
| # --------------------------------------------------------------------------- | |
| def derive_turn_state(events: list[dict[str, Any]]) -> dict[str, Any] | None: | |
| """Derive the response status from a turn's event slice. | |
| Returns None when no verdict can be drawn from the events alone (turn | |
| still running — or the server restarted mid-run; callers disambiguate by | |
| checking the live session). ``approval_required`` only counts as a pause | |
| if nothing ran after it (a resumed turn keeps appending events). | |
| """ | |
| last_event: dict[str, Any] | None = None | |
| for event in events or []: | |
| event_type = str(event.get("event_type") or "") | |
| status = HARD_TERMINAL_STATUS.get(event_type) | |
| if status is not None: | |
| data = event.get("data") or {} | |
| error = None | |
| if status == "failed": | |
| error = { | |
| "code": ( | |
| "session_shutdown" | |
| if event_type == "shutdown" | |
| else "agent_error" | |
| ), | |
| "message": str(data.get("error") or event_type), | |
| } | |
| return { | |
| "status": status, | |
| "terminal_event_type": event_type, | |
| "end_seq": event.get("seq"), | |
| "error": error, | |
| "incomplete_details": None, | |
| "final_response": data.get("final_response"), | |
| } | |
| last_event = event | |
| if last_event and str(last_event.get("event_type")) == APPROVAL_EVENT: | |
| return { | |
| "status": "incomplete", | |
| "terminal_event_type": APPROVAL_EVENT, | |
| "end_seq": None, | |
| "error": None, | |
| "incomplete_details": { | |
| "reason": "approval_required", | |
| "approval": last_event.get("data") or {}, | |
| }, | |
| "final_response": None, | |
| } | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Output reconstruction | |
| # --------------------------------------------------------------------------- | |
| def _truncate(text: str, limit: int = MAX_TOOL_OUTPUT_CHARS) -> str: | |
| if len(text) <= limit: | |
| return text | |
| return text[:limit] + f"… [truncated {len(text) - limit} chars]" | |
| def build_output_items(events: list[dict[str, Any]]) -> list[dict[str, Any]]: | |
| """Rebuild OpenAI-style output items from a turn's event slice.""" | |
| items: list[dict[str, Any]] = [] | |
| tool_items: dict[str, dict[str, Any]] = {} | |
| text_parts: list[str] = [] | |
| final_response: str | None = None | |
| def flush_text() -> None: | |
| nonlocal text_parts | |
| text = "".join(text_parts) | |
| text_parts = [] | |
| if text.strip(): | |
| items.append( | |
| { | |
| "type": "message", | |
| "id": f"msg_{len(items)}", | |
| "role": "assistant", | |
| "status": "completed", | |
| "content": [{"type": "output_text", "text": text}], | |
| } | |
| ) | |
| for event in events or []: | |
| event_type = str(event.get("event_type") or "") | |
| data = event.get("data") or {} | |
| if event_type == "assistant_chunk": | |
| content = data.get("content") | |
| if isinstance(content, str): | |
| text_parts.append(content) | |
| elif event_type == "assistant_message": | |
| # Non-streamed full message replaces any partial chunk state. | |
| content = data.get("content") | |
| if isinstance(content, str): | |
| text_parts = [content] | |
| flush_text() | |
| elif event_type == "assistant_stream_end": | |
| flush_text() | |
| elif event_type == "tool_call": | |
| flush_text() | |
| tool_call_id = str(data.get("tool_call_id") or f"call_{len(items)}") | |
| item = { | |
| "type": "custom_tool_call", | |
| "id": tool_call_id, | |
| "name": data.get("tool"), | |
| "input": json.dumps(data.get("arguments") or {}), | |
| "output": None, | |
| "status": "in_progress", | |
| } | |
| tool_items[tool_call_id] = item | |
| items.append(item) | |
| elif event_type == "tool_output": | |
| tool_call_id = str(data.get("tool_call_id") or "") | |
| item = tool_items.get(tool_call_id) | |
| output = data.get("output") | |
| if item is not None: | |
| item["output"] = _truncate(str(output)) if output is not None else None | |
| item["status"] = ( | |
| "completed" if data.get("success", True) else "incomplete" | |
| ) | |
| elif event_type == "turn_complete": | |
| response_text = data.get("final_response") | |
| if isinstance(response_text, str): | |
| final_response = response_text | |
| flush_text() | |
| # turn_complete.final_response is authoritative for the last message: if | |
| # the chunk-reconstructed tail diverges (or is missing), replace it. | |
| if final_response and final_response.strip(): | |
| last = items[-1] if items else None | |
| if ( | |
| last is not None | |
| and last.get("type") == "message" | |
| and last["content"][0]["text"].strip() == final_response.strip() | |
| ): | |
| pass | |
| else: | |
| items.append( | |
| { | |
| "type": "message", | |
| "id": f"msg_{len(items)}", | |
| "role": "assistant", | |
| "status": "completed", | |
| "content": [{"type": "output_text", "text": final_response}], | |
| } | |
| ) | |
| return items | |
| # --------------------------------------------------------------------------- | |
| # Response object | |
| # --------------------------------------------------------------------------- | |
| def build_response_object( | |
| doc: dict[str, Any], | |
| *, | |
| output: list[dict[str, Any]] | None = None, | |
| artifacts: list[dict[str, Any]] | None = None, | |
| usage: dict[str, Any] | None = None, | |
| ) -> dict[str, Any]: | |
| return { | |
| "id": doc.get("_id"), | |
| "object": "response", | |
| "created_at": _timestamp(doc.get("created_at")), | |
| "completed_at": _timestamp(doc.get("completed_at")), | |
| "status": doc.get("status") or "queued", | |
| "model": doc.get("model"), | |
| "background": bool(doc.get("background")), | |
| "previous_response_id": doc.get("previous_response_id"), | |
| "session_id": doc.get("session_id"), | |
| "max_cost_usd": doc.get("max_cost_usd"), | |
| "instructions": doc.get("instructions"), | |
| "output": output if output is not None else (doc.get("output") or []), | |
| "error": doc.get("error"), | |
| "incomplete_details": doc.get("incomplete_details"), | |
| "usage": usage if usage is not None else doc.get("usage"), | |
| "artifacts": ( | |
| artifacts if artifacts is not None else (doc.get("artifacts") or []) | |
| ), | |
| "metadata": doc.get("metadata") or {}, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # SSE translation | |
| # --------------------------------------------------------------------------- | |
| _DROPPED_EVENTS = { | |
| "ready", | |
| "compacted", | |
| "new_complete", | |
| "resume_complete", | |
| "undo_complete", | |
| "session_terminated", | |
| } | |
| def translate_event( | |
| msg: dict[str, Any], response_id: str | |
| ) -> list[tuple[str, dict[str, Any]]]: | |
| """Translate one internal event into zero or more (name, payload) SSE frames. | |
| Terminal frames (``response.completed`` etc.) carry only a light payload | |
| here; the streaming routes merge in the full response snapshot they have | |
| been accumulating. | |
| """ | |
| event_type = str(msg.get("event_type") or "") | |
| data = msg.get("data") or {} | |
| seq = msg.get("seq") | |
| base: dict[str, Any] = {"response_id": response_id} | |
| if seq is not None: | |
| base["sequence_number"] = seq | |
| if event_type in _DROPPED_EVENTS: | |
| return [] | |
| if event_type == "processing": | |
| return [("response.in_progress", base)] | |
| if event_type == "assistant_chunk": | |
| return [("response.output_text.delta", {**base, "delta": data.get("content")})] | |
| if event_type == "assistant_message": | |
| return [("response.output_text.done", {**base, "text": data.get("content")})] | |
| if event_type == "assistant_stream_end": | |
| return [("response.output_text.done", base)] | |
| if event_type == "tool_call": | |
| return [ | |
| ( | |
| "response.output_item.added", | |
| { | |
| **base, | |
| "item": { | |
| "type": "custom_tool_call", | |
| "id": data.get("tool_call_id"), | |
| "name": data.get("tool"), | |
| "input": json.dumps(data.get("arguments") or {}), | |
| "status": "in_progress", | |
| }, | |
| }, | |
| ) | |
| ] | |
| if event_type == "tool_output": | |
| output = data.get("output") | |
| return [ | |
| ( | |
| "response.output_item.done", | |
| { | |
| **base, | |
| "item": { | |
| "type": "custom_tool_call", | |
| "id": data.get("tool_call_id"), | |
| "name": data.get("tool"), | |
| "output": ( | |
| _truncate(str(output)) if output is not None else None | |
| ), | |
| "status": ( | |
| "completed" if data.get("success", True) else "incomplete" | |
| ), | |
| }, | |
| }, | |
| ) | |
| ] | |
| if event_type == "tool_log": | |
| return [("response.tool_log", {**base, **data})] | |
| if event_type == "tool_state_change": | |
| frames: list[tuple[str, dict[str, Any]]] = [ | |
| ("response.tool_state.changed", {**base, **data}) | |
| ] | |
| for artifact in artifacts_from_event(event_type, data): | |
| frames.append(("response.artifact.created", {**base, "artifact": artifact})) | |
| return frames | |
| if event_type == "hub_artifact": | |
| return [ | |
| ("response.artifact.created", {**base, "artifact": artifact}) | |
| for artifact in artifacts_from_event(event_type, data) | |
| ] | |
| if event_type == APPROVAL_EVENT: | |
| return [("response.approval_required", {**base, **data})] | |
| if event_type == "turn_complete": | |
| return [ | |
| ( | |
| "response.completed", | |
| {**base, "final_response": data.get("final_response")}, | |
| ) | |
| ] | |
| if event_type == "error": | |
| return [("response.failed", {**base, "error": data.get("error")})] | |
| if event_type == "interrupted": | |
| return [("response.cancelled", base)] | |
| if event_type == "shutdown": | |
| return [ | |
| ( | |
| "response.failed", | |
| {**base, "error": {"code": "session_shutdown"}}, | |
| ) | |
| ] | |
| # Unknown event types pass through under a generic name so new internal | |
| # events are visible to API consumers without a translator change. | |
| return [(f"response.{event_type}", {**base, **data})] | |
| def format_v1_sse(event_name: str, payload: dict[str, Any]) -> str: | |
| seq = payload.get("sequence_number") | |
| body = json.dumps({"type": event_name, **payload}) | |
| if isinstance(seq, int): | |
| return f"id: {seq}\nevent: {event_name}\ndata: {body}\n\n" | |
| return f"event: {event_name}\ndata: {body}\n\n" | |