Charles Grandjean commited on
Commit
ace7446
Β·
1 Parent(s): 88ceb8e

fix tool calling chat_agent et doc_assistant

Browse files
Files changed (2) hide show
  1. agents/chat_agent.py +22 -21
  2. agents/doc_assistant.py +14 -12
agents/chat_agent.py CHANGED
@@ -105,27 +105,28 @@ class CyberLegalAgent:
105
  args["conversation_history"] = state.get("conversation_history", [])
106
  logger.info(f"πŸ“ Injecting conversation_history to {tool_call['name']}: {len(args['conversation_history'])} messages")
107
 
108
- # Inject jurisdiction for query_knowledge_graph tool
109
- if tool_call['name'] == "query_knowledge_graph":
110
- args["jurisdiction"] = state.get("jurisdiction")
111
- logger.info(f"🌍 Injecting jurisdiction: {args['jurisdiction']}")
112
-
113
- # Inject user_id for message_lawyer tool
114
- if tool_call['name'] == "message_lawyer":
115
- args["user_id"] = state.get("user_id")
116
- logger.info(f"πŸ‘€ Injecting user_id: {args['user_id']}")
117
-
118
- # Inject user_id for retrieve_lawyer_document tool
119
- if tool_call['name'] == "retrieve_lawyer_document":
120
- args["user_id"] = state.get("user_id")
121
- logger.info(f"πŸ“„ Injecting user_id for retrieve_lawyer_document: {args['user_id']}")
122
-
123
- # Inject user_id for create_draft_document tool
124
- if tool_call['name'] == "create_draft_document":
125
- args["user_id"] = state.get("user_id")
126
- logger.info(f"πŸ“ Injecting user_id for create_draft_document: {args['user_id']}")
127
-
128
- tool_call['name']="_" + tool_call['name']
 
129
 
130
  result = await tool_func.ainvoke(args)
131
  logger.info(f"πŸ”§ Tool {tool_call} returned: {result}")
 
105
  args["conversation_history"] = state.get("conversation_history", [])
106
  logger.info(f"πŸ“ Injecting conversation_history to {tool_call['name']}: {len(args['conversation_history'])} messages")
107
 
108
+ # Inject jurisdiction for query_knowledge_graph tool
109
+ if tool_call['name'] == "query_knowledge_graph":
110
+ args["jurisdiction"] = state.get("jurisdiction")
111
+ logger.info(f"🌍 Injecting jurisdiction: {args['jurisdiction']}")
112
+
113
+ # Inject user_id for message_lawyer tool
114
+ if tool_call['name'] == "message_lawyer":
115
+ args["user_id"] = state.get("user_id")
116
+ logger.info(f"πŸ‘€ Injecting user_id: {args['user_id']}")
117
+
118
+ # Inject user_id for retrieve_lawyer_document tool
119
+ if tool_call['name'] == "retrieve_lawyer_document":
120
+ args["user_id"] = state.get("user_id")
121
+ logger.info(f"πŸ“„ Injecting user_id for retrieve_lawyer_document: {args['user_id']}")
122
+
123
+ # Inject user_id for create_draft_document tool
124
+ if tool_call['name'] == "create_draft_document":
125
+ args["user_id"] = state.get("user_id")
126
+ logger.info(f"πŸ“ Injecting user_id for create_draft_document: {args['user_id']}")
127
+
128
+ # Convert to implementation name
129
+ tool_call['name'] = "_" + tool_call['name']
130
 
131
  result = await tool_func.ainvoke(args)
132
  logger.info(f"πŸ”§ Tool {tool_call} returned: {result}")
agents/doc_assistant.py CHANGED
@@ -132,7 +132,6 @@ class DocAssistant:
132
 
133
  for tool_call in last_message.tool_calls:
134
  tool_name = tool_call['name']
135
- args = tool_call['args'].copy()
136
 
137
  # Get the tool function directly from self.tools (add underscore prefix)
138
  tool_func = next((t for t in self.tools if t.name == "_" + tool_name), None)
@@ -140,6 +139,8 @@ class DocAssistant:
140
  if tool_func:
141
  try:
142
  args = tool_call['args'].copy()
 
 
143
  if tool_name == "edit_document":
144
  logger.info("πŸ“ edit_document tool called - invoking doc_editor_agent")
145
 
@@ -150,26 +151,27 @@ class DocAssistant:
150
  args["max_iterations"] = 10
151
  args["document_id"] = state.get("document_id")
152
  args["user_id"] = state.get("user_id")
153
- result = await tool_func.ainvoke(args)
154
- doc_text=result['doc_text']
155
- message=f"Document was edited with this summary :{result['final_summary']}"
156
- state['modified_document']=doc_text
157
- state['message']=message
158
-
159
- logger.info(f"βœ… edit_document called - ending router workflow")
160
 
161
  elif tool_name == "retrieve_lawyer_document":
162
  logger.info(f"πŸ“„ retrieve_lawyer_document tool called: {args.get('file_path')}")
163
 
164
  if "user_id" not in args and state.get("user_id"):
165
  args["user_id"] = state["user_id"]
166
-
167
-
168
- logger.info(f"Lauching tool:{tool_name} with args {json.dumps(args)}")
169
  result = await tool_func.ainvoke(args)
 
 
 
 
 
 
 
 
 
 
170
  intermediate_steps.append(
171
  ToolMessage(
172
- content=message,
173
  tool_call_id=tool_call['id'],
174
  name=tool_name
175
  )
 
132
 
133
  for tool_call in last_message.tool_calls:
134
  tool_name = tool_call['name']
 
135
 
136
  # Get the tool function directly from self.tools (add underscore prefix)
137
  tool_func = next((t for t in self.tools if t.name == "_" + tool_name), None)
 
139
  if tool_func:
140
  try:
141
  args = tool_call['args'].copy()
142
+ logger.info(f"Launching tool: {tool_name} with args {json.dumps(args, default=str)}")
143
+
144
  if tool_name == "edit_document":
145
  logger.info("πŸ“ edit_document tool called - invoking doc_editor_agent")
146
 
 
151
  args["max_iterations"] = 10
152
  args["document_id"] = state.get("document_id")
153
  args["user_id"] = state.get("user_id")
 
 
 
 
 
 
 
154
 
155
  elif tool_name == "retrieve_lawyer_document":
156
  logger.info(f"πŸ“„ retrieve_lawyer_document tool called: {args.get('file_path')}")
157
 
158
  if "user_id" not in args and state.get("user_id"):
159
  args["user_id"] = state["user_id"]
160
+
 
 
161
  result = await tool_func.ainvoke(args)
162
+
163
+ if tool_name == "edit_document":
164
+ doc_text = result['doc_text']
165
+ tool_result = f"Document was edited with this summary: {result['final_summary']}"
166
+ state['modified_document'] = doc_text
167
+ state['message'] = tool_result
168
+ logger.info(f"βœ… edit_document completed - ending router workflow")
169
+ else:
170
+ tool_result = result
171
+
172
  intermediate_steps.append(
173
  ToolMessage(
174
+ content=tool_result,
175
  tool_call_id=tool_call['id'],
176
  name=tool_name
177
  )