File size: 9,548 Bytes
634c038 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 | #!/usr/bin/env python3
"""
Robust JSON Extraction from LLM Output
=======================================
LLMs frequently wrap JSON in markdown, add conversational preamble/postamble,
use Python-style booleans, or output malformed JSON. This module handles all
of those cases with a multi-strategy approach.
Extracted from production tool calling orchestrator. Battle-tested against
Hermes-3, Llama 3.3, Qwen2, and Mistral models.
Usage:
from robust_json_extraction import extract_json, extract_tool_calls
# Handle any LLM output format
data = extract_json('Here is the result: ```json\\n{"key": "value"}\\n``` Hope that helps!')
# Extract tool calls from Hermes-format XML
calls = extract_tool_calls('<tool_call>{"name": "search", "arguments": {"q": "test"}}</tool_call>')
"""
import json
import re
import ast
import xml.etree.ElementTree as ET
from json import JSONDecoder
from typing import Any, Dict, List, Optional
def extract_json(text: str) -> Any:
"""
Extract JSON from LLM output, handling common issues:
1. Markdown code blocks (```json ... ```)
2. Preamble text ("Here is the result: {...")
3. Postamble text ("...} Let me know if you need help!")
4. Python-style booleans (True/False/None instead of true/false/null)
Returns parsed JSON data (dict, list, etc.)
Raises json.JSONDecodeError if no valid JSON can be extracted.
"""
text = text.strip()
if not text:
raise json.JSONDecodeError("Empty input", text, 0)
# Layer 1: Strip markdown code blocks
if "```" in text:
# Match ```json or ``` followed by content until closing ```
match = re.search(r'```(?:json)?\s*\n(.*?)\n```', text, re.DOTALL)
if match:
text = match.group(1).strip()
else:
# Handle unclosed code blocks
start = text.find('```')
if start != -1:
first_newline = text.find('\n', start)
if first_newline != -1:
text = text[first_newline + 1:]
if text.endswith("```"):
text = text[:-3].strip()
# Layer 2: Find first { or [ (skip preamble)
if not text.startswith(('{', '[')):
for char in ['{', '[']:
idx = text.find(char)
if idx != -1:
text = text[idx:]
break
# Layer 3: Try parsing, with raw_decode fallback for postamble
try:
return json.loads(text)
except json.JSONDecodeError as original_error:
# Try raw_decode — stops at end of valid JSON, ignoring trailing text
decoder = JSONDecoder()
try:
data, _ = decoder.raw_decode(text)
return data
except json.JSONDecodeError:
pass
# Try fixing Python-style booleans/None
try:
fixed = text.replace('True', 'true').replace('False', 'false').replace('None', 'null')
return json.loads(fixed)
except json.JSONDecodeError:
pass
# Try raw_decode on fixed text
try:
fixed = text.replace('True', 'true').replace('False', 'false').replace('None', 'null')
data, _ = decoder.raw_decode(fixed)
return data
except json.JSONDecodeError:
pass
raise original_error
def parse_single_call(json_text: str) -> Optional[Dict]:
"""
Parse a single tool call JSON using multiple strategies.
Returns dict with 'name' and 'arguments' keys, or None if parsing fails.
"""
json_text = json_text.strip()
if not json_text:
return None
# Strategy 1: Standard JSON
try:
return json.loads(json_text)
except json.JSONDecodeError:
pass
# Strategy 2: Fix JS booleans and use ast.literal_eval
try:
python_text = json_text.replace('true', 'True').replace('false', 'False').replace('null', 'None')
return ast.literal_eval(python_text)
except (SyntaxError, ValueError):
pass
# Strategy 3: Fix Python->JSON issues (single quotes, capitalized booleans)
try:
fixed = json_text.replace("'", '"').replace('True', 'true').replace('False', 'false').replace('None', 'null')
return json.loads(fixed)
except (json.JSONDecodeError, ValueError):
pass
# Strategy 4: Regex extraction as last resort
name_match = re.search(r"['\"]?name['\"]?\s*:\s*['\"]([^'\"]+)['\"]", json_text)
if name_match:
name = name_match.group(1)
arguments = {}
args_match = re.search(r"['\"]?arguments['\"]?\s*:\s*(\{[^}]+\})", json_text)
if args_match:
try:
arguments = json.loads(args_match.group(1))
except json.JSONDecodeError:
try:
arguments = ast.literal_eval(args_match.group(1))
except (SyntaxError, ValueError):
pass
return {"name": name, "arguments": arguments}
return None
def extract_tool_calls(assistant_message: str) -> List[Dict]:
"""
Extract tool calls from an assistant message containing <tool_call> XML tags.
Supports:
- Single tool call: <tool_call>{"name": "fn", "arguments": {...}}</tool_call>
- Nested format: <tool_call>{"tool_calls": [...]}</tool_call>
- Multiple JSON objects in one block (line-by-line)
- Malformed XML (regex fallback)
Returns list of dicts, each with 'name' and 'arguments' keys.
"""
tool_calls = []
# Try XML parsing first
try:
xml_root = f"<root>{assistant_message}</root>"
root = ET.fromstring(xml_root)
for element in root.findall(".//tool_call"):
raw_text = (element.text or "").strip()
if not raw_text:
continue
# Try parsing as single JSON object
json_data = parse_single_call(raw_text)
if json_data:
# Check for nested tool_calls array
if isinstance(json_data, dict) and 'tool_calls' in json_data:
nested = json_data.get('tool_calls', [])
if isinstance(nested, list):
tool_calls.extend(nested)
elif isinstance(json_data, dict) and 'name' in json_data:
tool_calls.append(json_data)
else:
# Fallback: line-by-line parsing
for line in raw_text.split('\n'):
line = line.strip()
if line.startswith('{'):
parsed = parse_single_call(line)
if parsed:
tool_calls.append(parsed)
except ET.ParseError:
# Regex fallback for malformed XML
pattern = re.compile(r'<tool_call>(.*?)</tool_call>', re.DOTALL)
for match in pattern.findall(assistant_message):
raw_text = match.strip()
json_data = parse_single_call(raw_text)
if json_data:
if isinstance(json_data, dict) and 'tool_calls' in json_data:
tool_calls.extend(json_data.get('tool_calls', []))
elif isinstance(json_data, dict) and 'name' in json_data:
tool_calls.append(json_data)
return tool_calls
# ============================================================================
# Examples / Self-test
# ============================================================================
if __name__ == "__main__":
print("=" * 60)
print("Testing robust JSON extraction")
print("=" * 60)
# Test 1: Clean JSON
assert extract_json('{"key": "value"}') == {"key": "value"}
print(" [PASS] Clean JSON")
# Test 2: Markdown-wrapped JSON
assert extract_json('```json\n{"key": "value"}\n```') == {"key": "value"}
print(" [PASS] Markdown-wrapped JSON")
# Test 3: Preamble + JSON
assert extract_json('Here is the result: {"key": "value"}') == {"key": "value"}
print(" [PASS] Preamble text")
# Test 4: JSON + postamble
assert extract_json('{"key": "value"} Hope that helps!') == {"key": "value"}
print(" [PASS] Postamble text")
# Test 5: Preamble + markdown + postamble
result = extract_json('Sure! ```json\n{"key": "value"}\n``` Let me know!')
assert result == {"key": "value"}
print(" [PASS] Preamble + markdown + postamble")
# Test 6: Python-style booleans
assert extract_json('{"active": True, "deleted": False, "value": None}') == {
"active": True, "deleted": False, "value": None
}
print(" [PASS] Python-style booleans")
print("\n" + "=" * 60)
print("Testing tool call extraction")
print("=" * 60)
# Test 7: Single tool call
calls = extract_tool_calls('<tool_call>{"name": "search", "arguments": {"q": "test"}}</tool_call>')
assert len(calls) == 1
assert calls[0]["name"] == "search"
print(" [PASS] Single tool call")
# Test 8: Nested tool_calls array
calls = extract_tool_calls(
'<tool_call>{"tool_calls": [{"name": "a", "arguments": {}}, {"name": "b", "arguments": {}}]}</tool_call>'
)
assert len(calls) == 2
print(" [PASS] Nested tool_calls array")
# Test 9: Mixed content
calls = extract_tool_calls(
'I will search for that.\n<tool_call>\n{"name": "search", "arguments": {"q": "hello"}}\n</tool_call>\nDone.'
)
assert len(calls) == 1
assert calls[0]["name"] == "search"
print(" [PASS] Mixed content with tool call")
print("\nAll tests passed.")
|