ask-the-web-agent / src /tools /calculator.py
debashis2007's picture
Upload folder using huggingface_hub
75bea1c verified
"""Calculator tool implementation."""
from __future__ import annotations
import ast
import math
import operator
from dataclasses import dataclass, field
from typing import Any
from src.tools.base import Tool, ToolParameter, ToolResult
from src.utils.logging import get_logger
logger = get_logger(__name__)
# Safe operators for evaluation
SAFE_OPERATORS = {
ast.Add: operator.add,
ast.Sub: operator.sub,
ast.Mult: operator.mul,
ast.Div: operator.truediv,
ast.FloorDiv: operator.floordiv,
ast.Mod: operator.mod,
ast.Pow: operator.pow,
ast.USub: operator.neg,
ast.UAdd: operator.pos,
}
# Safe math functions
SAFE_FUNCTIONS = {
"abs": abs,
"round": round,
"min": min,
"max": max,
"sum": sum,
"sqrt": math.sqrt,
"sin": math.sin,
"cos": math.cos,
"tan": math.tan,
"log": math.log,
"log10": math.log10,
"exp": math.exp,
"floor": math.floor,
"ceil": math.ceil,
"pi": math.pi,
"e": math.e,
}
@dataclass
class CalculatorTool(Tool):
"""Tool for performing mathematical calculations."""
name: str = "calculator"
description: str = "Evaluate mathematical expressions and perform calculations. Supports basic arithmetic, powers, and common math functions (sqrt, sin, cos, log, etc.)."
parameters: list[ToolParameter] = field(default_factory=lambda: [
ToolParameter(
name="expression",
type="string",
description="Mathematical expression to evaluate (e.g., '2 + 2 * 3', 'sqrt(16)', '15% of 200')",
required=True,
),
ToolParameter(
name="precision",
type="integer",
description="Number of decimal places for the result",
required=False,
default=4,
),
])
async def execute(self, **kwargs: Any) -> ToolResult:
"""Execute a calculation.
Args:
expression: Mathematical expression to evaluate
precision: Decimal places for result
Returns:
ToolResult with calculation result
"""
expression = kwargs.get("expression", "")
precision = kwargs.get("precision", 4)
if not expression:
return ToolResult.fail("Expression cannot be empty")
try:
# Pre-process expression
processed = self._preprocess(expression)
# Evaluate safely
result = self._safe_eval(processed)
# Format result
if isinstance(result, float):
result = round(result, precision)
return ToolResult.ok({
"expression": expression,
"processed": processed,
"result": result,
"result_type": type(result).__name__,
})
except ZeroDivisionError:
return ToolResult.fail("Division by zero")
except ValueError as e:
return ToolResult.fail(f"Math error: {e}")
except Exception as e:
logger.error(f"Calculation failed: {e}")
return ToolResult.fail(f"Could not evaluate expression: {e}")
def _preprocess(self, expression: str) -> str:
"""Preprocess expression for evaluation.
Args:
expression: Raw expression
Returns:
Processed expression
"""
expr = expression.strip()
# Handle percentage calculations: "15% of 200" -> "0.15 * 200"
import re
percent_pattern = r"(\d+(?:\.\d+)?)\s*%\s*of\s*(\d+(?:\.\d+)?)"
expr = re.sub(
percent_pattern, lambda m: f"({m.group(1)} / 100) * {m.group(2)}", expr
)
# Handle standalone percentages: "15%" -> "0.15"
expr = re.sub(r"(\d+(?:\.\d+)?)\s*%", r"(\1 / 100)", expr)
# Handle "^" as power operator
expr = expr.replace("^", "**")
# Handle implicit multiplication: "2(3)" -> "2*(3)"
expr = re.sub(r"(\d)\s*\(", r"\1*(", expr)
return expr
def _safe_eval(self, expression: str) -> float | int:
"""Safely evaluate a mathematical expression.
Args:
expression: Pre-processed expression
Returns:
Calculation result
Raises:
ValueError: If expression is invalid or unsafe
"""
try:
tree = ast.parse(expression, mode="eval")
except SyntaxError as e:
raise ValueError(f"Invalid expression syntax: {e}")
return self._eval_node(tree.body)
def _eval_node(self, node: ast.AST) -> float | int:
"""Recursively evaluate an AST node.
Args:
node: AST node to evaluate
Returns:
Evaluated value
"""
if isinstance(node, ast.Constant):
if isinstance(node.value, (int, float)):
return node.value
raise ValueError(f"Unsupported constant type: {type(node.value)}")
elif isinstance(node, ast.Name):
# Handle named constants like pi, e
if node.id in SAFE_FUNCTIONS:
value = SAFE_FUNCTIONS[node.id]
if isinstance(value, (int, float)):
return value
raise ValueError(f"Unknown variable: {node.id}")
elif isinstance(node, ast.BinOp):
left = self._eval_node(node.left)
right = self._eval_node(node.right)
op = type(node.op)
if op in SAFE_OPERATORS:
return SAFE_OPERATORS[op](left, right)
raise ValueError(f"Unsupported operator: {op.__name__}")
elif isinstance(node, ast.UnaryOp):
operand = self._eval_node(node.operand)
op = type(node.op)
if op in SAFE_OPERATORS:
return SAFE_OPERATORS[op](operand)
raise ValueError(f"Unsupported unary operator: {op.__name__}")
elif isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
func_name = node.func.id
if func_name in SAFE_FUNCTIONS:
func = SAFE_FUNCTIONS[func_name]
if callable(func):
args = [self._eval_node(arg) for arg in node.args]
return func(*args)
raise ValueError(f"Unsupported function call")
else:
raise ValueError(f"Unsupported expression type: {type(node).__name__}")