Spaces:
Sleeping
Sleeping
File size: 5,567 Bytes
9b1756a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | """Utilities for parsing model output into structured tool calls."""
from __future__ import annotations
import json
import re
import warnings
from typing import Any
from .models import ToolAction
from .tool_catalog import (
KNOWN_TOOL_NAMES,
ToolValidationError,
canonicalize_tool_name,
validate_tool_arguments,
)
class ToolParseError(ValueError):
"""Raised when model output cannot be converted into a ToolAction."""
class ParseError(ToolParseError):
"""Raised when tool-call JSON cannot be extracted or validated safely."""
class ParseWarning(UserWarning):
"""Warning emitted when the parser must use a weaker extraction fallback."""
def _format_parse_error(message: str, raw_output: str) -> str:
"""Attach a compact raw-output preview to parser failures for debugging."""
preview = raw_output.strip().replace("\n", "\\n")
if len(preview) > 240:
preview = preview[:237] + "..."
return f"{message} Raw output: {preview}"
def _normalize_action_payload(payload: Any, raw_output: str) -> dict[str, Any]:
"""Normalize schema variants into a canonical tool-action payload."""
if not isinstance(payload, dict):
raise ParseError(_format_parse_error("Parsed JSON must be an object.", raw_output))
if "action" in payload:
nested_action = payload["action"]
if not isinstance(nested_action, dict):
raise ParseError(
_format_parse_error("Top-level 'action' must itself be a JSON object.", raw_output)
)
if "reasoning" in payload and "reasoning" not in nested_action:
nested_action = {**nested_action, "reasoning": payload["reasoning"]}
payload = nested_action
if "tool_name" not in payload:
raise ParseError(
_format_parse_error(
"Tool-call JSON must contain either a top-level 'tool_name' or 'action' key.",
raw_output,
)
)
return payload
def _decode_json(candidate: str, raw_output: str) -> dict[str, Any]:
"""Decode one JSON candidate and normalize it to the expected action schema."""
try:
payload = json.loads(candidate)
except json.JSONDecodeError as exc:
raise ParseError(
_format_parse_error(f"Could not decode model output as JSON: {exc}", raw_output)
) from exc
return _normalize_action_payload(payload, raw_output)
def _iter_schema_objects(text: str) -> list[dict[str, Any]]:
"""Scan raw text for standalone JSON objects with tool-action top-level keys."""
decoder = json.JSONDecoder()
matches: list[dict[str, Any]] = []
for index, character in enumerate(text):
if character != "{":
continue
try:
payload, _end_index = decoder.raw_decode(text[index:])
except json.JSONDecodeError:
continue
if isinstance(payload, dict) and ("tool_name" in payload or "action" in payload):
matches.append(payload)
return matches
def parse_with_fallback(llm_output: str, log_warnings: bool = True) -> dict[str, Any]:
"""Parse LLM output with a strict extraction hierarchy and visible fallbacks."""
candidate = llm_output.strip()
fenced_blocks = re.findall(r"```(?:json)?\s*(.*?)\s*```", candidate, re.DOTALL)
for block in fenced_blocks:
try:
return _decode_json(block.strip(), llm_output)
except ParseError:
continue
for payload in _iter_schema_objects(candidate):
return _normalize_action_payload(payload, llm_output)
first = candidate.find("{")
last = candidate.rfind("}")
if first != -1 and last != -1 and last > first:
if log_warnings:
warnings.warn(
"Parser fell back to broad brace extraction because no fenced block or schema-keyed JSON object was found.",
ParseWarning,
stacklevel=2,
)
return _decode_json(candidate[first : last + 1], llm_output)
raise ParseError(_format_parse_error("No JSON object could be extracted from model output.", llm_output))
def extract_json_object(text: str) -> dict[str, Any]:
"""Backward-compatible wrapper around the stricter hierarchical parser."""
return parse_with_fallback(text, log_warnings=True)
def parse_tool_action(
text: str,
*,
allowed_tools: list[str] | None = None,
) -> ToolAction:
"""Parse raw model output into a validated ToolAction."""
payload = parse_with_fallback(text, log_warnings=True)
tool_name = payload.get("tool_name")
arguments = payload.get("arguments", {})
reasoning = payload.get("reasoning")
if not isinstance(tool_name, str) or not tool_name.strip():
raise ParseError(_format_parse_error("tool_name must be a non-empty string.", text))
if reasoning is not None and not isinstance(reasoning, str):
raise ParseError(_format_parse_error("reasoning must be a string when provided.", text))
valid_tools = allowed_tools or list(KNOWN_TOOL_NAMES)
canonical_tool_name = canonicalize_tool_name(tool_name, allowed_tools=valid_tools)
try:
normalized_arguments = validate_tool_arguments(
canonical_tool_name,
arguments,
allowed_tools=valid_tools,
)
except ToolValidationError as exc:
raise ParseError(_format_parse_error(str(exc), text)) from exc
return ToolAction(
tool_name=canonical_tool_name,
arguments=normalized_arguments,
reasoning=reasoning,
)
|