KrishSharma07 commited on
Commit
760d670
·
1 Parent(s): 758f79c

feat(rag): implement secure calculator tool for financial math (Fixes #219)

Browse files
backend/app/rag/agent.py CHANGED
@@ -10,6 +10,7 @@ from huggingface_hub import InferenceClient
10
  from app.config import get_settings
11
  from app.rag.retriever import retrieve
12
  from app.rag.prompts import SYSTEM_PROMPT, RAG_PROMPT_TEMPLATE, GREETING_PROMPT
 
13
  from app.rag.tracing import trace_function
14
 
15
  logger = logging.getLogger(__name__)
@@ -32,6 +33,34 @@ def get_llm_client() -> InferenceClient:
32
  return _llm_client
33
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def is_greeting(question: str) -> bool:
36
  """Detect if the question is a casual greeting rather than a document query."""
37
  greetings = {
@@ -124,12 +153,7 @@ def generate_answer(
124
  # ── Generate answer ──────────────────────────────
125
  # STAGE 3: Send prompt to HuggingFace Inference API and get the generated answer
126
  try:
127
- response = client.chat_completion(
128
- messages=messages,
129
- model=settings.LLM_MODEL,
130
- max_tokens=settings.LLM_MAX_NEW_TOKENS,
131
- temperature=settings.LLM_TEMPERATURE,
132
- )
133
  if response.choices:
134
  answer = response.choices[0].message.content.strip()
135
  else:
@@ -234,15 +258,17 @@ def generate_answer_stream(
234
  user_content = RAG_PROMPT_TEMPLATE.format(context=context, question=question)
235
  messages = _chat_messages(SYSTEM_PROMPT, user_content)
236
 
237
- # ── Stream answer tokens ─────────────────────────
238
- # STAGE 3: Stream tokens from HuggingFace Inference API → forward each as an SSE 'token' event
239
  try:
 
240
  stream = client.chat_completion(
241
  messages=messages,
242
  model=settings.LLM_MODEL,
243
  max_tokens=settings.LLM_MAX_NEW_TOKENS,
244
  temperature=settings.LLM_TEMPERATURE,
245
  stream=True,
 
 
246
  )
247
  for chunk in stream:
248
  if chunk.choices:
 
10
  from app.config import get_settings
11
  from app.rag.retriever import retrieve
12
  from app.rag.prompts import SYSTEM_PROMPT, RAG_PROMPT_TEMPLATE, GREETING_PROMPT
13
+ from app.rag.tools import TOOL_PROMPT, TOOLS, execute_tool
14
  from app.rag.tracing import trace_function
15
 
16
  logger = logging.getLogger(__name__)
 
33
  return _llm_client
34
 
35
 
36
+ def _execute_tools_if_requested(client: InferenceClient, messages: list[dict[str, Any]]) -> Any:
37
+ """Run the LLM and execute any tool call responses until the final answer is produced."""
38
+ for _ in range(3):
39
+ response = client.chat_completion(
40
+ messages=messages,
41
+ model=settings.LLM_MODEL,
42
+ max_tokens=settings.LLM_MAX_NEW_TOKENS,
43
+ temperature=settings.LLM_TEMPERATURE,
44
+ tools=TOOLS,
45
+ tool_prompt=TOOL_PROMPT,
46
+ )
47
+
48
+ choice = response.choices[0]
49
+ tool_calls = getattr(choice.message, "tool_calls", None)
50
+ if not tool_calls:
51
+ return response
52
+
53
+ tool_call = tool_calls[0]
54
+ tool_name = tool_call.function.name
55
+ tool_args = json.loads(tool_call.function.arguments)
56
+ tool_result = execute_tool(tool_name, tool_args)
57
+
58
+ messages.append({"role": "tool", "name": tool_name, "content": tool_result})
59
+
60
+ # If tools are still requested after several rounds, return the latest response anyway.
61
+ return response
62
+
63
+
64
  def is_greeting(question: str) -> bool:
65
  """Detect if the question is a casual greeting rather than a document query."""
66
  greetings = {
 
153
  # ── Generate answer ──────────────────────────────
154
  # STAGE 3: Send prompt to HuggingFace Inference API and get the generated answer
155
  try:
156
+ response = _execute_tools_if_requested(client, messages)
 
 
 
 
 
157
  if response.choices:
158
  answer = response.choices[0].message.content.strip()
159
  else:
 
258
  user_content = RAG_PROMPT_TEMPLATE.format(context=context, question=question)
259
  messages = _chat_messages(SYSTEM_PROMPT, user_content)
260
 
261
+ # Resolve tool calls before streaming, then stream the final answer.
 
262
  try:
263
+ _execute_tools_if_requested(client, messages)
264
  stream = client.chat_completion(
265
  messages=messages,
266
  model=settings.LLM_MODEL,
267
  max_tokens=settings.LLM_MAX_NEW_TOKENS,
268
  temperature=settings.LLM_TEMPERATURE,
269
  stream=True,
270
+ tools=TOOLS,
271
+ tool_prompt=TOOL_PROMPT,
272
  )
273
  for chunk in stream:
274
  if chunk.choices:
backend/app/rag/prompts.py CHANGED
@@ -12,6 +12,7 @@ IMPORTANT RULES:
12
  4. Be precise, clear, and well-structured in your responses.
13
  5. Use bullet points and formatting when listing multiple items.
14
  6. For numerical data or key facts, quote the relevant text directly.
 
15
 
16
  FORMATTING:
17
  - Use **bold** for key terms and important findings
 
12
  4. Be precise, clear, and well-structured in your responses.
13
  5. Use bullet points and formatting when listing multiple items.
14
  6. For numerical data or key facts, quote the relevant text directly.
15
+ 7. If a question requires arithmetic calculations, use the registered calculator tool instead of guessing or estimating.
16
 
17
  FORMATTING:
18
  - Use **bold** for key terms and important findings
backend/app/rag/tools.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Agent tools for the PDF Assistant RAG backend."""
2
+
3
+ import ast
4
+ import operator as op
5
+ from typing import Any
6
+
7
+ from huggingface_hub.inference._generated.types.chat_completion import (
8
+ ChatCompletionInputFunctionDefinition,
9
+ ChatCompletionInputTool,
10
+ )
11
+
12
+ _ALLOWED_OPERATORS = {
13
+ ast.Add: op.add,
14
+ ast.Sub: op.sub,
15
+ ast.Mult: op.mul,
16
+ ast.Div: op.truediv,
17
+ ast.FloorDiv: op.floordiv,
18
+ ast.Mod: op.mod,
19
+ ast.Pow: op.pow,
20
+ ast.USub: op.neg,
21
+ ast.UAdd: op.pos,
22
+ }
23
+
24
+
25
+ def _evaluate_ast(node: ast.AST) -> float:
26
+ if isinstance(node, ast.Expression):
27
+ return _evaluate_ast(node.body)
28
+
29
+ if isinstance(node, ast.Constant):
30
+ if isinstance(node.value, (int, float)):
31
+ return float(node.value)
32
+ raise ValueError("Only numeric values are allowed in calculator expressions.")
33
+
34
+ if isinstance(node, ast.BinOp):
35
+ left = _evaluate_ast(node.left)
36
+ right = _evaluate_ast(node.right)
37
+ operator = type(node.op)
38
+ if operator not in _ALLOWED_OPERATORS:
39
+ raise ValueError(f"Operator {operator.__name__} is not allowed.")
40
+ return _ALLOWED_OPERATORS[operator](left, right)
41
+
42
+ if isinstance(node, ast.UnaryOp):
43
+ operator = type(node.op)
44
+ if operator not in _ALLOWED_OPERATORS:
45
+ raise ValueError(f"Operator {operator.__name__} is not allowed.")
46
+ operand = _evaluate_ast(node.operand)
47
+ return _ALLOWED_OPERATORS[operator](operand)
48
+
49
+ raise ValueError("Unsupported expression in calculator tool.")
50
+
51
+
52
+ def calculate_expression(expression: str) -> str:
53
+ """Safely evaluate a simple arithmetic expression.
54
+
55
+ This tool only permits numeric literals and arithmetic operators.
56
+ It does not execute arbitrary code.
57
+ """
58
+ try:
59
+ parsed = ast.parse(expression, mode="eval")
60
+ except SyntaxError as exc:
61
+ raise ValueError(f"Invalid calculator expression: {exc}") from exc
62
+
63
+ if not isinstance(parsed, ast.Expression):
64
+ raise ValueError("Expression must be a single arithmetic expression.")
65
+
66
+ result = _evaluate_ast(parsed)
67
+
68
+ if result.is_integer():
69
+ return str(int(result))
70
+
71
+ return str(result)
72
+
73
+
74
+ def execute_tool(name: str, arguments: dict[str, Any]) -> str:
75
+ """Execute a registered tool by name."""
76
+ if name != "calculator":
77
+ raise ValueError(f"Unknown tool: {name}")
78
+
79
+ expression = arguments.get("expression")
80
+ if not isinstance(expression, str) or not expression.strip():
81
+ raise ValueError("The calculator tool requires a non-empty 'expression' string.")
82
+
83
+ return calculate_expression(expression)
84
+
85
+
86
+ CALCULATOR_TOOL = ChatCompletionInputTool(
87
+ function=ChatCompletionInputFunctionDefinition(
88
+ name="calculator",
89
+ description=(
90
+ "Safely evaluate a numeric arithmetic expression for financial calculations. "
91
+ "Use only numeric values and arithmetic operators like +, -, *, /, %, //, and **."
92
+ ),
93
+ parameters={
94
+ "type": "object",
95
+ "properties": {
96
+ "expression": {
97
+ "type": "string",
98
+ "description": (
99
+ "A valid arithmetic expression to evaluate, for example '1000 - 250' or "
100
+ "'(revenue - expenses) * 0.2'."
101
+ ),
102
+ }
103
+ },
104
+ "required": ["expression"],
105
+ },
106
+ ),
107
+ type="tool",
108
+ )
109
+
110
+ TOOL_PROMPT = (
111
+ "Use the calculator tool for all numeric arithmetic operations in the user query. "
112
+ "The tool accepts a single 'expression' field and returns the evaluated numeric result. "
113
+ "Do not attempt to compute arithmetic without the tool."
114
+ )
115
+
116
+ TOOLS = [CALCULATOR_TOOL]
backend/tests/test_rag_tools.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.rag.tools import CALCULATOR_TOOL, calculate_expression, execute_tool
2
+
3
+
4
+ def test_calculator_tool_evaluates_basic_expression():
5
+ assert calculate_expression("1000 - 250") == "750"
6
+ assert calculate_expression("10 + 5 * 2") == "20"
7
+ assert calculate_expression("10 / 4") == "2.5"
8
+
9
+
10
+ def test_calculator_tool_rejects_unsafe_expression():
11
+ try:
12
+ calculate_expression("__import__('os').system('echo x')")
13
+ except ValueError as exc:
14
+ assert "Invalid calculator expression" in str(exc) or "Unsupported expression" in str(exc)
15
+ else:
16
+ assert False, "Unsafe expressions should not be evaluated"
17
+
18
+
19
+ def test_execute_tool_with_expression_argument():
20
+ result = execute_tool("calculator", {"expression": "12 * 3"})
21
+ assert result == "36"
22
+
23
+
24
+ def test_calculator_tool_metadata():
25
+ assert CALCULATOR_TOOL["function"]["name"] == "calculator"
26
+ assert "expression" in CALCULATOR_TOOL["function"]["parameters"]["properties"]