File size: 5,567 Bytes
9b1756a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""Utilities for parsing model output into structured tool calls."""

from __future__ import annotations

import json
import re
import warnings
from typing import Any

from .models import ToolAction
from .tool_catalog import (
    KNOWN_TOOL_NAMES,
    ToolValidationError,
    canonicalize_tool_name,
    validate_tool_arguments,
)


class ToolParseError(ValueError):
    """Raised when model output cannot be converted into a ToolAction."""


class ParseError(ToolParseError):
    """Raised when tool-call JSON cannot be extracted or validated safely."""


class ParseWarning(UserWarning):
    """Warning emitted when the parser must use a weaker extraction fallback."""


def _format_parse_error(message: str, raw_output: str) -> str:
    """Attach a compact raw-output preview to parser failures for debugging."""

    preview = raw_output.strip().replace("\n", "\\n")
    if len(preview) > 240:
        preview = preview[:237] + "..."
    return f"{message} Raw output: {preview}"


def _normalize_action_payload(payload: Any, raw_output: str) -> dict[str, Any]:
    """Normalize schema variants into a canonical tool-action payload."""

    if not isinstance(payload, dict):
        raise ParseError(_format_parse_error("Parsed JSON must be an object.", raw_output))

    if "action" in payload:
        nested_action = payload["action"]
        if not isinstance(nested_action, dict):
            raise ParseError(
                _format_parse_error("Top-level 'action' must itself be a JSON object.", raw_output)
            )
        if "reasoning" in payload and "reasoning" not in nested_action:
            nested_action = {**nested_action, "reasoning": payload["reasoning"]}
        payload = nested_action

    if "tool_name" not in payload:
        raise ParseError(
            _format_parse_error(
                "Tool-call JSON must contain either a top-level 'tool_name' or 'action' key.",
                raw_output,
            )
        )

    return payload


def _decode_json(candidate: str, raw_output: str) -> dict[str, Any]:
    """Decode one JSON candidate and normalize it to the expected action schema."""

    try:
        payload = json.loads(candidate)
    except json.JSONDecodeError as exc:
        raise ParseError(
            _format_parse_error(f"Could not decode model output as JSON: {exc}", raw_output)
        ) from exc

    return _normalize_action_payload(payload, raw_output)


def _iter_schema_objects(text: str) -> list[dict[str, Any]]:
    """Scan raw text for standalone JSON objects with tool-action top-level keys."""

    decoder = json.JSONDecoder()
    matches: list[dict[str, Any]] = []
    for index, character in enumerate(text):
        if character != "{":
            continue
        try:
            payload, _end_index = decoder.raw_decode(text[index:])
        except json.JSONDecodeError:
            continue
        if isinstance(payload, dict) and ("tool_name" in payload or "action" in payload):
            matches.append(payload)
    return matches


def parse_with_fallback(llm_output: str, log_warnings: bool = True) -> dict[str, Any]:
    """Parse LLM output with a strict extraction hierarchy and visible fallbacks."""

    candidate = llm_output.strip()
    fenced_blocks = re.findall(r"```(?:json)?\s*(.*?)\s*```", candidate, re.DOTALL)
    for block in fenced_blocks:
        try:
            return _decode_json(block.strip(), llm_output)
        except ParseError:
            continue

    for payload in _iter_schema_objects(candidate):
        return _normalize_action_payload(payload, llm_output)

    first = candidate.find("{")
    last = candidate.rfind("}")
    if first != -1 and last != -1 and last > first:
        if log_warnings:
            warnings.warn(
                "Parser fell back to broad brace extraction because no fenced block or schema-keyed JSON object was found.",
                ParseWarning,
                stacklevel=2,
            )
        return _decode_json(candidate[first : last + 1], llm_output)

    raise ParseError(_format_parse_error("No JSON object could be extracted from model output.", llm_output))


def extract_json_object(text: str) -> dict[str, Any]:
    """Backward-compatible wrapper around the stricter hierarchical parser."""

    return parse_with_fallback(text, log_warnings=True)


def parse_tool_action(
    text: str,
    *,
    allowed_tools: list[str] | None = None,
) -> ToolAction:
    """Parse raw model output into a validated ToolAction."""

    payload = parse_with_fallback(text, log_warnings=True)
    tool_name = payload.get("tool_name")
    arguments = payload.get("arguments", {})
    reasoning = payload.get("reasoning")

    if not isinstance(tool_name, str) or not tool_name.strip():
        raise ParseError(_format_parse_error("tool_name must be a non-empty string.", text))

    if reasoning is not None and not isinstance(reasoning, str):
        raise ParseError(_format_parse_error("reasoning must be a string when provided.", text))

    valid_tools = allowed_tools or list(KNOWN_TOOL_NAMES)
    canonical_tool_name = canonicalize_tool_name(tool_name, allowed_tools=valid_tools)
    try:
        normalized_arguments = validate_tool_arguments(
            canonical_tool_name,
            arguments,
            allowed_tools=valid_tools,
        )
    except ToolValidationError as exc:
        raise ParseError(_format_parse_error(str(exc), text)) from exc

    return ToolAction(
        tool_name=canonical_tool_name,
        arguments=normalized_arguments,
        reasoning=reasoning,
    )