agentbee / test /test_calculator.py
mangubee's picture
fix: correct author name formatting in multiple files
e7b4937
"""
Tests for calculator tool (safe mathematical evaluation)
Author: @mangubee
Date: 2026-01-02
Tests cover:
- Basic arithmetic operations
- Mathematical functions
- Safety checks (no code execution, no imports, etc.)
- Timeout protection
- Complexity limits
- Error handling
"""
import pytest
from src.tools.calculator import safe_eval
# ============================================================================
# Basic Arithmetic Tests
# ============================================================================
def test_addition():
"""Test basic addition"""
result = safe_eval("2 + 3")
assert result["result"] == 5
assert result["success"] is True
def test_subtraction():
"""Test basic subtraction"""
result = safe_eval("10 - 4")
assert result["result"] == 6
def test_multiplication():
"""Test basic multiplication"""
result = safe_eval("6 * 7")
assert result["result"] == 42
def test_division():
"""Test basic division"""
result = safe_eval("15 / 3")
assert result["result"] == 5.0
def test_floor_division():
"""Test floor division"""
result = safe_eval("17 // 5")
assert result["result"] == 3
def test_modulo():
"""Test modulo operation"""
result = safe_eval("17 % 5")
assert result["result"] == 2
def test_exponentiation():
"""Test exponentiation"""
result = safe_eval("2 ** 8")
assert result["result"] == 256
def test_negative_numbers():
"""Test negative numbers"""
result = safe_eval("-5 + 3")
assert result["result"] == -2
def test_complex_expression():
"""Test complex arithmetic expression"""
result = safe_eval("(2 + 3) * 4 - 10 / 2")
assert result["result"] == 15.0
# ============================================================================
# Mathematical Function Tests
# ============================================================================
def test_sqrt():
"""Test square root function"""
result = safe_eval("sqrt(16)")
assert result["result"] == 4.0
def test_abs():
"""Test absolute value"""
result = safe_eval("abs(-42)")
assert result["result"] == 42
def test_round():
"""Test rounding"""
result = safe_eval("round(3.7)")
assert result["result"] == 4
def test_min():
"""Test min function"""
result = safe_eval("min(5, 2, 8, 1)")
assert result["result"] == 1
def test_max():
"""Test max function"""
result = safe_eval("max(5, 2, 8, 1)")
assert result["result"] == 8
def test_trigonometric():
"""Test trigonometric functions"""
result = safe_eval("sin(0)")
assert result["result"] == 0.0
result = safe_eval("cos(0)")
assert result["result"] == 1.0
def test_logarithm():
"""Test logarithmic functions"""
result = safe_eval("log10(100)")
assert result["result"] == 2.0
def test_constants():
"""Test mathematical constants"""
result = safe_eval("pi")
assert abs(result["result"] - 3.14159) < 0.001
result = safe_eval("e")
assert abs(result["result"] - 2.71828) < 0.001
def test_factorial():
"""Test factorial function"""
result = safe_eval("factorial(5)")
assert result["result"] == 120
def test_nested_functions():
"""Test nested function calls"""
result = safe_eval("sqrt(abs(-16))")
assert result["result"] == 4.0
# ============================================================================
# Security Tests
# ============================================================================
def test_no_import():
"""Test that imports are blocked"""
with pytest.raises(SyntaxError):
safe_eval("import os")
def test_no_exec():
"""Test that exec is blocked"""
with pytest.raises((ValueError, SyntaxError)):
safe_eval("exec('print(1)')")
def test_no_eval():
"""Test that eval is blocked"""
with pytest.raises((ValueError, SyntaxError)):
safe_eval("eval('1+1')")
def test_no_lambda():
"""Test that lambda is blocked"""
with pytest.raises((ValueError, SyntaxError)):
safe_eval("lambda x: x + 1")
def test_no_attribute_access():
"""Test that attribute access is blocked"""
with pytest.raises(ValueError):
safe_eval("(1).__class__")
def test_no_list_comprehension():
"""Test that list comprehensions are blocked"""
with pytest.raises(ValueError):
safe_eval("[x for x in range(10)]")
def test_no_dict_access():
"""Test that dict operations are blocked"""
with pytest.raises((ValueError, SyntaxError)):
safe_eval("{'a': 1}")
def test_no_undefined_names():
"""Test that undefined variable names are blocked"""
with pytest.raises(ValueError, match="Undefined name"):
safe_eval("undefined_variable + 1")
def test_no_dangerous_functions():
"""Test that dangerous functions are blocked"""
with pytest.raises(ValueError, match="Unsupported function"):
safe_eval("open('file.txt')")
# ============================================================================
# Error Handling Tests
# ============================================================================
def test_division_by_zero():
"""Test division by zero raises error"""
with pytest.raises(ZeroDivisionError):
safe_eval("10 / 0")
def test_invalid_syntax():
"""Test invalid syntax raises error"""
with pytest.raises(SyntaxError):
safe_eval("2 +* 3")
def test_empty_expression():
"""Test empty expression returns graceful error dict"""
result = safe_eval("")
assert result["success"] is False
assert "Empty expression" in result["error"]
assert result["result"] is None
def test_too_long_expression():
"""Test expression length limit returns graceful error dict"""
long_expr = "1 + " * 300 + "1"
result = safe_eval(long_expr)
assert result["success"] is False
assert "too long" in result["error"]
assert result["result"] is None
def test_huge_exponent():
"""Test that huge exponents are blocked"""
with pytest.raises(ValueError, match="Exponent too large"):
safe_eval("2 ** 10000")
def test_sqrt_negative():
"""Test sqrt of negative number raises error"""
with pytest.raises(ValueError):
safe_eval("sqrt(-1)")
def test_factorial_negative():
"""Test factorial of negative number raises error"""
with pytest.raises(ValueError):
safe_eval("factorial(-5)")
# ============================================================================
# Edge Case Tests
# ============================================================================
def test_whitespace_handling():
"""Test that whitespace is handled correctly"""
result = safe_eval(" 2 + 3 ")
assert result["result"] == 5
def test_floating_point():
"""Test floating point arithmetic"""
result = safe_eval("3.14 * 2")
assert abs(result["result"] - 6.28) < 0.01
def test_very_small_numbers():
"""Test very small numbers"""
result = safe_eval("0.0001 + 0.0002")
assert abs(result["result"] - 0.0003) < 0.00001
def test_scientific_notation():
"""Test scientific notation"""
result = safe_eval("1e3 + 2e2")
assert result["result"] == 1200.0
def test_parentheses_precedence():
"""Test that parentheses affect precedence correctly"""
result1 = safe_eval("2 + 3 * 4")
assert result1["result"] == 14
result2 = safe_eval("(2 + 3) * 4")
assert result2["result"] == 20
def test_multiple_operations():
"""Test chaining multiple operations"""
result = safe_eval("10 + 20 - 5 * 2 / 2 + 3")
assert result["result"] == 28.0