ml-intern-api / backend /openai_compat.py
abidlabs's picture
abidlabs HF Staff
Extract Hub repo artifacts from response text
7c6e674
Raw
History Blame Contribute Delete
18.4 kB
"""OpenAI Responses-API compatibility layer for the /v1 developer API.
Pure translation between ml-intern's internal session events and an
OpenAI-Responses-style wire format: SSE event names, the response object,
terminal-status derivation, and artifact extraction. No I/O here — routes
in routes/v1_responses.py own all store/session access.
Internal events arrive either as live broadcaster messages
(``{"event_type", "data", "seq"}``) or persisted Mongo docs (same keys plus
``_id``/``created_at``) — both shapes are accepted everywhere below.
"""
import json
import re
import uuid
from datetime import UTC, datetime
from typing import Any
RESPONSE_ID_PREFIX = "resp_"
MAX_TOOL_OUTPUT_CHARS = 4096
# Internal terminal event → response status. ``approval_required`` is handled
# separately: it pauses the response (resumable) rather than ending it.
HARD_TERMINAL_STATUS: dict[str, str] = {
"turn_complete": "completed",
"error": "failed",
"interrupted": "cancelled",
"shutdown": "failed",
}
APPROVAL_EVENT = "approval_required"
TERMINAL_RESPONSE_STATUSES = {"completed", "failed", "cancelled"}
# HF job URLs look like https://huggingface.co/jobs/<namespace>/<job_id>
_JOB_URL_ID_RE = re.compile(r"/jobs/[^/?#]+/([^/?#]+)/?(?:[?#]|$)")
_HF_DATASET_URL_RE = re.compile(
r"https://huggingface\.co/datasets/"
r"([A-Za-z0-9][\w.-]*/[A-Za-z0-9][\w.-]*)(?:[/?#]|$)"
)
_HF_SPACE_URL_RE = re.compile(
r"https://huggingface\.co/spaces/"
r"([A-Za-z0-9][\w.-]*/[A-Za-z0-9][\w.-]*)(?:[/?#]|$)"
)
_HF_MODEL_URL_RE = re.compile(
r"https://huggingface\.co/"
r"(?!(?:datasets|spaces|jobs|papers|collections)/)"
r"([A-Za-z0-9][\w.-]*/[A-Za-z0-9][\w.-]*)(?:[/?#]|$)"
)
class V1APIError(Exception):
"""OpenAI-shaped API error: ``{"error": {message, type, code}}``."""
def __init__(
self,
status_code: int,
message: str,
*,
code: str | None = None,
error_type: str = "invalid_request_error",
) -> None:
super().__init__(message)
self.status_code = status_code
self.message = message
self.code = code
self.error_type = error_type
def body(self) -> dict[str, Any]:
return {
"error": {
"message": self.message,
"type": self.error_type,
"code": self.code,
}
}
def new_response_id() -> str:
return f"{RESPONSE_ID_PREFIX}{uuid.uuid4().hex}"
def _timestamp(value: Any) -> int | None:
if isinstance(value, datetime):
if value.tzinfo is None:
value = value.replace(tzinfo=UTC)
return int(value.timestamp())
if isinstance(value, (int, float)):
return int(value)
return None
# ---------------------------------------------------------------------------
# Artifacts
# ---------------------------------------------------------------------------
def _job_id_from_url(url: str) -> str | None:
match = _JOB_URL_ID_RE.search(url or "")
return match.group(1) if match else None
def _repo_artifacts_from_text(text: str | None) -> list[dict[str, Any]]:
"""Extract Hub repo links from assistant/tool text as best-effort artifacts."""
if not isinstance(text, str) or "https://huggingface.co/" not in text:
return []
artifacts: list[dict[str, Any]] = []
for repo_type, pattern, prefix in (
("dataset", _HF_DATASET_URL_RE, "datasets/"),
("space", _HF_SPACE_URL_RE, "spaces/"),
("model", _HF_MODEL_URL_RE, ""),
):
for match in pattern.finditer(text):
repo_id = match.group(1)
artifacts.append(
{
"type": repo_type,
"repo_id": repo_id,
"url": f"https://huggingface.co/{prefix}{repo_id}",
}
)
return merge_artifacts([], artifacts)
def artifact_key(artifact: dict[str, Any]) -> str:
ident = (
artifact.get("id")
or artifact.get("repo_id")
or artifact.get("space_id")
or artifact.get("slug")
or artifact.get("url")
or ""
)
return f"{artifact.get('type')}:{ident}"
def artifacts_from_event(event_type: str, data: dict[str, Any]) -> list[dict[str, Any]]:
"""Structured artifacts carried by a single internal event, if any."""
artifacts: list[dict[str, Any]] = []
data = data or {}
if event_type == "tool_state_change":
job_url = data.get("jobUrl")
if isinstance(job_url, str) and job_url:
artifacts.append(
{
"type": "hf_job",
"id": _job_id_from_url(job_url),
"url": job_url,
}
)
space_id = data.get("trackioSpaceId")
if isinstance(space_id, str) and space_id:
artifact: dict[str, Any] = {
"type": "trackio_dashboard",
"space_id": space_id,
"url": f"https://huggingface.co/spaces/{space_id}",
}
project = data.get("trackioProject")
if isinstance(project, str) and project:
artifact["project"] = project
artifacts.append(artifact)
elif event_type == "hub_artifact":
repo_id = data.get("repo_id")
repo_type = data.get("repo_type") or "model"
if isinstance(repo_id, str) and repo_id:
prefix = {"model": "", "dataset": "datasets/", "space": "spaces/"}.get(
repo_type, ""
)
artifacts.append(
{
"type": repo_type,
"repo_id": repo_id,
"url": f"https://huggingface.co/{prefix}{repo_id}",
}
)
for key in ("content", "final_response", "output"):
artifacts.extend(_repo_artifacts_from_text(data.get(key)))
return artifacts
def merge_artifacts(
existing: list[dict[str, Any]] | None,
new: list[dict[str, Any]] | None,
) -> list[dict[str, Any]]:
merged: list[dict[str, Any]] = []
seen: set[str] = set()
for artifact in [*(existing or []), *(new or [])]:
if not isinstance(artifact, dict) or not artifact.get("type"):
continue
key = artifact_key(artifact)
if key in seen:
continue
seen.add(key)
merged.append(artifact)
return merged
def extract_artifacts(events: list[dict[str, Any]]) -> list[dict[str, Any]]:
found: list[dict[str, Any]] = []
for event in events or []:
found.extend(
artifacts_from_event(
str(event.get("event_type") or ""), event.get("data") or {}
)
)
return merge_artifacts([], found)
# ---------------------------------------------------------------------------
# Turn state (event-sourced status derivation)
# ---------------------------------------------------------------------------
def derive_turn_state(events: list[dict[str, Any]]) -> dict[str, Any] | None:
"""Derive the response status from a turn's event slice.
Returns None when no verdict can be drawn from the events alone (turn
still running — or the server restarted mid-run; callers disambiguate by
checking the live session). ``approval_required`` only counts as a pause
if nothing ran after it (a resumed turn keeps appending events).
"""
last_event: dict[str, Any] | None = None
for event in events or []:
event_type = str(event.get("event_type") or "")
status = HARD_TERMINAL_STATUS.get(event_type)
if status is not None:
data = event.get("data") or {}
error = None
if status == "failed":
error = {
"code": (
"session_shutdown"
if event_type == "shutdown"
else "agent_error"
),
"message": str(data.get("error") or event_type),
}
return {
"status": status,
"terminal_event_type": event_type,
"end_seq": event.get("seq"),
"error": error,
"incomplete_details": None,
"final_response": data.get("final_response"),
}
last_event = event
if last_event and str(last_event.get("event_type")) == APPROVAL_EVENT:
return {
"status": "incomplete",
"terminal_event_type": APPROVAL_EVENT,
"end_seq": None,
"error": None,
"incomplete_details": {
"reason": "approval_required",
"approval": last_event.get("data") or {},
},
"final_response": None,
}
return None
# ---------------------------------------------------------------------------
# Output reconstruction
# ---------------------------------------------------------------------------
def _truncate(text: str, limit: int = MAX_TOOL_OUTPUT_CHARS) -> str:
if len(text) <= limit:
return text
return text[:limit] + f"… [truncated {len(text) - limit} chars]"
def build_output_items(events: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Rebuild OpenAI-style output items from a turn's event slice."""
items: list[dict[str, Any]] = []
tool_items: dict[str, dict[str, Any]] = {}
text_parts: list[str] = []
final_response: str | None = None
def flush_text() -> None:
nonlocal text_parts
text = "".join(text_parts)
text_parts = []
if text.strip():
items.append(
{
"type": "message",
"id": f"msg_{len(items)}",
"role": "assistant",
"status": "completed",
"content": [{"type": "output_text", "text": text}],
}
)
for event in events or []:
event_type = str(event.get("event_type") or "")
data = event.get("data") or {}
if event_type == "assistant_chunk":
content = data.get("content")
if isinstance(content, str):
text_parts.append(content)
elif event_type == "assistant_message":
# Non-streamed full message replaces any partial chunk state.
content = data.get("content")
if isinstance(content, str):
text_parts = [content]
flush_text()
elif event_type == "assistant_stream_end":
flush_text()
elif event_type == "tool_call":
flush_text()
tool_call_id = str(data.get("tool_call_id") or f"call_{len(items)}")
item = {
"type": "custom_tool_call",
"id": tool_call_id,
"name": data.get("tool"),
"input": json.dumps(data.get("arguments") or {}),
"output": None,
"status": "in_progress",
}
tool_items[tool_call_id] = item
items.append(item)
elif event_type == "tool_output":
tool_call_id = str(data.get("tool_call_id") or "")
item = tool_items.get(tool_call_id)
output = data.get("output")
if item is not None:
item["output"] = _truncate(str(output)) if output is not None else None
item["status"] = (
"completed" if data.get("success", True) else "incomplete"
)
elif event_type == "turn_complete":
response_text = data.get("final_response")
if isinstance(response_text, str):
final_response = response_text
flush_text()
# turn_complete.final_response is authoritative for the last message: if
# the chunk-reconstructed tail diverges (or is missing), replace it.
if final_response and final_response.strip():
last = items[-1] if items else None
if (
last is not None
and last.get("type") == "message"
and last["content"][0]["text"].strip() == final_response.strip()
):
pass
else:
items.append(
{
"type": "message",
"id": f"msg_{len(items)}",
"role": "assistant",
"status": "completed",
"content": [{"type": "output_text", "text": final_response}],
}
)
return items
# ---------------------------------------------------------------------------
# Response object
# ---------------------------------------------------------------------------
def build_response_object(
doc: dict[str, Any],
*,
output: list[dict[str, Any]] | None = None,
artifacts: list[dict[str, Any]] | None = None,
usage: dict[str, Any] | None = None,
) -> dict[str, Any]:
return {
"id": doc.get("_id"),
"object": "response",
"created_at": _timestamp(doc.get("created_at")),
"completed_at": _timestamp(doc.get("completed_at")),
"status": doc.get("status") or "queued",
"model": doc.get("model"),
"background": bool(doc.get("background")),
"previous_response_id": doc.get("previous_response_id"),
"session_id": doc.get("session_id"),
"max_cost_usd": doc.get("max_cost_usd"),
"instructions": doc.get("instructions"),
"output": output if output is not None else (doc.get("output") or []),
"error": doc.get("error"),
"incomplete_details": doc.get("incomplete_details"),
"usage": usage if usage is not None else doc.get("usage"),
"artifacts": (
artifacts if artifacts is not None else (doc.get("artifacts") or [])
),
"metadata": doc.get("metadata") or {},
}
# ---------------------------------------------------------------------------
# SSE translation
# ---------------------------------------------------------------------------
_DROPPED_EVENTS = {
"ready",
"compacted",
"new_complete",
"resume_complete",
"undo_complete",
"session_terminated",
}
def translate_event(
msg: dict[str, Any], response_id: str
) -> list[tuple[str, dict[str, Any]]]:
"""Translate one internal event into zero or more (name, payload) SSE frames.
Terminal frames (``response.completed`` etc.) carry only a light payload
here; the streaming routes merge in the full response snapshot they have
been accumulating.
"""
event_type = str(msg.get("event_type") or "")
data = msg.get("data") or {}
seq = msg.get("seq")
base: dict[str, Any] = {"response_id": response_id}
if seq is not None:
base["sequence_number"] = seq
if event_type in _DROPPED_EVENTS:
return []
if event_type == "processing":
return [("response.in_progress", base)]
if event_type == "assistant_chunk":
return [("response.output_text.delta", {**base, "delta": data.get("content")})]
if event_type == "assistant_message":
return [("response.output_text.done", {**base, "text": data.get("content")})]
if event_type == "assistant_stream_end":
return [("response.output_text.done", base)]
if event_type == "tool_call":
return [
(
"response.output_item.added",
{
**base,
"item": {
"type": "custom_tool_call",
"id": data.get("tool_call_id"),
"name": data.get("tool"),
"input": json.dumps(data.get("arguments") or {}),
"status": "in_progress",
},
},
)
]
if event_type == "tool_output":
output = data.get("output")
return [
(
"response.output_item.done",
{
**base,
"item": {
"type": "custom_tool_call",
"id": data.get("tool_call_id"),
"name": data.get("tool"),
"output": (
_truncate(str(output)) if output is not None else None
),
"status": (
"completed" if data.get("success", True) else "incomplete"
),
},
},
)
]
if event_type == "tool_log":
return [("response.tool_log", {**base, **data})]
if event_type == "tool_state_change":
frames: list[tuple[str, dict[str, Any]]] = [
("response.tool_state.changed", {**base, **data})
]
for artifact in artifacts_from_event(event_type, data):
frames.append(("response.artifact.created", {**base, "artifact": artifact}))
return frames
if event_type == "hub_artifact":
return [
("response.artifact.created", {**base, "artifact": artifact})
for artifact in artifacts_from_event(event_type, data)
]
if event_type == APPROVAL_EVENT:
return [("response.approval_required", {**base, **data})]
if event_type == "turn_complete":
return [
(
"response.completed",
{**base, "final_response": data.get("final_response")},
)
]
if event_type == "error":
return [("response.failed", {**base, "error": data.get("error")})]
if event_type == "interrupted":
return [("response.cancelled", base)]
if event_type == "shutdown":
return [
(
"response.failed",
{**base, "error": {"code": "session_shutdown"}},
)
]
# Unknown event types pass through under a generic name so new internal
# events are visible to API consumers without a translator change.
return [(f"response.{event_type}", {**base, **data})]
def format_v1_sse(event_name: str, payload: dict[str, Any]) -> str:
seq = payload.get("sequence_number")
body = json.dumps({"type": event_name, **payload})
if isinstance(seq, int):
return f"id: {seq}\nevent: {event_name}\ndata: {body}\n\n"
return f"event: {event_name}\ndata: {body}\n\n"