Spaces:
Sleeping
Sleeping
| """Calculator tool implementation.""" | |
| from __future__ import annotations | |
| import ast | |
| import math | |
| import operator | |
| from dataclasses import dataclass, field | |
| from typing import Any | |
| from src.tools.base import Tool, ToolParameter, ToolResult | |
| from src.utils.logging import get_logger | |
| logger = get_logger(__name__) | |
| # Safe operators for evaluation | |
| SAFE_OPERATORS = { | |
| ast.Add: operator.add, | |
| ast.Sub: operator.sub, | |
| ast.Mult: operator.mul, | |
| ast.Div: operator.truediv, | |
| ast.FloorDiv: operator.floordiv, | |
| ast.Mod: operator.mod, | |
| ast.Pow: operator.pow, | |
| ast.USub: operator.neg, | |
| ast.UAdd: operator.pos, | |
| } | |
| # Safe math functions | |
| SAFE_FUNCTIONS = { | |
| "abs": abs, | |
| "round": round, | |
| "min": min, | |
| "max": max, | |
| "sum": sum, | |
| "sqrt": math.sqrt, | |
| "sin": math.sin, | |
| "cos": math.cos, | |
| "tan": math.tan, | |
| "log": math.log, | |
| "log10": math.log10, | |
| "exp": math.exp, | |
| "floor": math.floor, | |
| "ceil": math.ceil, | |
| "pi": math.pi, | |
| "e": math.e, | |
| } | |
| class CalculatorTool(Tool): | |
| """Tool for performing mathematical calculations.""" | |
| name: str = "calculator" | |
| description: str = "Evaluate mathematical expressions and perform calculations. Supports basic arithmetic, powers, and common math functions (sqrt, sin, cos, log, etc.)." | |
| parameters: list[ToolParameter] = field(default_factory=lambda: [ | |
| ToolParameter( | |
| name="expression", | |
| type="string", | |
| description="Mathematical expression to evaluate (e.g., '2 + 2 * 3', 'sqrt(16)', '15% of 200')", | |
| required=True, | |
| ), | |
| ToolParameter( | |
| name="precision", | |
| type="integer", | |
| description="Number of decimal places for the result", | |
| required=False, | |
| default=4, | |
| ), | |
| ]) | |
| async def execute(self, **kwargs: Any) -> ToolResult: | |
| """Execute a calculation. | |
| Args: | |
| expression: Mathematical expression to evaluate | |
| precision: Decimal places for result | |
| Returns: | |
| ToolResult with calculation result | |
| """ | |
| expression = kwargs.get("expression", "") | |
| precision = kwargs.get("precision", 4) | |
| if not expression: | |
| return ToolResult.fail("Expression cannot be empty") | |
| try: | |
| # Pre-process expression | |
| processed = self._preprocess(expression) | |
| # Evaluate safely | |
| result = self._safe_eval(processed) | |
| # Format result | |
| if isinstance(result, float): | |
| result = round(result, precision) | |
| return ToolResult.ok({ | |
| "expression": expression, | |
| "processed": processed, | |
| "result": result, | |
| "result_type": type(result).__name__, | |
| }) | |
| except ZeroDivisionError: | |
| return ToolResult.fail("Division by zero") | |
| except ValueError as e: | |
| return ToolResult.fail(f"Math error: {e}") | |
| except Exception as e: | |
| logger.error(f"Calculation failed: {e}") | |
| return ToolResult.fail(f"Could not evaluate expression: {e}") | |
| def _preprocess(self, expression: str) -> str: | |
| """Preprocess expression for evaluation. | |
| Args: | |
| expression: Raw expression | |
| Returns: | |
| Processed expression | |
| """ | |
| expr = expression.strip() | |
| # Handle percentage calculations: "15% of 200" -> "0.15 * 200" | |
| import re | |
| percent_pattern = r"(\d+(?:\.\d+)?)\s*%\s*of\s*(\d+(?:\.\d+)?)" | |
| expr = re.sub( | |
| percent_pattern, lambda m: f"({m.group(1)} / 100) * {m.group(2)}", expr | |
| ) | |
| # Handle standalone percentages: "15%" -> "0.15" | |
| expr = re.sub(r"(\d+(?:\.\d+)?)\s*%", r"(\1 / 100)", expr) | |
| # Handle "^" as power operator | |
| expr = expr.replace("^", "**") | |
| # Handle implicit multiplication: "2(3)" -> "2*(3)" | |
| expr = re.sub(r"(\d)\s*\(", r"\1*(", expr) | |
| return expr | |
| def _safe_eval(self, expression: str) -> float | int: | |
| """Safely evaluate a mathematical expression. | |
| Args: | |
| expression: Pre-processed expression | |
| Returns: | |
| Calculation result | |
| Raises: | |
| ValueError: If expression is invalid or unsafe | |
| """ | |
| try: | |
| tree = ast.parse(expression, mode="eval") | |
| except SyntaxError as e: | |
| raise ValueError(f"Invalid expression syntax: {e}") | |
| return self._eval_node(tree.body) | |
| def _eval_node(self, node: ast.AST) -> float | int: | |
| """Recursively evaluate an AST node. | |
| Args: | |
| node: AST node to evaluate | |
| Returns: | |
| Evaluated value | |
| """ | |
| if isinstance(node, ast.Constant): | |
| if isinstance(node.value, (int, float)): | |
| return node.value | |
| raise ValueError(f"Unsupported constant type: {type(node.value)}") | |
| elif isinstance(node, ast.Name): | |
| # Handle named constants like pi, e | |
| if node.id in SAFE_FUNCTIONS: | |
| value = SAFE_FUNCTIONS[node.id] | |
| if isinstance(value, (int, float)): | |
| return value | |
| raise ValueError(f"Unknown variable: {node.id}") | |
| elif isinstance(node, ast.BinOp): | |
| left = self._eval_node(node.left) | |
| right = self._eval_node(node.right) | |
| op = type(node.op) | |
| if op in SAFE_OPERATORS: | |
| return SAFE_OPERATORS[op](left, right) | |
| raise ValueError(f"Unsupported operator: {op.__name__}") | |
| elif isinstance(node, ast.UnaryOp): | |
| operand = self._eval_node(node.operand) | |
| op = type(node.op) | |
| if op in SAFE_OPERATORS: | |
| return SAFE_OPERATORS[op](operand) | |
| raise ValueError(f"Unsupported unary operator: {op.__name__}") | |
| elif isinstance(node, ast.Call): | |
| if isinstance(node.func, ast.Name): | |
| func_name = node.func.id | |
| if func_name in SAFE_FUNCTIONS: | |
| func = SAFE_FUNCTIONS[func_name] | |
| if callable(func): | |
| args = [self._eval_node(arg) for arg in node.args] | |
| return func(*args) | |
| raise ValueError(f"Unsupported function call") | |
| else: | |
| raise ValueError(f"Unsupported expression type: {type(node).__name__}") | |