Commit
·
9126c2d
1
Parent(s):
4144607
handling of duplicate calls for tools
Browse files- __pycache__/agent.cpython-313.pyc +0 -0
- agent.py +10 -5
- app.py +64 -21
- test_deduplication.py +43 -0
- tools/__pycache__/tavily_search_tool.cpython-313.pyc +0 -0
- tools/tavily_search_tool.py +48 -0
__pycache__/agent.cpython-313.pyc
CHANGED
|
Binary files a/__pycache__/agent.cpython-313.pyc and b/__pycache__/agent.cpython-313.pyc differ
|
|
|
agent.py
CHANGED
|
@@ -23,7 +23,10 @@ class TeacherStudentAgentWorkflow:
|
|
| 23 |
description="Useful for searching the web for information on a given topic and recording notes on the topic.",
|
| 24 |
system_prompt=(
|
| 25 |
"You are the ResearchAgent that can search the web for information on a given topic and record notes on the topic. "
|
| 26 |
-
"
|
|
|
|
|
|
|
|
|
|
| 27 |
"You should have at least some notes on a topic before handing off control to the WriteAgent."
|
| 28 |
),
|
| 29 |
llm=self.llm,
|
|
@@ -36,8 +39,9 @@ class TeacherStudentAgentWorkflow:
|
|
| 36 |
description="Useful for writing a report on a given topic.",
|
| 37 |
system_prompt=(
|
| 38 |
"You are the WriteAgent that can write a report on a given topic. "
|
|
|
|
| 39 |
"Your report should be in a markdown format. The content should be grounded in the research notes. "
|
| 40 |
-
"Once the report is written,
|
| 41 |
),
|
| 42 |
llm=self.llm,
|
| 43 |
tools=[write_report],
|
|
@@ -48,9 +52,10 @@ class TeacherStudentAgentWorkflow:
|
|
| 48 |
name="ReviewAgent",
|
| 49 |
description="Useful for reviewing a report and providing feedback.",
|
| 50 |
system_prompt=(
|
| 51 |
-
"You are the ReviewAgent that can review the
|
| 52 |
-
"
|
| 53 |
-
"
|
|
|
|
| 54 |
),
|
| 55 |
llm=self.llm,
|
| 56 |
tools=[review_report],
|
|
|
|
| 23 |
description="Useful for searching the web for information on a given topic and recording notes on the topic.",
|
| 24 |
system_prompt=(
|
| 25 |
"You are the ResearchAgent that can search the web for information on a given topic and record notes on the topic. "
|
| 26 |
+
"IMPORTANT: Never make duplicate tool calls. Each tool call should be unique and purposeful. "
|
| 27 |
+
"Process: 1) Search for information ONCE with a clear query, 2) Record the notes ONCE with a descriptive title, "
|
| 28 |
+
"3) Only search again if you need different/additional information with a different query. "
|
| 29 |
+
"Once you have sufficient notes recorded, immediately hand off control to the WriteAgent. "
|
| 30 |
"You should have at least some notes on a topic before handing off control to the WriteAgent."
|
| 31 |
),
|
| 32 |
llm=self.llm,
|
|
|
|
| 39 |
description="Useful for writing a report on a given topic.",
|
| 40 |
system_prompt=(
|
| 41 |
"You are the WriteAgent that can write a report on a given topic. "
|
| 42 |
+
"IMPORTANT: Never make duplicate tool calls. Write the report only ONCE with all available research. "
|
| 43 |
"Your report should be in a markdown format. The content should be grounded in the research notes. "
|
| 44 |
+
"Once the report is written ONCE, immediately hand off control to the ReviewAgent for feedback."
|
| 45 |
),
|
| 46 |
llm=self.llm,
|
| 47 |
tools=[write_report],
|
|
|
|
| 52 |
name="ReviewAgent",
|
| 53 |
description="Useful for reviewing a report and providing feedback.",
|
| 54 |
system_prompt=(
|
| 55 |
+
"You are the ReviewAgent that can review the report and provide feedback. "
|
| 56 |
+
"IMPORTANT: Never make duplicate tool calls. Review the report only ONCE and provide clear feedback. "
|
| 57 |
+
"Your review should either approve the current report or request specific changes for the WriteAgent to implement. "
|
| 58 |
+
"If you have feedback that requires changes, hand off control to the WriteAgent to implement the changes after submitting the review ONCE."
|
| 59 |
),
|
| 60 |
llm=self.llm,
|
| 61 |
tools=[review_report],
|
app.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from gradio import ChatMessage
|
| 3 |
import asyncio
|
|
|
|
|
|
|
| 4 |
from agent import TeacherStudentAgentWorkflow
|
| 5 |
from llama_index.core.agent.workflow import (
|
| 6 |
AgentInput,
|
|
@@ -43,6 +45,10 @@ async def chat_with_agent(message, history):
|
|
| 43 |
final_report = None
|
| 44 |
workflow_state = {}
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
async for event in handler.stream_events():
|
| 47 |
# Check if we switched to a new agent
|
| 48 |
if (
|
|
@@ -51,6 +57,9 @@ async def chat_with_agent(message, history):
|
|
| 51 |
):
|
| 52 |
current_agent = event.current_agent_name
|
| 53 |
|
|
|
|
|
|
|
|
|
|
| 54 |
# Add agent header message
|
| 55 |
agent_header = ChatMessage(
|
| 56 |
role="assistant",
|
|
@@ -84,14 +93,36 @@ async def chat_with_agent(message, history):
|
|
| 84 |
yield history, final_report
|
| 85 |
|
| 86 |
elif isinstance(event, ToolCall):
|
| 87 |
-
#
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
)
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
elif isinstance(event, ToolCallResult):
|
| 97 |
# Show tool results
|
|
@@ -99,21 +130,33 @@ async def chat_with_agent(message, history):
|
|
| 99 |
if len(result_content) > 500:
|
| 100 |
result_content = result_content[:500] + "..."
|
| 101 |
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
history.append(result_msg)
|
| 108 |
|
| 109 |
-
# Track tool results to detect report writing and review approval
|
| 110 |
-
if
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
|
|
|
| 117 |
|
| 118 |
yield history, final_report
|
| 119 |
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from gradio import ChatMessage
|
| 3 |
import asyncio
|
| 4 |
+
import json
|
| 5 |
+
import hashlib
|
| 6 |
from agent import TeacherStudentAgentWorkflow
|
| 7 |
from llama_index.core.agent.workflow import (
|
| 8 |
AgentInput,
|
|
|
|
| 45 |
final_report = None
|
| 46 |
workflow_state = {}
|
| 47 |
|
| 48 |
+
# Track recent tool calls to prevent UI duplicates
|
| 49 |
+
recent_tool_calls = set()
|
| 50 |
+
max_cache_size = 100 # Limit cache size to prevent memory issues
|
| 51 |
+
|
| 52 |
async for event in handler.stream_events():
|
| 53 |
# Check if we switched to a new agent
|
| 54 |
if (
|
|
|
|
| 57 |
):
|
| 58 |
current_agent = event.current_agent_name
|
| 59 |
|
| 60 |
+
# Clear tool call tracking when switching agents
|
| 61 |
+
recent_tool_calls.clear()
|
| 62 |
+
|
| 63 |
# Add agent header message
|
| 64 |
agent_header = ChatMessage(
|
| 65 |
role="assistant",
|
|
|
|
| 93 |
yield history, final_report
|
| 94 |
|
| 95 |
elif isinstance(event, ToolCall):
|
| 96 |
+
# Create a unique identifier for this tool call using a more robust approach
|
| 97 |
+
try:
|
| 98 |
+
# Sort the arguments to ensure consistent hashing
|
| 99 |
+
sorted_kwargs = json.dumps(event.tool_kwargs, sort_keys=True, default=str)
|
| 100 |
+
tool_call_id = f"{event.tool_name}_{hashlib.md5(sorted_kwargs.encode()).hexdigest()}"
|
| 101 |
+
except (TypeError, ValueError):
|
| 102 |
+
# Fallback for non-serializable arguments
|
| 103 |
+
tool_call_id = f"{event.tool_name}_{hash(str(event.tool_kwargs))}"
|
| 104 |
+
|
| 105 |
+
# Only show if we haven't seen this exact tool call recently
|
| 106 |
+
if tool_call_id not in recent_tool_calls:
|
| 107 |
+
recent_tool_calls.add(tool_call_id)
|
| 108 |
+
|
| 109 |
+
# Clean up cache if it gets too large
|
| 110 |
+
if len(recent_tool_calls) > max_cache_size:
|
| 111 |
+
# Remove some old entries (keep the most recent half)
|
| 112 |
+
recent_tool_calls = set(list(recent_tool_calls)[-max_cache_size//2:])
|
| 113 |
+
|
| 114 |
+
# Show tool being called
|
| 115 |
+
tool_msg = ChatMessage(
|
| 116 |
+
role="assistant",
|
| 117 |
+
content=f"🔨 **Calling Tool:** {event.tool_name}\n**Arguments:** {event.tool_kwargs}",
|
| 118 |
+
metadata={"title": f"{current_agent} - Tool Call"}
|
| 119 |
+
)
|
| 120 |
+
history.append(tool_msg)
|
| 121 |
+
yield history, final_report
|
| 122 |
+
else:
|
| 123 |
+
# Debug: Log duplicate detection (remove this in production)
|
| 124 |
+
print(f"🚫 Duplicate tool call detected and skipped: {event.tool_name} with args {event.tool_kwargs}")
|
| 125 |
+
# If it's a duplicate, we simply skip displaying it
|
| 126 |
|
| 127 |
elif isinstance(event, ToolCallResult):
|
| 128 |
# Show tool results
|
|
|
|
| 130 |
if len(result_content) > 500:
|
| 131 |
result_content = result_content[:500] + "..."
|
| 132 |
|
| 133 |
+
# Check if this is a duplicate detection message
|
| 134 |
+
is_duplicate = any(word in result_content.lower() for word in ["duplicate", "skipping"])
|
| 135 |
+
|
| 136 |
+
if is_duplicate:
|
| 137 |
+
result_msg = ChatMessage(
|
| 138 |
+
role="assistant",
|
| 139 |
+
content=f"⚠️ **Duplicate Detection ({event.tool_name}):**\n{result_content}",
|
| 140 |
+
metadata={"title": f"{current_agent} - Duplicate Skipped"}
|
| 141 |
+
)
|
| 142 |
+
else:
|
| 143 |
+
result_msg = ChatMessage(
|
| 144 |
+
role="assistant",
|
| 145 |
+
content=f"🔧 **Tool Result ({event.tool_name}):**\n{result_content}",
|
| 146 |
+
metadata={"title": f"{current_agent} - Tool Result"}
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
history.append(result_msg)
|
| 150 |
|
| 151 |
+
# Track tool results to detect report writing and review approval (only for non-duplicates)
|
| 152 |
+
if not is_duplicate:
|
| 153 |
+
if event.tool_name == "write_report":
|
| 154 |
+
workflow_state["has_report"] = True
|
| 155 |
+
elif event.tool_name == "review_report" and current_agent == "ReviewAgent":
|
| 156 |
+
workflow_state["has_review"] = True
|
| 157 |
+
# Check if review indicates approval
|
| 158 |
+
if any(word in result_content.lower() for word in ["approved", "ready", "good", "excellent"]):
|
| 159 |
+
workflow_state["review_approved"] = True
|
| 160 |
|
| 161 |
yield history, final_report
|
| 162 |
|
test_deduplication.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Test script to verify that the deduplication mechanism works correctly.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import sys
|
| 8 |
+
import os
|
| 9 |
+
sys.path.append(os.path.dirname(__file__))
|
| 10 |
+
|
| 11 |
+
from tools.tavily_search_tool import search_web, _should_execute_call
|
| 12 |
+
|
| 13 |
+
async def test_deduplication():
|
| 14 |
+
"""Test that duplicate tool calls are properly detected and prevented."""
|
| 15 |
+
print("Testing deduplication mechanism...")
|
| 16 |
+
|
| 17 |
+
# Test 1: Same query should be deduplicated
|
| 18 |
+
print("\n1. Testing search_web deduplication:")
|
| 19 |
+
query = "test query for deduplication"
|
| 20 |
+
|
| 21 |
+
print(f"First call with query: '{query}'")
|
| 22 |
+
result1 = await search_web(query)
|
| 23 |
+
print(f"Result: {result1[:100]}...")
|
| 24 |
+
|
| 25 |
+
print(f"Second call with same query: '{query}'")
|
| 26 |
+
result2 = await search_web(query)
|
| 27 |
+
print(f"Result: {result2}")
|
| 28 |
+
|
| 29 |
+
# Test 2: Direct deduplication function
|
| 30 |
+
print("\n2. Testing _should_execute_call function:")
|
| 31 |
+
should_execute_1 = _should_execute_call("test_tool", arg1="value1", arg2="value2")
|
| 32 |
+
print(f"First call should execute: {should_execute_1}")
|
| 33 |
+
|
| 34 |
+
should_execute_2 = _should_execute_call("test_tool", arg1="value1", arg2="value2")
|
| 35 |
+
print(f"Second call should execute: {should_execute_2}")
|
| 36 |
+
|
| 37 |
+
should_execute_3 = _should_execute_call("test_tool", arg1="value1", arg2="different_value")
|
| 38 |
+
print(f"Third call with different args should execute: {should_execute_3}")
|
| 39 |
+
|
| 40 |
+
print("\n✅ Deduplication test completed!")
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
asyncio.run(test_deduplication())
|
tools/__pycache__/tavily_search_tool.cpython-313.pyc
CHANGED
|
Binary files a/tools/__pycache__/tavily_search_tool.cpython-313.pyc and b/tools/__pycache__/tavily_search_tool.cpython-313.pyc differ
|
|
|
tools/tavily_search_tool.py
CHANGED
|
@@ -2,17 +2,57 @@ from tavily import AsyncTavilyClient
|
|
| 2 |
from llama_index.core.workflow import Context
|
| 3 |
from dotenv import load_dotenv
|
| 4 |
import os
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
load_dotenv(os.path.join(os.path.dirname(__file__), '../env.local'))
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
async def search_web(query: str) -> str:
|
| 9 |
"""Useful for using the web to answer questions."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
client = AsyncTavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
|
| 11 |
return str(await client.search(query))
|
| 12 |
|
| 13 |
|
| 14 |
async def record_notes(ctx: Context, notes: str, notes_title: str) -> str:
|
| 15 |
"""Useful for recording notes on a given topic. Your input should be notes with a title to save the notes under."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
current_state = await ctx.get("state")
|
| 17 |
if "research_notes" not in current_state:
|
| 18 |
current_state["research_notes"] = {}
|
|
@@ -23,6 +63,10 @@ async def record_notes(ctx: Context, notes: str, notes_title: str) -> str:
|
|
| 23 |
|
| 24 |
async def write_report(ctx: Context, report_content: str) -> str:
|
| 25 |
"""Useful for writing a report on a given topic. Your input should be a markdown formatted report."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
current_state = await ctx.get("state")
|
| 27 |
current_state["report_content"] = report_content
|
| 28 |
await ctx.set("state", current_state)
|
|
@@ -31,6 +75,10 @@ async def write_report(ctx: Context, report_content: str) -> str:
|
|
| 31 |
|
| 32 |
async def review_report(ctx: Context, review: str) -> str:
|
| 33 |
"""Useful for reviewing a report and providing feedback. Your input should be a review of the report."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
current_state = await ctx.get("state")
|
| 35 |
current_state["review"] = review
|
| 36 |
await ctx.set("state", current_state)
|
|
|
|
| 2 |
from llama_index.core.workflow import Context
|
| 3 |
from dotenv import load_dotenv
|
| 4 |
import os
|
| 5 |
+
import time
|
| 6 |
+
import hashlib
|
| 7 |
+
import json
|
| 8 |
|
| 9 |
load_dotenv(os.path.join(os.path.dirname(__file__), '../env.local'))
|
| 10 |
|
| 11 |
+
# Global cache to track recent tool calls and prevent duplicates
|
| 12 |
+
_tool_call_cache = {}
|
| 13 |
+
_cache_timeout = 30 # 30 seconds timeout for deduplication
|
| 14 |
+
|
| 15 |
+
def _generate_call_hash(tool_name: str, **kwargs) -> str:
|
| 16 |
+
"""Generate a hash for tool call deduplication."""
|
| 17 |
+
# Create a stable hash from tool name and arguments
|
| 18 |
+
call_data = {"tool": tool_name, "args": kwargs}
|
| 19 |
+
call_str = json.dumps(call_data, sort_keys=True)
|
| 20 |
+
return hashlib.md5(call_str.encode()).hexdigest()
|
| 21 |
+
|
| 22 |
+
def _should_execute_call(tool_name: str, **kwargs) -> bool:
|
| 23 |
+
"""Check if a tool call should be executed or if it's a duplicate."""
|
| 24 |
+
current_time = time.time()
|
| 25 |
+
call_hash = _generate_call_hash(tool_name, **kwargs)
|
| 26 |
+
|
| 27 |
+
# Clean up old cache entries
|
| 28 |
+
expired_keys = [k for k, v in _tool_call_cache.items() if current_time - v > _cache_timeout]
|
| 29 |
+
for key in expired_keys:
|
| 30 |
+
del _tool_call_cache[key]
|
| 31 |
+
|
| 32 |
+
# Check if this call was made recently
|
| 33 |
+
if call_hash in _tool_call_cache:
|
| 34 |
+
return False
|
| 35 |
+
|
| 36 |
+
# Record this call
|
| 37 |
+
_tool_call_cache[call_hash] = current_time
|
| 38 |
+
return True
|
| 39 |
+
|
| 40 |
async def search_web(query: str) -> str:
|
| 41 |
"""Useful for using the web to answer questions."""
|
| 42 |
+
# Check for duplicate calls
|
| 43 |
+
if not _should_execute_call("search_web", query=query):
|
| 44 |
+
return f"Duplicate search call detected for query: '{query}'. Skipping to avoid redundant API calls."
|
| 45 |
+
|
| 46 |
client = AsyncTavilyClient(api_key=os.getenv("TAVILY_API_KEY"))
|
| 47 |
return str(await client.search(query))
|
| 48 |
|
| 49 |
|
| 50 |
async def record_notes(ctx: Context, notes: str, notes_title: str) -> str:
|
| 51 |
"""Useful for recording notes on a given topic. Your input should be notes with a title to save the notes under."""
|
| 52 |
+
# Check for duplicate calls
|
| 53 |
+
if not _should_execute_call("record_notes", notes=notes, notes_title=notes_title):
|
| 54 |
+
return f"Duplicate notes recording detected for title: '{notes_title}'. Skipping to avoid redundant recording."
|
| 55 |
+
|
| 56 |
current_state = await ctx.get("state")
|
| 57 |
if "research_notes" not in current_state:
|
| 58 |
current_state["research_notes"] = {}
|
|
|
|
| 63 |
|
| 64 |
async def write_report(ctx: Context, report_content: str) -> str:
|
| 65 |
"""Useful for writing a report on a given topic. Your input should be a markdown formatted report."""
|
| 66 |
+
# Check for duplicate calls
|
| 67 |
+
if not _should_execute_call("write_report", report_content=report_content):
|
| 68 |
+
return "Duplicate report writing detected. Skipping to avoid redundant report generation."
|
| 69 |
+
|
| 70 |
current_state = await ctx.get("state")
|
| 71 |
current_state["report_content"] = report_content
|
| 72 |
await ctx.set("state", current_state)
|
|
|
|
| 75 |
|
| 76 |
async def review_report(ctx: Context, review: str) -> str:
|
| 77 |
"""Useful for reviewing a report and providing feedback. Your input should be a review of the report."""
|
| 78 |
+
# Check for duplicate calls
|
| 79 |
+
if not _should_execute_call("review_report", review=review):
|
| 80 |
+
return "Duplicate review detected. Skipping to avoid redundant review submission."
|
| 81 |
+
|
| 82 |
current_state = await ctx.get("state")
|
| 83 |
current_state["review"] = review
|
| 84 |
await ctx.set("state", current_state)
|