ashish-sarvam's picture
Upload folder using huggingface_hub
fc1a684 verified
"""
Simplified two-stage parsing for LLM responses.
Stage 1: normalize_llm_response() - Clean and extract valid JSON
Stage 2: parse_action() - Detect tool/KB actions from normalized JSON
"""
import json
import re
from typing import Any, Dict
def normalize_llm_response(reply: str) -> Dict[str, Any]:
"""
Normalize LLM response to valid JSON.
Handles:
- Chat wrappers: {"role": "...", "content": "..."}
- Code fences: ```json ... ```
- Labels: "Agent:", "Assistant:"
- Plain text (returns as {"text": "..."})
Args:
reply: Raw LLM response string
Returns:
Dict with at least {"text": "..."} key
"""
s = (reply or "").strip()
if not s:
return {"text": ""}
# Try to parse as JSON directly
try:
obj = json.loads(s)
if isinstance(obj, dict):
# Handle chat wrapper: {"role": "...", "content": "..."}
if "content" in obj and isinstance(obj.get("content"), str):
s = obj["content"].strip()
# Recursively process the content
return normalize_llm_response(s)
# Already valid JSON dict - return as-is
return obj
except json.JSONDecodeError:
pass
# Strip code fences: ```json ... ``` or ``` ... ```
if s.startswith("```") and s.endswith("```"):
s = re.sub(
r"^```(?:json|python)?\s*|\s*```$", "", s, flags=re.IGNORECASE
).strip()
# Strip leading labels: "Agent:", "Assistant:", "User:"
s = re.sub(
r"^\s*(agent|assistant|user)\s*:\s*", "", s, flags=re.IGNORECASE
).strip()
# Try parsing again after cleaning
if s.startswith("{") and s.endswith("}"):
try:
obj = json.loads(s)
if isinstance(obj, dict):
return obj
except json.JSONDecodeError:
pass
# Find first balanced JSON object in the string
start = s.find("{")
if start != -1:
depth = 0
for i in range(start, len(s)):
if s[i] == "{":
depth += 1
elif s[i] == "}":
depth -= 1
if depth == 0:
try:
obj = json.loads(s[start : i + 1])
if isinstance(obj, dict):
return obj
except json.JSONDecodeError:
pass
break
# Fallback: wrap plain text
return {"text": s}
def parse_action(normalized_response: Dict[str, Any]) -> Dict[str, Any]:
"""
Parse normalized JSON to detect tool calls, KB queries, or plain text.
Expected formats:
Tool execution:
{
"text": "Let me check that...",
"tool_execution": [
{"function": "...", "params": {...}},
...
]
}
KB retrieval:
{
"text": "Let me look that up...",
"kb_retrieval": {
"query": "...",
"kb_name": "..." # optional
}
}
Plain text:
{
"text": "Here's the answer..."
}
Args:
normalized_response: Normalized JSON dict from stage 1
Returns:
Dict with:
- type: "tool_execution" | "kb_retrieval" | "text_only"
- Additional fields based on type
"""
if not isinstance(normalized_response, dict):
return {
"type": "text_only",
"text": str(normalized_response),
}
# Check for KB retrieval
if "kb_retrieval" in normalized_response:
kb_obj = normalized_response.get("kb_retrieval")
if isinstance(kb_obj, dict):
query = kb_obj.get("query", "").strip()
kb_name = kb_obj.get("kb_name", "").strip() or None
pre_text = normalized_response.get("text", "").strip()
if query: # Valid KB query
return {
"type": "kb_retrieval",
"query": query,
"kb_name": kb_name,
"pre_text": pre_text,
}
# Check for tool execution
if "tool_execution" in normalized_response:
tool_exec = normalized_response.get("tool_execution")
if isinstance(tool_exec, list) and len(tool_exec) > 0:
pre_text = normalized_response.get("text", "").strip()
return {
"type": "tool_execution",
"executions": tool_exec,
"pre_text": pre_text,
}
# Plain text (or invalid format)
text = normalized_response.get("text", "").strip()
if not text:
# If no text field, serialize the whole dict as text
text = json.dumps(normalized_response)
return {
"type": "text_only",
"text": text,
}
def extract_text(normalized_response: Dict[str, Any]) -> str:
"""
Extract just the text content from a normalized response.
Args:
normalized_response: Normalized JSON dict
Returns:
Text string
"""
if isinstance(normalized_response, dict):
return normalized_response.get("text", "").strip()
return str(normalized_response).strip()
def extract_text_from_llm_response(reply: str) -> str:
"""
Convenience function: normalize LLM response and extract text in one call.
This is useful when you just need the text content without caring about
tool/KB actions.
Args:
reply: Raw LLM response string
Returns:
Extracted text string
"""
normalized = normalize_llm_response(reply)
return extract_text(normalized)
def serialize_memory(memory: Any) -> str:
try:
if isinstance(memory, (dict, list)):
return json.dumps(memory, ensure_ascii=False)
return str(memory)
except Exception:
return str(memory)