chmielvu's picture
feat: add production refinements (Phase 1-3)
4454066 verified
"""
Pydantic validation for tool and workflow outputs.
Handles output validation with automatic JSON repair for malformed responses.
"""
from pydantic import BaseModel, Field, ValidationError, ConfigDict
from typing import Any, Dict, Optional
import json
import logging
logger = logging.getLogger(__name__)
class ToolOutput(BaseModel):
"""
Standard tool output format.
All tools should return JSON that conforms to this schema.
"""
model_config = ConfigDict(extra="allow") # Allow additional fields for flexibility
success: bool
result: Optional[Any] = None
error: Optional[str] = None
error_type: Optional[str] = None
recovery_hint: Optional[str] = None
fallback_action: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None
class WorkflowOutput(BaseModel):
"""
Workflow execution output format.
Returned by WorkflowExecutor after executing a workflow.
"""
model_config = ConfigDict(extra="allow")
success: bool
result: Optional[Any] = None
execution_time: Optional[float] = None
trace: Optional[list] = None
all_results: Optional[Dict[str, Any]] = None
error: Optional[str] = None
error_type: Optional[str] = None
def validate_tool_output(raw_output: Any) -> ToolOutput:
"""
Validate and parse tool output with automatic repair.
Attempts multiple strategies:
1. Direct parsing if already a dict
2. JSON parsing if string
3. Wrap in error format if validation fails
Args:
raw_output: Raw tool output (str, dict, or other)
Returns:
Validated ToolOutput instance
"""
try:
# Strategy 1: Already a dict
if isinstance(raw_output, dict):
return ToolOutput(**raw_output)
# Strategy 2: JSON string
if isinstance(raw_output, str):
try:
data = json.loads(raw_output)
return ToolOutput(**data)
except json.JSONDecodeError as e:
logger.warning(f"JSON decode failed: {e}")
# Fall through to repair strategy
# Strategy 3: Other types (int, bool, etc.) - wrap as result
return ToolOutput(
success=True,
result=raw_output,
metadata={"original_type": type(raw_output).__name__}
)
except ValidationError as e:
logger.warning(f"Tool output validation failed: {e}")
# Repair strategy: wrap in error format
return ToolOutput(
success=False,
error=f"Invalid tool output format: {str(e)}",
error_type="ValidationError",
recovery_hint="Tool returned malformed output - expected ToolOutput schema",
metadata={
"raw_output": str(raw_output)[:500], # Truncate to prevent huge logs
"validation_errors": str(e)
}
)
except Exception as e:
logger.error(f"Unexpected error validating tool output: {e}", exc_info=True)
return ToolOutput(
success=False,
error=f"Validation error: {str(e)}",
error_type=type(e).__name__,
recovery_hint="Unexpected validation failure",
metadata={"raw_output": str(raw_output)[:500]}
)
def validate_workflow_output(raw_output: Any) -> WorkflowOutput:
"""
Validate workflow output with automatic repair.
Args:
raw_output: Raw workflow output (dict expected)
Returns:
Validated WorkflowOutput instance
"""
try:
# Workflow outputs should always be dicts
if not isinstance(raw_output, dict):
raise ValueError(f"Workflow output must be dict, got {type(raw_output).__name__}")
return WorkflowOutput(**raw_output)
except ValidationError as e:
logger.error(f"Workflow output validation failed: {e}")
return WorkflowOutput(
success=False,
error=f"Invalid workflow output format: {str(e)}",
error_type="ValidationError",
metadata={
"raw_output": str(raw_output)[:500],
"validation_errors": str(e)
}
)
except Exception as e:
logger.error(f"Unexpected error validating workflow output: {e}", exc_info=True)
return WorkflowOutput(
success=False,
error=f"Validation error: {str(e)}",
error_type=type(e).__name__,
metadata={"raw_output": str(raw_output)[:500]}
)
def repair_json_output(raw_output: str) -> Dict[str, Any]:
"""
Attempt to repair malformed JSON output.
Common repair strategies:
1. Fix common JSON errors (trailing commas, quotes)
2. Extract JSON from mixed text
3. Wrap plain text in result field
Args:
raw_output: Raw string output
Returns:
Dict that can be validated as ToolOutput
"""
# Strategy 1: Try direct JSON parse
try:
return json.loads(raw_output)
except json.JSONDecodeError:
pass
# Strategy 2: Extract JSON from text (find first { or [)
try:
start_brace = raw_output.find('{')
start_bracket = raw_output.find('[')
start_pos = -1
if start_brace >= 0 and start_bracket >= 0:
start_pos = min(start_brace, start_bracket)
elif start_brace >= 0:
start_pos = start_brace
elif start_bracket >= 0:
start_pos = start_bracket
if start_pos >= 0:
# Find matching closing brace/bracket
extracted = raw_output[start_pos:]
# Try to parse
return json.loads(extracted)
except (json.JSONDecodeError, ValueError):
pass
# Strategy 3: Wrap as plain text result
logger.warning("Could not parse as JSON - wrapping as plain text")
return {
"success": True,
"result": raw_output,
"metadata": {"repaired": True, "original_format": "plain_text"}
}
def ensure_tool_output_schema(func):
"""
Decorator to ensure tool output conforms to ToolOutput schema.
Wraps any tool forward() method to validate output.
Usage:
@ensure_tool_output_schema
def forward(self, query: str) -> str:
...
"""
import functools
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
result = func(*args, **kwargs)
# Validate and return JSON
validated = validate_tool_output(result)
return validated.model_dump_json(indent=2)
except Exception as e:
logger.error(f"Tool execution failed: {e}", exc_info=True)
# Return error in standard format
error_output = ToolOutput(
success=False,
error=str(e),
error_type=type(e).__name__,
recovery_hint="Tool execution raised exception"
)
return error_output.model_dump_json(indent=2)
return wrapper