| | |
| | """ |
| | Robust JSON Extraction from LLM Output |
| | ======================================= |
| | |
| | LLMs frequently wrap JSON in markdown, add conversational preamble/postamble, |
| | use Python-style booleans, or output malformed JSON. This module handles all |
| | of those cases with a multi-strategy approach. |
| | |
| | Extracted from production tool calling orchestrator. Battle-tested against |
| | Hermes-3, Llama 3.3, Qwen2, and Mistral models. |
| | |
| | Usage: |
| | from robust_json_extraction import extract_json, extract_tool_calls |
| | |
| | # Handle any LLM output format |
| | data = extract_json('Here is the result: ```json\\n{"key": "value"}\\n``` Hope that helps!') |
| | |
| | # Extract tool calls from Hermes-format XML |
| | calls = extract_tool_calls('<tool_call>{"name": "search", "arguments": {"q": "test"}}</tool_call>') |
| | """ |
| |
|
| | import json |
| | import re |
| | import ast |
| | import xml.etree.ElementTree as ET |
| | from json import JSONDecoder |
| | from typing import Any, Dict, List, Optional |
| |
|
| |
|
| | def extract_json(text: str) -> Any: |
| | """ |
| | Extract JSON from LLM output, handling common issues: |
| | 1. Markdown code blocks (```json ... ```) |
| | 2. Preamble text ("Here is the result: {...") |
| | 3. Postamble text ("...} Let me know if you need help!") |
| | 4. Python-style booleans (True/False/None instead of true/false/null) |
| | |
| | Returns parsed JSON data (dict, list, etc.) |
| | Raises json.JSONDecodeError if no valid JSON can be extracted. |
| | """ |
| | text = text.strip() |
| | if not text: |
| | raise json.JSONDecodeError("Empty input", text, 0) |
| |
|
| | |
| | if "```" in text: |
| | |
| | match = re.search(r'```(?:json)?\s*\n(.*?)\n```', text, re.DOTALL) |
| | if match: |
| | text = match.group(1).strip() |
| | else: |
| | |
| | start = text.find('```') |
| | if start != -1: |
| | first_newline = text.find('\n', start) |
| | if first_newline != -1: |
| | text = text[first_newline + 1:] |
| | if text.endswith("```"): |
| | text = text[:-3].strip() |
| |
|
| | |
| | if not text.startswith(('{', '[')): |
| | for char in ['{', '[']: |
| | idx = text.find(char) |
| | if idx != -1: |
| | text = text[idx:] |
| | break |
| |
|
| | |
| | try: |
| | return json.loads(text) |
| | except json.JSONDecodeError as original_error: |
| | |
| | decoder = JSONDecoder() |
| | try: |
| | data, _ = decoder.raw_decode(text) |
| | return data |
| | except json.JSONDecodeError: |
| | pass |
| |
|
| | |
| | try: |
| | fixed = text.replace('True', 'true').replace('False', 'false').replace('None', 'null') |
| | return json.loads(fixed) |
| | except json.JSONDecodeError: |
| | pass |
| |
|
| | |
| | try: |
| | fixed = text.replace('True', 'true').replace('False', 'false').replace('None', 'null') |
| | data, _ = decoder.raw_decode(fixed) |
| | return data |
| | except json.JSONDecodeError: |
| | pass |
| |
|
| | raise original_error |
| |
|
| |
|
| | def parse_single_call(json_text: str) -> Optional[Dict]: |
| | """ |
| | Parse a single tool call JSON using multiple strategies. |
| | Returns dict with 'name' and 'arguments' keys, or None if parsing fails. |
| | """ |
| | json_text = json_text.strip() |
| | if not json_text: |
| | return None |
| |
|
| | |
| | try: |
| | return json.loads(json_text) |
| | except json.JSONDecodeError: |
| | pass |
| |
|
| | |
| | try: |
| | python_text = json_text.replace('true', 'True').replace('false', 'False').replace('null', 'None') |
| | return ast.literal_eval(python_text) |
| | except (SyntaxError, ValueError): |
| | pass |
| |
|
| | |
| | try: |
| | fixed = json_text.replace("'", '"').replace('True', 'true').replace('False', 'false').replace('None', 'null') |
| | return json.loads(fixed) |
| | except (json.JSONDecodeError, ValueError): |
| | pass |
| |
|
| | |
| | name_match = re.search(r"['\"]?name['\"]?\s*:\s*['\"]([^'\"]+)['\"]", json_text) |
| | if name_match: |
| | name = name_match.group(1) |
| | arguments = {} |
| | args_match = re.search(r"['\"]?arguments['\"]?\s*:\s*(\{[^}]+\})", json_text) |
| | if args_match: |
| | try: |
| | arguments = json.loads(args_match.group(1)) |
| | except json.JSONDecodeError: |
| | try: |
| | arguments = ast.literal_eval(args_match.group(1)) |
| | except (SyntaxError, ValueError): |
| | pass |
| | return {"name": name, "arguments": arguments} |
| |
|
| | return None |
| |
|
| |
|
| | def extract_tool_calls(assistant_message: str) -> List[Dict]: |
| | """ |
| | Extract tool calls from an assistant message containing <tool_call> XML tags. |
| | |
| | Supports: |
| | - Single tool call: <tool_call>{"name": "fn", "arguments": {...}}</tool_call> |
| | - Nested format: <tool_call>{"tool_calls": [...]}</tool_call> |
| | - Multiple JSON objects in one block (line-by-line) |
| | - Malformed XML (regex fallback) |
| | |
| | Returns list of dicts, each with 'name' and 'arguments' keys. |
| | """ |
| | tool_calls = [] |
| |
|
| | |
| | try: |
| | xml_root = f"<root>{assistant_message}</root>" |
| | root = ET.fromstring(xml_root) |
| |
|
| | for element in root.findall(".//tool_call"): |
| | raw_text = (element.text or "").strip() |
| | if not raw_text: |
| | continue |
| |
|
| | |
| | json_data = parse_single_call(raw_text) |
| |
|
| | if json_data: |
| | |
| | if isinstance(json_data, dict) and 'tool_calls' in json_data: |
| | nested = json_data.get('tool_calls', []) |
| | if isinstance(nested, list): |
| | tool_calls.extend(nested) |
| | elif isinstance(json_data, dict) and 'name' in json_data: |
| | tool_calls.append(json_data) |
| | else: |
| | |
| | for line in raw_text.split('\n'): |
| | line = line.strip() |
| | if line.startswith('{'): |
| | parsed = parse_single_call(line) |
| | if parsed: |
| | tool_calls.append(parsed) |
| |
|
| | except ET.ParseError: |
| | |
| | pattern = re.compile(r'<tool_call>(.*?)</tool_call>', re.DOTALL) |
| | for match in pattern.findall(assistant_message): |
| | raw_text = match.strip() |
| | json_data = parse_single_call(raw_text) |
| | if json_data: |
| | if isinstance(json_data, dict) and 'tool_calls' in json_data: |
| | tool_calls.extend(json_data.get('tool_calls', [])) |
| | elif isinstance(json_data, dict) and 'name' in json_data: |
| | tool_calls.append(json_data) |
| |
|
| | return tool_calls |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | if __name__ == "__main__": |
| | print("=" * 60) |
| | print("Testing robust JSON extraction") |
| | print("=" * 60) |
| |
|
| | |
| | assert extract_json('{"key": "value"}') == {"key": "value"} |
| | print(" [PASS] Clean JSON") |
| |
|
| | |
| | assert extract_json('```json\n{"key": "value"}\n```') == {"key": "value"} |
| | print(" [PASS] Markdown-wrapped JSON") |
| |
|
| | |
| | assert extract_json('Here is the result: {"key": "value"}') == {"key": "value"} |
| | print(" [PASS] Preamble text") |
| |
|
| | |
| | assert extract_json('{"key": "value"} Hope that helps!') == {"key": "value"} |
| | print(" [PASS] Postamble text") |
| |
|
| | |
| | result = extract_json('Sure! ```json\n{"key": "value"}\n``` Let me know!') |
| | assert result == {"key": "value"} |
| | print(" [PASS] Preamble + markdown + postamble") |
| |
|
| | |
| | assert extract_json('{"active": True, "deleted": False, "value": None}') == { |
| | "active": True, "deleted": False, "value": None |
| | } |
| | print(" [PASS] Python-style booleans") |
| |
|
| | print("\n" + "=" * 60) |
| | print("Testing tool call extraction") |
| | print("=" * 60) |
| |
|
| | |
| | calls = extract_tool_calls('<tool_call>{"name": "search", "arguments": {"q": "test"}}</tool_call>') |
| | assert len(calls) == 1 |
| | assert calls[0]["name"] == "search" |
| | print(" [PASS] Single tool call") |
| |
|
| | |
| | calls = extract_tool_calls( |
| | '<tool_call>{"tool_calls": [{"name": "a", "arguments": {}}, {"name": "b", "arguments": {}}]}</tool_call>' |
| | ) |
| | assert len(calls) == 2 |
| | print(" [PASS] Nested tool_calls array") |
| |
|
| | |
| | calls = extract_tool_calls( |
| | 'I will search for that.\n<tool_call>\n{"name": "search", "arguments": {"q": "hello"}}\n</tool_call>\nDone.' |
| | ) |
| | assert len(calls) == 1 |
| | assert calls[0]["name"] == "search" |
| | print(" [PASS] Mixed content with tool call") |
| |
|
| | print("\nAll tests passed.") |
| |
|