|
|
""" |
|
|
Calculator Tool - Safe mathematical expression evaluation |
|
|
Author: @mangubee |
|
|
Date: 2026-01-02 |
|
|
|
|
|
Provides safe evaluation of mathematical expressions with: |
|
|
- Whitelisted operations and functions |
|
|
- Timeout protection |
|
|
- Complexity limits |
|
|
- No access to dangerous built-ins |
|
|
|
|
|
Security is prioritized over functionality. |
|
|
""" |
|
|
|
|
|
import ast |
|
|
import math |
|
|
import operator |
|
|
import logging |
|
|
from typing import Any, Dict |
|
|
import signal |
|
|
from contextlib import contextmanager |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MAX_EXPRESSION_LENGTH = 500 |
|
|
MAX_EVAL_TIME_SECONDS = 2 |
|
|
MAX_NUMBER_SIZE = 10**100 |
|
|
|
|
|
|
|
|
SAFE_OPERATORS = { |
|
|
ast.Add: operator.add, |
|
|
ast.Sub: operator.sub, |
|
|
ast.Mult: operator.mul, |
|
|
ast.Div: operator.truediv, |
|
|
ast.FloorDiv: operator.floordiv, |
|
|
ast.Mod: operator.mod, |
|
|
ast.Pow: operator.pow, |
|
|
ast.USub: operator.neg, |
|
|
ast.UAdd: operator.pos, |
|
|
} |
|
|
|
|
|
|
|
|
SAFE_FUNCTIONS = { |
|
|
'abs': abs, |
|
|
'round': round, |
|
|
'min': min, |
|
|
'max': max, |
|
|
'sum': sum, |
|
|
|
|
|
'sqrt': math.sqrt, |
|
|
'ceil': math.ceil, |
|
|
'floor': math.floor, |
|
|
'log': math.log, |
|
|
'log10': math.log10, |
|
|
'exp': math.exp, |
|
|
'sin': math.sin, |
|
|
'cos': math.cos, |
|
|
'tan': math.tan, |
|
|
'asin': math.asin, |
|
|
'acos': math.acos, |
|
|
'atan': math.atan, |
|
|
'degrees': math.degrees, |
|
|
'radians': math.radians, |
|
|
'factorial': math.factorial, |
|
|
|
|
|
'pi': math.pi, |
|
|
'e': math.e, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TimeoutError(Exception): |
|
|
"""Raised when evaluation exceeds timeout""" |
|
|
pass |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def timeout(seconds: int): |
|
|
""" |
|
|
Context manager for timeout protection. |
|
|
|
|
|
Args: |
|
|
seconds: Maximum execution time |
|
|
|
|
|
Raises: |
|
|
TimeoutError: If execution exceeds timeout |
|
|
|
|
|
Note: |
|
|
signal.alarm() only works in main thread. In threaded contexts |
|
|
(Gradio, ThreadPoolExecutor), timeout protection is disabled. |
|
|
""" |
|
|
def timeout_handler(signum, frame): |
|
|
raise TimeoutError(f"Evaluation exceeded {seconds} second timeout") |
|
|
|
|
|
try: |
|
|
|
|
|
old_handler = signal.signal(signal.SIGALRM, timeout_handler) |
|
|
signal.alarm(seconds) |
|
|
_alarm_set = True |
|
|
except (ValueError, AttributeError): |
|
|
|
|
|
|
|
|
logger.warning(f"Timeout protection disabled (threading/Windows limitation)") |
|
|
_alarm_set = False |
|
|
old_handler = None |
|
|
|
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
|
|
|
if _alarm_set and old_handler is not None: |
|
|
signal.alarm(0) |
|
|
signal.signal(signal.SIGALRM, old_handler) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SafeEvaluator(ast.NodeVisitor): |
|
|
""" |
|
|
AST visitor that evaluates mathematical expressions safely. |
|
|
|
|
|
Only allows whitelisted operations and functions. |
|
|
Prevents code execution, attribute access, and other dangerous operations. |
|
|
""" |
|
|
|
|
|
def visit_Expression(self, node): |
|
|
"""Visit Expression node (root of parse tree)""" |
|
|
return self.visit(node.body) |
|
|
|
|
|
def visit_Constant(self, node): |
|
|
"""Visit Constant node (numbers, strings)""" |
|
|
value = node.value |
|
|
|
|
|
|
|
|
if not isinstance(value, (int, float, complex)): |
|
|
raise ValueError(f"Unsupported constant type: {type(value).__name__}") |
|
|
|
|
|
|
|
|
if isinstance(value, (int, float)) and abs(value) > MAX_NUMBER_SIZE: |
|
|
raise ValueError(f"Number too large: {value}") |
|
|
|
|
|
return value |
|
|
|
|
|
def visit_BinOp(self, node): |
|
|
"""Visit binary operation node (+, -, *, /, etc.)""" |
|
|
op_type = type(node.op) |
|
|
|
|
|
if op_type not in SAFE_OPERATORS: |
|
|
raise ValueError(f"Unsupported operation: {op_type.__name__}") |
|
|
|
|
|
left = self.visit(node.left) |
|
|
right = self.visit(node.right) |
|
|
|
|
|
op_func = SAFE_OPERATORS[op_type] |
|
|
|
|
|
|
|
|
if op_type in (ast.Div, ast.FloorDiv, ast.Mod) and right == 0: |
|
|
raise ZeroDivisionError("Division by zero") |
|
|
|
|
|
|
|
|
if op_type == ast.Pow and abs(right) > 1000: |
|
|
raise ValueError(f"Exponent too large: {right}") |
|
|
|
|
|
return op_func(left, right) |
|
|
|
|
|
def visit_UnaryOp(self, node): |
|
|
"""Visit unary operation node (-, +)""" |
|
|
op_type = type(node.op) |
|
|
|
|
|
if op_type not in SAFE_OPERATORS: |
|
|
raise ValueError(f"Unsupported unary operation: {op_type.__name__}") |
|
|
|
|
|
operand = self.visit(node.operand) |
|
|
op_func = SAFE_OPERATORS[op_type] |
|
|
|
|
|
return op_func(operand) |
|
|
|
|
|
def visit_Call(self, node): |
|
|
"""Visit function call node""" |
|
|
|
|
|
if not isinstance(node.func, ast.Name): |
|
|
raise ValueError("Only direct function calls are allowed") |
|
|
|
|
|
func_name = node.func.id |
|
|
|
|
|
if func_name not in SAFE_FUNCTIONS: |
|
|
raise ValueError(f"Unsupported function: {func_name}") |
|
|
|
|
|
|
|
|
args = [self.visit(arg) for arg in node.args] |
|
|
|
|
|
|
|
|
if node.keywords: |
|
|
raise ValueError("Keyword arguments not allowed") |
|
|
|
|
|
func = SAFE_FUNCTIONS[func_name] |
|
|
|
|
|
try: |
|
|
return func(*args) |
|
|
except Exception as e: |
|
|
raise ValueError(f"Error calling {func_name}: {str(e)}") |
|
|
|
|
|
def visit_Name(self, node): |
|
|
"""Visit name node (variable/constant reference)""" |
|
|
|
|
|
if node.id in SAFE_FUNCTIONS: |
|
|
value = SAFE_FUNCTIONS[node.id] |
|
|
|
|
|
if not callable(value): |
|
|
return value |
|
|
|
|
|
raise ValueError(f"Undefined name: {node.id}") |
|
|
|
|
|
def visit_List(self, node): |
|
|
"""Visit list node""" |
|
|
return [self.visit(element) for element in node.elts] |
|
|
|
|
|
def visit_Tuple(self, node): |
|
|
"""Visit tuple node""" |
|
|
return tuple(self.visit(element) for element in node.elts) |
|
|
|
|
|
def generic_visit(self, node): |
|
|
"""Catch-all for unsupported node types""" |
|
|
raise ValueError(f"Unsupported expression type: {type(node).__name__}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def safe_eval(expression: str) -> Dict[str, Any]: |
|
|
""" |
|
|
Safely evaluate a mathematical expression. |
|
|
|
|
|
Args: |
|
|
expression: Mathematical expression string |
|
|
|
|
|
Returns: |
|
|
Dict with structure: { |
|
|
"result": float or int, # Evaluation result |
|
|
"expression": str, # Original expression |
|
|
"success": bool # True if evaluation succeeded |
|
|
} |
|
|
|
|
|
Raises: |
|
|
ValueError: For invalid or unsafe expressions |
|
|
ZeroDivisionError: For division by zero |
|
|
TimeoutError: If evaluation exceeds timeout |
|
|
SyntaxError: For malformed expressions |
|
|
|
|
|
Examples: |
|
|
>>> safe_eval("2 + 2") |
|
|
{"result": 4, "expression": "2 + 2", "success": True} |
|
|
|
|
|
>>> safe_eval("sqrt(16) + 3") |
|
|
{"result": 7.0, "expression": "sqrt(16) + 3", "success": True} |
|
|
|
|
|
>>> safe_eval("import os") # Raises ValueError |
|
|
""" |
|
|
|
|
|
if not expression or not isinstance(expression, str): |
|
|
logger.warning("Calculator received empty or non-string expression - returning graceful error") |
|
|
return { |
|
|
"result": None, |
|
|
"expression": str(expression) if expression else "", |
|
|
"success": False, |
|
|
"error": "Empty expression provided. Calculator requires a mathematical expression string." |
|
|
} |
|
|
|
|
|
expression = expression.strip() |
|
|
|
|
|
|
|
|
if not expression: |
|
|
logger.warning("Calculator expression was only whitespace - returning graceful error") |
|
|
return { |
|
|
"result": None, |
|
|
"expression": "", |
|
|
"success": False, |
|
|
"error": "Expression was only whitespace. Provide a valid mathematical expression." |
|
|
} |
|
|
|
|
|
if len(expression) > MAX_EXPRESSION_LENGTH: |
|
|
logger.warning(f"Expression too long ({len(expression)} chars) - returning graceful error") |
|
|
return { |
|
|
"result": None, |
|
|
"expression": expression[:100] + "...", |
|
|
"success": False, |
|
|
"error": f"Expression too long ({len(expression)} chars). Maximum: {MAX_EXPRESSION_LENGTH} chars" |
|
|
} |
|
|
|
|
|
logger.info(f"Evaluating expression: {expression}") |
|
|
|
|
|
try: |
|
|
|
|
|
tree = ast.parse(expression, mode='eval') |
|
|
|
|
|
|
|
|
with timeout(MAX_EVAL_TIME_SECONDS): |
|
|
evaluator = SafeEvaluator() |
|
|
result = evaluator.visit(tree) |
|
|
|
|
|
logger.info(f"Evaluation successful: {result}") |
|
|
|
|
|
return { |
|
|
"result": result, |
|
|
"expression": expression, |
|
|
"success": True, |
|
|
} |
|
|
|
|
|
except SyntaxError as e: |
|
|
logger.error(f"Syntax error in expression: {e}") |
|
|
raise SyntaxError(f"Invalid expression syntax: {str(e)}") |
|
|
except ZeroDivisionError as e: |
|
|
logger.error(f"Division by zero: {expression}") |
|
|
raise |
|
|
except TimeoutError as e: |
|
|
logger.error(f"Evaluation timeout: {expression}") |
|
|
raise |
|
|
except ValueError as e: |
|
|
logger.error(f"Invalid expression: {e}") |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"Unexpected error evaluating expression: {e}") |
|
|
raise ValueError(f"Evaluation error: {str(e)}") |
|
|
|