"""
Model output parser.
Responsible for extracting structured information from the raw text
generated by the LLM:
1. Detect tool calls -- look for ``...`` XML tags
(Qwen style) or bare JSON objects with ``name``/``arguments`` keys.
2. Extract the tool name and argument dict.
3. Validate that the tool name is known and arguments match the schema.
4. Detect a final answer (no tool call present).
5. Handle malformed output via regex fallback and heuristic correction.
"""
from __future__ import annotations
import json
import logging
import re
from dataclasses import dataclass, field
from typing import Any
logger = logging.getLogger(__name__)
# ====================================================================== #
# Parsed result types
# ====================================================================== #
@dataclass
class ParsedAction:
"""Result of parsing a single model response."""
is_tool_call: bool = False
tool_name: str = ""
tool_args: dict[str, Any] = field(default_factory=dict)
thought: str = ""
answer: str = ""
raw: str = ""
parse_error: str | None = None
# ====================================================================== #
# Regex patterns
# ====================================================================== #
# {"name": "...", "arguments": {...}}
_TOOL_CALL_XML_RE = re.compile(
r"\s*(\{.*?\})\s*",
re.DOTALL,
)
# Bare JSON with "name" and "arguments" keys (fallback).
_TOOL_CALL_JSON_RE = re.compile(
r'\{\s*"name"\s*:\s*"(\w+)"\s*,\s*"arguments"\s*:\s*(\{.*?\})\s*\}',
re.DOTALL,
)
# Thought: ... (everything before the tool call or answer).
_THOUGHT_RE = re.compile(
r"(?:Thought|Thinking|Reasoning)\s*:\s*(.+?)(?=\n\s*(?:Action|Answer||\{)|\Z)",
re.DOTALL | re.IGNORECASE,
)
# Answer: ... or Final Answer: ...
_ANSWER_RE = re.compile(
r"(?:Final\s+)?Answer\s*:\s*(.+)",
re.DOTALL | re.IGNORECASE,
)
# Action: tool_name(arg=value, ...) -- relaxed fallback.
_ACTION_FUNC_RE = re.compile(
r"Action\s*:\s*(\w+)\s*\(([^)]*)\)",
re.IGNORECASE,
)
# ====================================================================== #
# Public API
# ====================================================================== #
def parse_model_output(
text: str,
known_tools: list[str] | None = None,
) -> ParsedAction:
"""Parse a model completion into a structured ``ParsedAction``.
Parameters
----------
text:
Raw model output.
known_tools:
Optional list of registered tool names for validation.
Returns
-------
ParsedAction
"""
result = ParsedAction(raw=text)
# Extract thought.
thought_match = _THOUGHT_RE.search(text)
if thought_match:
result.thought = thought_match.group(1).strip()
# --- Strategy 1: XML-tagged tool call ---
xml_match = _TOOL_CALL_XML_RE.search(text)
if xml_match:
return _parse_json_tool_call(xml_match.group(1), result, known_tools)
# --- Strategy 2: Bare JSON tool call ---
json_match = _TOOL_CALL_JSON_RE.search(text)
if json_match:
name = json_match.group(1)
args_str = json_match.group(2)
try:
args = json.loads(args_str)
except json.JSONDecodeError:
args = _attempt_json_repair(args_str)
if args is not None:
result.is_tool_call = True
result.tool_name = name
result.tool_args = args
return _validate_tool(result, known_tools)
# --- Strategy 3: Action: func(args) fallback ---
func_match = _ACTION_FUNC_RE.search(text)
if func_match:
name = func_match.group(1)
args_raw = func_match.group(2).strip()
args = _parse_kv_args(args_raw)
result.is_tool_call = True
result.tool_name = name
result.tool_args = args
return _validate_tool(result, known_tools)
# --- No tool call detected -- look for an answer ---
answer_match = _ANSWER_RE.search(text)
if answer_match:
result.answer = answer_match.group(1).strip()
return result
# If the model produced text but no recognizable pattern, treat the
# entire output as the answer (common for simple questions).
if text.strip():
# But only if there is no apparent intent to call a tool.
if not any(kw in text.lower() for kw in ("tool_call", "action:", "function")):
result.answer = text.strip()
else:
result.parse_error = (
"Could not parse tool call from model output. "
"Please use the format: {\"name\": \"tool_name\", "
"\"arguments\": {...}}"
)
return result
# ====================================================================== #
# Internal helpers
# ====================================================================== #
def _parse_json_tool_call(
json_str: str,
result: ParsedAction,
known_tools: list[str] | None,
) -> ParsedAction:
"""Parse a JSON tool-call object."""
try:
obj = json.loads(json_str)
except json.JSONDecodeError:
obj = _attempt_json_repair(json_str)
if obj is None:
result.parse_error = f"Malformed JSON in tool_call: {json_str[:200]}"
return result
name = obj.get("name", "")
args = obj.get("arguments", {})
if isinstance(args, str):
try:
args = json.loads(args)
except json.JSONDecodeError:
args = {}
result.is_tool_call = True
result.tool_name = name
result.tool_args = args if isinstance(args, dict) else {}
return _validate_tool(result, known_tools)
def _validate_tool(
result: ParsedAction,
known_tools: list[str] | None,
) -> ParsedAction:
"""Warn if the tool name is not in the known set."""
if known_tools and result.tool_name not in known_tools:
result.parse_error = (
f"Unknown tool '{result.tool_name}'. "
f"Available tools: {known_tools}"
)
return result
def _attempt_json_repair(s: str) -> dict | None:
"""Try common fixes for malformed JSON from LLM output.
Handles: trailing commas, single quotes, unquoted keys.
"""
# Remove trailing commas before } or ].
cleaned = re.sub(r",\s*([}\]])", r"\1", s)
# Replace single quotes with double quotes.
cleaned = cleaned.replace("'", '"')
try:
return json.loads(cleaned)
except json.JSONDecodeError:
pass
# Try wrapping in braces.
if not cleaned.strip().startswith("{"):
try:
return json.loads("{" + cleaned + "}")
except json.JSONDecodeError:
pass
return None
def _parse_kv_args(raw: str) -> dict[str, Any]:
"""Parse ``key=value, key=value`` style arguments.
Used as a last-resort fallback for ``Action: func(key=val, ...)`` patterns.
"""
if not raw:
return {}
args: dict[str, Any] = {}
# Split on commas that are not inside quotes.
parts = re.split(r',\s*(?=\w+=)', raw)
for part in parts:
if "=" not in part:
continue
key, _, val = part.partition("=")
key = key.strip().strip('"').strip("'")
val = val.strip().strip('"').strip("'")
# Attempt type coercion.
if val.lower() in ("true", "false"):
args[key] = val.lower() == "true"
else:
try:
args[key] = int(val)
except ValueError:
try:
args[key] = float(val)
except ValueError:
# Try JSON (for lists / dicts).
try:
args[key] = json.loads(val)
except (json.JSONDecodeError, ValueError):
args[key] = val
return args