agentbee / src /tools /calculator.py
mangubee's picture
fix: correct author name formatting in multiple files
e7b4937
"""
Calculator Tool - Safe mathematical expression evaluation
Author: @mangubee
Date: 2026-01-02
Provides safe evaluation of mathematical expressions with:
- Whitelisted operations and functions
- Timeout protection
- Complexity limits
- No access to dangerous built-ins
Security is prioritized over functionality.
"""
import ast
import math
import operator
import logging
from typing import Any, Dict
import signal
from contextlib import contextmanager
# ============================================================================
# CONFIG
# ============================================================================
MAX_EXPRESSION_LENGTH = 500
MAX_EVAL_TIME_SECONDS = 2
MAX_NUMBER_SIZE = 10**100 # Prevent huge number calculations
# Whitelist of safe operations
SAFE_OPERATORS = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.FloorDiv: operator.floordiv,
ast.Mod: operator.mod,
ast.Pow: operator.pow,
ast.USub: operator.neg,
ast.UAdd: operator.pos,
}
# Whitelist of safe mathematical functions
SAFE_FUNCTIONS = {
'abs': abs,
'round': round,
'min': min,
'max': max,
'sum': sum,
# Math module functions
'sqrt': math.sqrt,
'ceil': math.ceil,
'floor': math.floor,
'log': math.log,
'log10': math.log10,
'exp': math.exp,
'sin': math.sin,
'cos': math.cos,
'tan': math.tan,
'asin': math.asin,
'acos': math.acos,
'atan': math.atan,
'degrees': math.degrees,
'radians': math.radians,
'factorial': math.factorial,
# Constants
'pi': math.pi,
'e': math.e,
}
# ============================================================================
# Logging Setup
# ============================================================================
logger = logging.getLogger(__name__)
# ============================================================================
# Timeout Context Manager
# ============================================================================
class TimeoutError(Exception):
"""Raised when evaluation exceeds timeout"""
pass
@contextmanager
def timeout(seconds: int):
"""
Context manager for timeout protection.
Args:
seconds: Maximum execution time
Raises:
TimeoutError: If execution exceeds timeout
Note:
signal.alarm() only works in main thread. In threaded contexts
(Gradio, ThreadPoolExecutor), timeout protection is disabled.
"""
def timeout_handler(signum, frame):
raise TimeoutError(f"Evaluation exceeded {seconds} second timeout")
try:
# Set signal handler (only works in main thread)
old_handler = signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(seconds)
_alarm_set = True
except (ValueError, AttributeError):
# ValueError: signal.alarm() in non-main thread
# AttributeError: signal.SIGALRM not available (Windows)
logger.warning(f"Timeout protection disabled (threading/Windows limitation)")
_alarm_set = False
old_handler = None
try:
yield
finally:
# Restore old handler and cancel alarm
if _alarm_set and old_handler is not None:
signal.alarm(0)
signal.signal(signal.SIGALRM, old_handler)
# ============================================================================
# Safe AST Evaluator
# ============================================================================
class SafeEvaluator(ast.NodeVisitor):
"""
AST visitor that evaluates mathematical expressions safely.
Only allows whitelisted operations and functions.
Prevents code execution, attribute access, and other dangerous operations.
"""
def visit_Expression(self, node):
"""Visit Expression node (root of parse tree)"""
return self.visit(node.body)
def visit_Constant(self, node):
"""Visit Constant node (numbers, strings)"""
value = node.value
# Only allow numbers
if not isinstance(value, (int, float, complex)):
raise ValueError(f"Unsupported constant type: {type(value).__name__}")
# Prevent huge numbers
if isinstance(value, (int, float)) and abs(value) > MAX_NUMBER_SIZE:
raise ValueError(f"Number too large: {value}")
return value
def visit_BinOp(self, node):
"""Visit binary operation node (+, -, *, /, etc.)"""
op_type = type(node.op)
if op_type not in SAFE_OPERATORS:
raise ValueError(f"Unsupported operation: {op_type.__name__}")
left = self.visit(node.left)
right = self.visit(node.right)
op_func = SAFE_OPERATORS[op_type]
# Check for division by zero
if op_type in (ast.Div, ast.FloorDiv, ast.Mod) and right == 0:
raise ZeroDivisionError("Division by zero")
# Prevent huge exponentiations
if op_type == ast.Pow and abs(right) > 1000:
raise ValueError(f"Exponent too large: {right}")
return op_func(left, right)
def visit_UnaryOp(self, node):
"""Visit unary operation node (-, +)"""
op_type = type(node.op)
if op_type not in SAFE_OPERATORS:
raise ValueError(f"Unsupported unary operation: {op_type.__name__}")
operand = self.visit(node.operand)
op_func = SAFE_OPERATORS[op_type]
return op_func(operand)
def visit_Call(self, node):
"""Visit function call node"""
# Only allow simple function names, not attribute access
if not isinstance(node.func, ast.Name):
raise ValueError("Only direct function calls are allowed")
func_name = node.func.id
if func_name not in SAFE_FUNCTIONS:
raise ValueError(f"Unsupported function: {func_name}")
# Evaluate arguments
args = [self.visit(arg) for arg in node.args]
# No keyword arguments allowed
if node.keywords:
raise ValueError("Keyword arguments not allowed")
func = SAFE_FUNCTIONS[func_name]
try:
return func(*args)
except Exception as e:
raise ValueError(f"Error calling {func_name}: {str(e)}")
def visit_Name(self, node):
"""Visit name node (variable/constant reference)"""
# Only allow whitelisted constants
if node.id in SAFE_FUNCTIONS:
value = SAFE_FUNCTIONS[node.id]
# If it's a constant (not a function), return it
if not callable(value):
return value
raise ValueError(f"Undefined name: {node.id}")
def visit_List(self, node):
"""Visit list node"""
return [self.visit(element) for element in node.elts]
def visit_Tuple(self, node):
"""Visit tuple node"""
return tuple(self.visit(element) for element in node.elts)
def generic_visit(self, node):
"""Catch-all for unsupported node types"""
raise ValueError(f"Unsupported expression type: {type(node).__name__}")
# ============================================================================
# Public API
# ============================================================================
def safe_eval(expression: str) -> Dict[str, Any]:
"""
Safely evaluate a mathematical expression.
Args:
expression: Mathematical expression string
Returns:
Dict with structure: {
"result": float or int, # Evaluation result
"expression": str, # Original expression
"success": bool # True if evaluation succeeded
}
Raises:
ValueError: For invalid or unsafe expressions
ZeroDivisionError: For division by zero
TimeoutError: If evaluation exceeds timeout
SyntaxError: For malformed expressions
Examples:
>>> safe_eval("2 + 2")
{"result": 4, "expression": "2 + 2", "success": True}
>>> safe_eval("sqrt(16) + 3")
{"result": 7.0, "expression": "sqrt(16) + 3", "success": True}
>>> safe_eval("import os") # Raises ValueError
"""
# Input validation - relaxed to avoid crashes
if not expression or not isinstance(expression, str):
logger.warning("Calculator received empty or non-string expression - returning graceful error")
return {
"result": None,
"expression": str(expression) if expression else "",
"success": False,
"error": "Empty expression provided. Calculator requires a mathematical expression string."
}
expression = expression.strip()
# Handle case where expression becomes empty after stripping whitespace
if not expression:
logger.warning("Calculator expression was only whitespace - returning graceful error")
return {
"result": None,
"expression": "",
"success": False,
"error": "Expression was only whitespace. Provide a valid mathematical expression."
}
if len(expression) > MAX_EXPRESSION_LENGTH:
logger.warning(f"Expression too long ({len(expression)} chars) - returning graceful error")
return {
"result": None,
"expression": expression[:100] + "...",
"success": False,
"error": f"Expression too long ({len(expression)} chars). Maximum: {MAX_EXPRESSION_LENGTH} chars"
}
logger.info(f"Evaluating expression: {expression}")
try:
# Parse expression into AST
tree = ast.parse(expression, mode='eval')
# Evaluate with timeout protection
with timeout(MAX_EVAL_TIME_SECONDS):
evaluator = SafeEvaluator()
result = evaluator.visit(tree)
logger.info(f"Evaluation successful: {result}")
return {
"result": result,
"expression": expression,
"success": True,
}
except SyntaxError as e:
logger.error(f"Syntax error in expression: {e}")
raise SyntaxError(f"Invalid expression syntax: {str(e)}")
except ZeroDivisionError as e:
logger.error(f"Division by zero: {expression}")
raise
except TimeoutError as e:
logger.error(f"Evaluation timeout: {expression}")
raise
except ValueError as e:
logger.error(f"Invalid expression: {e}")
raise
except Exception as e:
logger.error(f"Unexpected error evaluating expression: {e}")
raise ValueError(f"Evaluation error: {str(e)}")