Paramjit Singh commited on
Commit
d6e540c
·
unverified ·
2 Parent(s): 4defd96760d670

Merge pull request #256 from krishsharma-code/feat/rag-calculator-tool

Browse files
backend/app/rag/agent.py CHANGED
@@ -11,6 +11,7 @@ from app.config import get_settings
11
  from app.rag.retriever import retrieve
12
  from app.rag.graph_retriever import get_entity_context
13
  from app.rag.prompts import SYSTEM_PROMPT, RAG_PROMPT_TEMPLATE, GREETING_PROMPT
 
14
  from app.rag.tracing import trace_function
15
 
16
  logger = logging.getLogger(__name__)
@@ -23,6 +24,34 @@ def get_llm_client(hf_token: Optional[str] = None) -> InferenceClient:
23
  )
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def is_greeting(question: str) -> bool:
27
  """Detect if the question is a casual greeting rather than a document query."""
28
  greetings = {
@@ -141,12 +170,7 @@ def generate_answer(
141
  # ── Generate answer ──────────────────────────────
142
  # STAGE 3: Send prompt to HuggingFace Inference API and get the generated answer
143
  try:
144
- response = client.chat_completion(
145
- messages=messages,
146
- model=settings.LLM_MODEL,
147
- max_tokens=settings.LLM_MAX_NEW_TOKENS,
148
- temperature=settings.LLM_TEMPERATURE,
149
- )
150
  if response.choices:
151
  answer = response.choices[0].message.content.strip()
152
  else:
@@ -257,15 +281,17 @@ def generate_answer_stream(
257
  user_content = RAG_PROMPT_TEMPLATE.format(context=context, question=question)
258
  messages = _chat_messages(SYSTEM_PROMPT, user_content)
259
 
260
- # ── Stream answer tokens ─────────────────────────
261
- # STAGE 3: Stream tokens from HuggingFace Inference API → forward each as an SSE 'token' event
262
  try:
 
263
  stream = client.chat_completion(
264
  messages=messages,
265
  model=settings.LLM_MODEL,
266
  max_tokens=settings.LLM_MAX_NEW_TOKENS,
267
  temperature=settings.LLM_TEMPERATURE,
268
  stream=True,
 
 
269
  )
270
  for chunk in stream:
271
  if chunk.choices:
 
11
  from app.rag.retriever import retrieve
12
  from app.rag.graph_retriever import get_entity_context
13
  from app.rag.prompts import SYSTEM_PROMPT, RAG_PROMPT_TEMPLATE, GREETING_PROMPT
14
+ from app.rag.tools import TOOL_PROMPT, TOOLS, execute_tool
15
  from app.rag.tracing import trace_function
16
 
17
  logger = logging.getLogger(__name__)
 
24
  )
25
 
26
 
27
+ def _execute_tools_if_requested(client: InferenceClient, messages: list[dict[str, Any]]) -> Any:
28
+ """Run the LLM and execute any tool call responses until the final answer is produced."""
29
+ for _ in range(3):
30
+ response = client.chat_completion(
31
+ messages=messages,
32
+ model=settings.LLM_MODEL,
33
+ max_tokens=settings.LLM_MAX_NEW_TOKENS,
34
+ temperature=settings.LLM_TEMPERATURE,
35
+ tools=TOOLS,
36
+ tool_prompt=TOOL_PROMPT,
37
+ )
38
+
39
+ choice = response.choices[0]
40
+ tool_calls = getattr(choice.message, "tool_calls", None)
41
+ if not tool_calls:
42
+ return response
43
+
44
+ tool_call = tool_calls[0]
45
+ tool_name = tool_call.function.name
46
+ tool_args = json.loads(tool_call.function.arguments)
47
+ tool_result = execute_tool(tool_name, tool_args)
48
+
49
+ messages.append({"role": "tool", "name": tool_name, "content": tool_result})
50
+
51
+ # If tools are still requested after several rounds, return the latest response anyway.
52
+ return response
53
+
54
+
55
  def is_greeting(question: str) -> bool:
56
  """Detect if the question is a casual greeting rather than a document query."""
57
  greetings = {
 
170
  # ── Generate answer ──────────────────────────────
171
  # STAGE 3: Send prompt to HuggingFace Inference API and get the generated answer
172
  try:
173
+ response = _execute_tools_if_requested(client, messages)
 
 
 
 
 
174
  if response.choices:
175
  answer = response.choices[0].message.content.strip()
176
  else:
 
281
  user_content = RAG_PROMPT_TEMPLATE.format(context=context, question=question)
282
  messages = _chat_messages(SYSTEM_PROMPT, user_content)
283
 
284
+ # Resolve tool calls before streaming, then stream the final answer.
 
285
  try:
286
+ _execute_tools_if_requested(client, messages)
287
  stream = client.chat_completion(
288
  messages=messages,
289
  model=settings.LLM_MODEL,
290
  max_tokens=settings.LLM_MAX_NEW_TOKENS,
291
  temperature=settings.LLM_TEMPERATURE,
292
  stream=True,
293
+ tools=TOOLS,
294
+ tool_prompt=TOOL_PROMPT,
295
  )
296
  for chunk in stream:
297
  if chunk.choices:
backend/app/rag/prompts.py CHANGED
@@ -12,6 +12,7 @@ IMPORTANT RULES:
12
  4. Be precise, clear, and well-structured in your responses.
13
  5. Use bullet points and formatting when listing multiple items.
14
  6. For numerical data or key facts, quote the relevant text directly.
 
15
 
16
  FORMATTING:
17
  - Use **bold** for key terms and important findings
 
12
  4. Be precise, clear, and well-structured in your responses.
13
  5. Use bullet points and formatting when listing multiple items.
14
  6. For numerical data or key facts, quote the relevant text directly.
15
+ 7. If a question requires arithmetic calculations, use the registered calculator tool instead of guessing or estimating.
16
 
17
  FORMATTING:
18
  - Use **bold** for key terms and important findings
backend/app/rag/tools.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Agent tools for the PDF Assistant RAG backend."""
2
+
3
+ import ast
4
+ import operator as op
5
+ from typing import Any
6
+
7
+ from huggingface_hub.inference._generated.types.chat_completion import (
8
+ ChatCompletionInputFunctionDefinition,
9
+ ChatCompletionInputTool,
10
+ )
11
+
12
+ _ALLOWED_OPERATORS = {
13
+ ast.Add: op.add,
14
+ ast.Sub: op.sub,
15
+ ast.Mult: op.mul,
16
+ ast.Div: op.truediv,
17
+ ast.FloorDiv: op.floordiv,
18
+ ast.Mod: op.mod,
19
+ ast.Pow: op.pow,
20
+ ast.USub: op.neg,
21
+ ast.UAdd: op.pos,
22
+ }
23
+
24
+
25
+ def _evaluate_ast(node: ast.AST) -> float:
26
+ if isinstance(node, ast.Expression):
27
+ return _evaluate_ast(node.body)
28
+
29
+ if isinstance(node, ast.Constant):
30
+ if isinstance(node.value, (int, float)):
31
+ return float(node.value)
32
+ raise ValueError("Only numeric values are allowed in calculator expressions.")
33
+
34
+ if isinstance(node, ast.BinOp):
35
+ left = _evaluate_ast(node.left)
36
+ right = _evaluate_ast(node.right)
37
+ operator = type(node.op)
38
+ if operator not in _ALLOWED_OPERATORS:
39
+ raise ValueError(f"Operator {operator.__name__} is not allowed.")
40
+ return _ALLOWED_OPERATORS[operator](left, right)
41
+
42
+ if isinstance(node, ast.UnaryOp):
43
+ operator = type(node.op)
44
+ if operator not in _ALLOWED_OPERATORS:
45
+ raise ValueError(f"Operator {operator.__name__} is not allowed.")
46
+ operand = _evaluate_ast(node.operand)
47
+ return _ALLOWED_OPERATORS[operator](operand)
48
+
49
+ raise ValueError("Unsupported expression in calculator tool.")
50
+
51
+
52
+ def calculate_expression(expression: str) -> str:
53
+ """Safely evaluate a simple arithmetic expression.
54
+
55
+ This tool only permits numeric literals and arithmetic operators.
56
+ It does not execute arbitrary code.
57
+ """
58
+ try:
59
+ parsed = ast.parse(expression, mode="eval")
60
+ except SyntaxError as exc:
61
+ raise ValueError(f"Invalid calculator expression: {exc}") from exc
62
+
63
+ if not isinstance(parsed, ast.Expression):
64
+ raise ValueError("Expression must be a single arithmetic expression.")
65
+
66
+ result = _evaluate_ast(parsed)
67
+
68
+ if result.is_integer():
69
+ return str(int(result))
70
+
71
+ return str(result)
72
+
73
+
74
+ def execute_tool(name: str, arguments: dict[str, Any]) -> str:
75
+ """Execute a registered tool by name."""
76
+ if name != "calculator":
77
+ raise ValueError(f"Unknown tool: {name}")
78
+
79
+ expression = arguments.get("expression")
80
+ if not isinstance(expression, str) or not expression.strip():
81
+ raise ValueError("The calculator tool requires a non-empty 'expression' string.")
82
+
83
+ return calculate_expression(expression)
84
+
85
+
86
+ CALCULATOR_TOOL = ChatCompletionInputTool(
87
+ function=ChatCompletionInputFunctionDefinition(
88
+ name="calculator",
89
+ description=(
90
+ "Safely evaluate a numeric arithmetic expression for financial calculations. "
91
+ "Use only numeric values and arithmetic operators like +, -, *, /, %, //, and **."
92
+ ),
93
+ parameters={
94
+ "type": "object",
95
+ "properties": {
96
+ "expression": {
97
+ "type": "string",
98
+ "description": (
99
+ "A valid arithmetic expression to evaluate, for example '1000 - 250' or "
100
+ "'(revenue - expenses) * 0.2'."
101
+ ),
102
+ }
103
+ },
104
+ "required": ["expression"],
105
+ },
106
+ ),
107
+ type="tool",
108
+ )
109
+
110
+ TOOL_PROMPT = (
111
+ "Use the calculator tool for all numeric arithmetic operations in the user query. "
112
+ "The tool accepts a single 'expression' field and returns the evaluated numeric result. "
113
+ "Do not attempt to compute arithmetic without the tool."
114
+ )
115
+
116
+ TOOLS = [CALCULATOR_TOOL]
backend/tests/test_rag_tools.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app.rag.tools import CALCULATOR_TOOL, calculate_expression, execute_tool
2
+
3
+
4
+ def test_calculator_tool_evaluates_basic_expression():
5
+ assert calculate_expression("1000 - 250") == "750"
6
+ assert calculate_expression("10 + 5 * 2") == "20"
7
+ assert calculate_expression("10 / 4") == "2.5"
8
+
9
+
10
+ def test_calculator_tool_rejects_unsafe_expression():
11
+ try:
12
+ calculate_expression("__import__('os').system('echo x')")
13
+ except ValueError as exc:
14
+ assert "Invalid calculator expression" in str(exc) or "Unsupported expression" in str(exc)
15
+ else:
16
+ assert False, "Unsafe expressions should not be evaluated"
17
+
18
+
19
+ def test_execute_tool_with_expression_argument():
20
+ result = execute_tool("calculator", {"expression": "12 * 3"})
21
+ assert result == "36"
22
+
23
+
24
+ def test_calculator_tool_metadata():
25
+ assert CALCULATOR_TOOL["function"]["name"] == "calculator"
26
+ assert "expression" in CALCULATOR_TOOL["function"]["parameters"]["properties"]