Spaces:
Sleeping
Sleeping
| """ | |
| Pydantic validation for tool and workflow outputs. | |
| Handles output validation with automatic JSON repair for malformed responses. | |
| """ | |
| from pydantic import BaseModel, Field, ValidationError, ConfigDict | |
| from typing import Any, Dict, Optional | |
| import json | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class ToolOutput(BaseModel): | |
| """ | |
| Standard tool output format. | |
| All tools should return JSON that conforms to this schema. | |
| """ | |
| model_config = ConfigDict(extra="allow") # Allow additional fields for flexibility | |
| success: bool | |
| result: Optional[Any] = None | |
| error: Optional[str] = None | |
| error_type: Optional[str] = None | |
| recovery_hint: Optional[str] = None | |
| fallback_action: Optional[str] = None | |
| metadata: Optional[Dict[str, Any]] = None | |
| class WorkflowOutput(BaseModel): | |
| """ | |
| Workflow execution output format. | |
| Returned by WorkflowExecutor after executing a workflow. | |
| """ | |
| model_config = ConfigDict(extra="allow") | |
| success: bool | |
| result: Optional[Any] = None | |
| execution_time: Optional[float] = None | |
| trace: Optional[list] = None | |
| all_results: Optional[Dict[str, Any]] = None | |
| error: Optional[str] = None | |
| error_type: Optional[str] = None | |
| def validate_tool_output(raw_output: Any) -> ToolOutput: | |
| """ | |
| Validate and parse tool output with automatic repair. | |
| Attempts multiple strategies: | |
| 1. Direct parsing if already a dict | |
| 2. JSON parsing if string | |
| 3. Wrap in error format if validation fails | |
| Args: | |
| raw_output: Raw tool output (str, dict, or other) | |
| Returns: | |
| Validated ToolOutput instance | |
| """ | |
| try: | |
| # Strategy 1: Already a dict | |
| if isinstance(raw_output, dict): | |
| return ToolOutput(**raw_output) | |
| # Strategy 2: JSON string | |
| if isinstance(raw_output, str): | |
| try: | |
| data = json.loads(raw_output) | |
| return ToolOutput(**data) | |
| except json.JSONDecodeError as e: | |
| logger.warning(f"JSON decode failed: {e}") | |
| # Fall through to repair strategy | |
| # Strategy 3: Other types (int, bool, etc.) - wrap as result | |
| return ToolOutput( | |
| success=True, | |
| result=raw_output, | |
| metadata={"original_type": type(raw_output).__name__} | |
| ) | |
| except ValidationError as e: | |
| logger.warning(f"Tool output validation failed: {e}") | |
| # Repair strategy: wrap in error format | |
| return ToolOutput( | |
| success=False, | |
| error=f"Invalid tool output format: {str(e)}", | |
| error_type="ValidationError", | |
| recovery_hint="Tool returned malformed output - expected ToolOutput schema", | |
| metadata={ | |
| "raw_output": str(raw_output)[:500], # Truncate to prevent huge logs | |
| "validation_errors": str(e) | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"Unexpected error validating tool output: {e}", exc_info=True) | |
| return ToolOutput( | |
| success=False, | |
| error=f"Validation error: {str(e)}", | |
| error_type=type(e).__name__, | |
| recovery_hint="Unexpected validation failure", | |
| metadata={"raw_output": str(raw_output)[:500]} | |
| ) | |
| def validate_workflow_output(raw_output: Any) -> WorkflowOutput: | |
| """ | |
| Validate workflow output with automatic repair. | |
| Args: | |
| raw_output: Raw workflow output (dict expected) | |
| Returns: | |
| Validated WorkflowOutput instance | |
| """ | |
| try: | |
| # Workflow outputs should always be dicts | |
| if not isinstance(raw_output, dict): | |
| raise ValueError(f"Workflow output must be dict, got {type(raw_output).__name__}") | |
| return WorkflowOutput(**raw_output) | |
| except ValidationError as e: | |
| logger.error(f"Workflow output validation failed: {e}") | |
| return WorkflowOutput( | |
| success=False, | |
| error=f"Invalid workflow output format: {str(e)}", | |
| error_type="ValidationError", | |
| metadata={ | |
| "raw_output": str(raw_output)[:500], | |
| "validation_errors": str(e) | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"Unexpected error validating workflow output: {e}", exc_info=True) | |
| return WorkflowOutput( | |
| success=False, | |
| error=f"Validation error: {str(e)}", | |
| error_type=type(e).__name__, | |
| metadata={"raw_output": str(raw_output)[:500]} | |
| ) | |
| def repair_json_output(raw_output: str) -> Dict[str, Any]: | |
| """ | |
| Attempt to repair malformed JSON output. | |
| Common repair strategies: | |
| 1. Fix common JSON errors (trailing commas, quotes) | |
| 2. Extract JSON from mixed text | |
| 3. Wrap plain text in result field | |
| Args: | |
| raw_output: Raw string output | |
| Returns: | |
| Dict that can be validated as ToolOutput | |
| """ | |
| # Strategy 1: Try direct JSON parse | |
| try: | |
| return json.loads(raw_output) | |
| except json.JSONDecodeError: | |
| pass | |
| # Strategy 2: Extract JSON from text (find first { or [) | |
| try: | |
| start_brace = raw_output.find('{') | |
| start_bracket = raw_output.find('[') | |
| start_pos = -1 | |
| if start_brace >= 0 and start_bracket >= 0: | |
| start_pos = min(start_brace, start_bracket) | |
| elif start_brace >= 0: | |
| start_pos = start_brace | |
| elif start_bracket >= 0: | |
| start_pos = start_bracket | |
| if start_pos >= 0: | |
| # Find matching closing brace/bracket | |
| extracted = raw_output[start_pos:] | |
| # Try to parse | |
| return json.loads(extracted) | |
| except (json.JSONDecodeError, ValueError): | |
| pass | |
| # Strategy 3: Wrap as plain text result | |
| logger.warning("Could not parse as JSON - wrapping as plain text") | |
| return { | |
| "success": True, | |
| "result": raw_output, | |
| "metadata": {"repaired": True, "original_format": "plain_text"} | |
| } | |
| def ensure_tool_output_schema(func): | |
| """ | |
| Decorator to ensure tool output conforms to ToolOutput schema. | |
| Wraps any tool forward() method to validate output. | |
| Usage: | |
| @ensure_tool_output_schema | |
| def forward(self, query: str) -> str: | |
| ... | |
| """ | |
| import functools | |
| def wrapper(*args, **kwargs): | |
| try: | |
| result = func(*args, **kwargs) | |
| # Validate and return JSON | |
| validated = validate_tool_output(result) | |
| return validated.model_dump_json(indent=2) | |
| except Exception as e: | |
| logger.error(f"Tool execution failed: {e}", exc_info=True) | |
| # Return error in standard format | |
| error_output = ToolOutput( | |
| success=False, | |
| error=str(e), | |
| error_type=type(e).__name__, | |
| recovery_hint="Tool execution raised exception" | |
| ) | |
| return error_output.model_dump_json(indent=2) | |
| return wrapper | |