general-reasoning-agent / test_validation.py
chmielvu's picture
feat: add production refinements (Phase 1-3)
4454066 verified
"""
Test script for output validation with Pydantic.
Tests:
1. Valid JSON parsing
2. Malformed JSON handling
3. Non-JSON string handling
4. Dict input validation
5. Workflow output validation
6. JSON repair strategies
"""
import sys
import json
from core.validation import (
ToolOutput,
WorkflowOutput,
validate_tool_output,
validate_workflow_output,
repair_json_output,
ensure_tool_output_schema
)
def test_valid_json():
"""Test valid JSON parsing"""
print("\n=== Test 1: Valid JSON ===")
valid_json = json.dumps({
"success": True,
"result": {"data": "test"},
"metadata": {"key": "value"}
})
validated = validate_tool_output(valid_json)
print(f"Success: {validated.success}")
print(f"Result: {validated.result}")
print(f"Metadata: {validated.metadata}")
assert validated.success is True
assert validated.result == {"data": "test"}
print("✓ Valid JSON test passed")
def test_malformed_json():
"""Test malformed JSON handling"""
print("\n=== Test 2: Malformed JSON ===")
malformed_cases = [
'{"success": true, "result": "missing closing brace"',
'{"success": true, "result": undefined}',
'{success: true}', # Missing quotes
'not json at all',
]
for i, case in enumerate(malformed_cases, 1):
print(f"\nCase {i}: {case[:50]}...")
validated = validate_tool_output(case)
print(f"Success: {validated.success}")
print(f"Error: {validated.error}")
# Should wrap in error format
assert validated.success is False or isinstance(validated.result, str)
print(f"✓ Case {i} handled correctly")
def test_dict_input():
"""Test dict input validation"""
print("\n=== Test 3: Dict Input ===")
valid_dict = {
"success": True,
"result": [1, 2, 3],
"metadata": {"source": "test"}
}
validated = validate_tool_output(valid_dict)
print(f"Success: {validated.success}")
print(f"Result: {validated.result}")
assert validated.success is True
assert validated.result == [1, 2, 3]
print("✓ Dict input test passed")
def test_invalid_schema():
"""Test invalid schema handling"""
print("\n=== Test 4: Invalid Schema ===")
# Missing required 'success' field
invalid_dict = {
"result": "data",
"metadata": {}
}
validated = validate_tool_output(invalid_dict)
print(f"Success: {validated.success}")
print(f"Error: {validated.error}")
print(f"Error type: {validated.error_type}")
# Should wrap in error format
assert validated.success is False
assert validated.error_type == "ValidationError"
print("✓ Invalid schema test passed")
def test_primitive_types():
"""Test primitive type handling"""
print("\n=== Test 5: Primitive Types ===")
test_cases = [
42,
"plain string",
True,
None,
[1, 2, 3],
]
for i, case in enumerate(test_cases, 1):
print(f"\nCase {i}: {case} ({type(case).__name__})")
validated = validate_tool_output(case)
print(f"Success: {validated.success}")
print(f"Result: {validated.result}")
# Should wrap as result
assert validated.success is True
assert validated.result == case
print(f"✓ Case {i} handled correctly")
def test_workflow_output():
"""Test workflow output validation"""
print("\n=== Test 6: Workflow Output ===")
valid_workflow = {
"success": True,
"result": "final result",
"execution_time": 1.5,
"trace": [{"task": "t1", "status": "completed"}],
"all_results": {"t1": "data"}
}
validated = validate_workflow_output(valid_workflow)
print(f"Success: {validated.success}")
print(f"Result: {validated.result}")
print(f"Execution time: {validated.execution_time}")
print(f"Trace: {validated.trace}")
assert validated.success is True
assert validated.execution_time == 1.5
print("✓ Workflow output test passed")
def test_workflow_output_invalid():
"""Test invalid workflow output"""
print("\n=== Test 7: Invalid Workflow Output ===")
# Invalid workflow (missing success field)
invalid_workflow = {
"result": "data"
}
validated = validate_workflow_output(invalid_workflow)
print(f"Success: {validated.success}")
print(f"Error: {validated.error}")
assert validated.success is False
print("✓ Invalid workflow output test passed")
def test_json_repair():
"""Test JSON repair strategies"""
print("\n=== Test 8: JSON Repair ===")
test_cases = [
# Valid JSON
('{"success": true, "result": "data"}', True),
# JSON embedded in text
('Some text before {"success": true, "result": "data"} and after', True),
# Plain text (no JSON)
('This is just plain text without any JSON', True),
]
for i, (case, should_work) in enumerate(test_cases, 1):
print(f"\nCase {i}: {case[:50]}...")
repaired = repair_json_output(case)
print(f"Repaired: {repaired}")
if should_work:
assert isinstance(repaired, dict)
assert "success" in repaired or "result" in repaired
print(f"✓ Case {i} repaired successfully")
def test_decorator():
"""Test ensure_tool_output_schema decorator"""
print("\n=== Test 9: Decorator ===")
@ensure_tool_output_schema
def mock_tool_success():
return {
"success": True,
"result": "test data"
}
@ensure_tool_output_schema
def mock_tool_error():
raise ValueError("Test error")
# Test success case
result = mock_tool_success()
print(f"Success result: {result[:100]}...")
parsed = json.loads(result)
assert parsed["success"] is True
# Test error case
error_result = mock_tool_error()
print(f"Error result: {error_result[:100]}...")
parsed = json.loads(error_result)
assert parsed["success"] is False
assert parsed["error_type"] == "ValueError"
print("✓ Decorator test passed")
def test_error_output_format():
"""Test error output format matches ToolOutput schema"""
print("\n=== Test 10: Error Output Format ===")
error_output = validate_tool_output("malformed json {{{")
print(f"Success: {error_output.success}")
print(f"Error: {error_output.error}")
print(f"Error type: {error_output.error_type}")
print(f"Recovery hint: {error_output.recovery_hint}")
print(f"Metadata: {error_output.metadata}")
# Ensure all error fields are populated
assert error_output.success is False
assert error_output.error is not None
assert error_output.error_type == "ValidationError"
assert error_output.recovery_hint is not None
print("✓ Error output format test passed")
def run_all_tests():
"""Run all validation tests"""
print("=" * 60)
print("Testing Output Validation with Pydantic")
print("=" * 60)
tests = [
test_valid_json,
test_malformed_json,
test_dict_input,
test_invalid_schema,
test_primitive_types,
test_workflow_output,
test_workflow_output_invalid,
test_json_repair,
test_decorator,
test_error_output_format,
]
passed = 0
failed = 0
for test in tests:
try:
test()
passed += 1
except AssertionError as e:
print(f"\n✗ {test.__name__} FAILED: {e}")
failed += 1
except Exception as e:
print(f"\n✗ {test.__name__} ERROR: {e}")
failed += 1
print("\n" + "=" * 60)
print(f"Test Results: {passed} passed, {failed} failed")
print("=" * 60)
return failed == 0
if __name__ == "__main__":
success = run_all_tests()
sys.exit(0 if success else 1)