""" Model output parser. Responsible for extracting structured information from the raw text generated by the LLM: 1. Detect tool calls -- look for ``...`` XML tags (Qwen style) or bare JSON objects with ``name``/``arguments`` keys. 2. Extract the tool name and argument dict. 3. Validate that the tool name is known and arguments match the schema. 4. Detect a final answer (no tool call present). 5. Handle malformed output via regex fallback and heuristic correction. """ from __future__ import annotations import json import logging import re from dataclasses import dataclass, field from typing import Any logger = logging.getLogger(__name__) # ====================================================================== # # Parsed result types # ====================================================================== # @dataclass class ParsedAction: """Result of parsing a single model response.""" is_tool_call: bool = False tool_name: str = "" tool_args: dict[str, Any] = field(default_factory=dict) thought: str = "" answer: str = "" raw: str = "" parse_error: str | None = None # ====================================================================== # # Regex patterns # ====================================================================== # # {"name": "...", "arguments": {...}} _TOOL_CALL_XML_RE = re.compile( r"\s*(\{.*?\})\s*", re.DOTALL, ) # Bare JSON with "name" and "arguments" keys (fallback). _TOOL_CALL_JSON_RE = re.compile( r'\{\s*"name"\s*:\s*"(\w+)"\s*,\s*"arguments"\s*:\s*(\{.*?\})\s*\}', re.DOTALL, ) # Thought: ... (everything before the tool call or answer). _THOUGHT_RE = re.compile( r"(?:Thought|Thinking|Reasoning)\s*:\s*(.+?)(?=\n\s*(?:Action|Answer||\{)|\Z)", re.DOTALL | re.IGNORECASE, ) # Answer: ... or Final Answer: ... _ANSWER_RE = re.compile( r"(?:Final\s+)?Answer\s*:\s*(.+)", re.DOTALL | re.IGNORECASE, ) # Action: tool_name(arg=value, ...) -- relaxed fallback. _ACTION_FUNC_RE = re.compile( r"Action\s*:\s*(\w+)\s*\(([^)]*)\)", re.IGNORECASE, ) # ====================================================================== # # Public API # ====================================================================== # def parse_model_output( text: str, known_tools: list[str] | None = None, ) -> ParsedAction: """Parse a model completion into a structured ``ParsedAction``. Parameters ---------- text: Raw model output. known_tools: Optional list of registered tool names for validation. Returns ------- ParsedAction """ result = ParsedAction(raw=text) # Extract thought. thought_match = _THOUGHT_RE.search(text) if thought_match: result.thought = thought_match.group(1).strip() # --- Strategy 1: XML-tagged tool call --- xml_match = _TOOL_CALL_XML_RE.search(text) if xml_match: return _parse_json_tool_call(xml_match.group(1), result, known_tools) # --- Strategy 2: Bare JSON tool call --- json_match = _TOOL_CALL_JSON_RE.search(text) if json_match: name = json_match.group(1) args_str = json_match.group(2) try: args = json.loads(args_str) except json.JSONDecodeError: args = _attempt_json_repair(args_str) if args is not None: result.is_tool_call = True result.tool_name = name result.tool_args = args return _validate_tool(result, known_tools) # --- Strategy 3: Action: func(args) fallback --- func_match = _ACTION_FUNC_RE.search(text) if func_match: name = func_match.group(1) args_raw = func_match.group(2).strip() args = _parse_kv_args(args_raw) result.is_tool_call = True result.tool_name = name result.tool_args = args return _validate_tool(result, known_tools) # --- No tool call detected -- look for an answer --- answer_match = _ANSWER_RE.search(text) if answer_match: result.answer = answer_match.group(1).strip() return result # If the model produced text but no recognizable pattern, treat the # entire output as the answer (common for simple questions). if text.strip(): # But only if there is no apparent intent to call a tool. if not any(kw in text.lower() for kw in ("tool_call", "action:", "function")): result.answer = text.strip() else: result.parse_error = ( "Could not parse tool call from model output. " "Please use the format: {\"name\": \"tool_name\", " "\"arguments\": {...}}" ) return result # ====================================================================== # # Internal helpers # ====================================================================== # def _parse_json_tool_call( json_str: str, result: ParsedAction, known_tools: list[str] | None, ) -> ParsedAction: """Parse a JSON tool-call object.""" try: obj = json.loads(json_str) except json.JSONDecodeError: obj = _attempt_json_repair(json_str) if obj is None: result.parse_error = f"Malformed JSON in tool_call: {json_str[:200]}" return result name = obj.get("name", "") args = obj.get("arguments", {}) if isinstance(args, str): try: args = json.loads(args) except json.JSONDecodeError: args = {} result.is_tool_call = True result.tool_name = name result.tool_args = args if isinstance(args, dict) else {} return _validate_tool(result, known_tools) def _validate_tool( result: ParsedAction, known_tools: list[str] | None, ) -> ParsedAction: """Warn if the tool name is not in the known set.""" if known_tools and result.tool_name not in known_tools: result.parse_error = ( f"Unknown tool '{result.tool_name}'. " f"Available tools: {known_tools}" ) return result def _attempt_json_repair(s: str) -> dict | None: """Try common fixes for malformed JSON from LLM output. Handles: trailing commas, single quotes, unquoted keys. """ # Remove trailing commas before } or ]. cleaned = re.sub(r",\s*([}\]])", r"\1", s) # Replace single quotes with double quotes. cleaned = cleaned.replace("'", '"') try: return json.loads(cleaned) except json.JSONDecodeError: pass # Try wrapping in braces. if not cleaned.strip().startswith("{"): try: return json.loads("{" + cleaned + "}") except json.JSONDecodeError: pass return None def _parse_kv_args(raw: str) -> dict[str, Any]: """Parse ``key=value, key=value`` style arguments. Used as a last-resort fallback for ``Action: func(key=val, ...)`` patterns. """ if not raw: return {} args: dict[str, Any] = {} # Split on commas that are not inside quotes. parts = re.split(r',\s*(?=\w+=)', raw) for part in parts: if "=" not in part: continue key, _, val = part.partition("=") key = key.strip().strip('"').strip("'") val = val.strip().strip('"').strip("'") # Attempt type coercion. if val.lower() in ("true", "false"): args[key] = val.lower() == "true" else: try: args[key] = int(val) except ValueError: try: args[key] = float(val) except ValueError: # Try JSON (for lists / dicts). try: args[key] = json.loads(val) except (json.JSONDecodeError, ValueError): args[key] = val return args