Adibvafa commited on
Commit
b5eed77
·
1 Parent(s): fbb4118

Ensure interface is comparible with new agent

Browse files
Files changed (3) hide show
  1. interface.py +92 -37
  2. main.py +4 -0
  3. medrax/agent/agent.py +3 -0
interface.py CHANGED
@@ -1,11 +1,14 @@
1
  import re
2
  import base64
 
 
3
  import gradio as gr
4
  from pathlib import Path
5
  import time
6
  import shutil
7
  from typing import AsyncGenerator, List, Optional, Tuple
8
  from gradio import ChatMessage
 
9
 
10
 
11
  class ChatInterface:
@@ -32,6 +35,7 @@ class ChatInterface:
32
  # Separate storage for original and display paths
33
  self.original_file_path = None # For LLM (.dcm or other)
34
  self.display_file_path = None # For UI (always viewable format)
 
35
 
36
  def handle_upload(self, file_path: str) -> str:
37
  """
@@ -132,48 +136,99 @@ class ChatInterface:
132
  messages.append({"role": "user", "content": [{"type": "text", "text": message}]})
133
 
134
  try:
135
- for event in self.agent.workflow.stream(
136
- {"messages": messages}, {"configurable": {"thread_id": self.current_thread_id}}
 
 
 
 
 
137
  ):
138
- if isinstance(event, dict):
139
- if "agent" in event:
140
- content = event["agent"]["messages"][-1].content
141
- if content:
142
- content = re.sub(r"temp/[^\s]*", "", content)
143
- chat_history.append(ChatMessage(role="assistant", content=content))
 
 
 
 
 
 
 
 
 
 
 
144
  yield chat_history, self.display_file_path, ""
145
 
146
- elif "tools" in event:
147
- for message in event["tools"]["messages"]:
148
- tool_name = message.name
149
- tool_result = eval(message.content)[0]
150
-
151
- if tool_result:
152
- metadata = {"title": f"🖼️ Image from tool: {tool_name}"}
153
- formatted_result = " ".join(
154
- line.strip() for line in str(tool_result).splitlines()
155
- ).strip()
156
- metadata["description"] = formatted_result
157
- chat_history.append(
158
- ChatMessage(
159
- role="assistant",
160
- content=formatted_result,
161
- metadata=metadata,
162
  )
163
- )
164
-
165
- # For image_visualizer, use display path
166
- if tool_name == "image_visualizer":
167
- self.display_file_path = tool_result["image_path"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  chat_history.append(
169
  ChatMessage(
170
  role="assistant",
171
- # content=gr.Image(value=self.display_file_path),
172
- content={"path": self.display_file_path},
173
  )
174
  )
175
-
176
- yield chat_history, self.display_file_path, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  except Exception as e:
179
  chat_history.append(
@@ -181,7 +236,7 @@ class ChatInterface:
181
  role="assistant", content=f"❌ Error: {str(e)}", metadata={"title": "Error"}
182
  )
183
  )
184
- yield chat_history, self.display_file_path
185
 
186
 
187
  def create_demo(agent, tools_dict):
@@ -207,10 +262,10 @@ def create_demo(agent, tools_dict):
207
  )
208
 
209
  with gr.Row():
210
- with gr.Column(scale=3):
211
  chatbot = gr.Chatbot(
212
  [],
213
- height=800,
214
  container=True,
215
  show_label=True,
216
  elem_classes="chat-box",
@@ -231,7 +286,7 @@ def create_demo(agent, tools_dict):
231
 
232
  with gr.Column(scale=3):
233
  image_display = gr.Image(
234
- label="Image", type="filepath", height=700, container=True
235
  )
236
  with gr.Row():
237
  upload_button = gr.UploadButton(
 
1
  import re
2
  import base64
3
+ import json
4
+ import ast
5
  import gradio as gr
6
  from pathlib import Path
7
  import time
8
  import shutil
9
  from typing import AsyncGenerator, List, Optional, Tuple
10
  from gradio import ChatMessage
11
+ from langchain_core.messages import AIMessage, AIMessageChunk, ToolMessage
12
 
13
 
14
  class ChatInterface:
 
35
  # Separate storage for original and display paths
36
  self.original_file_path = None # For LLM (.dcm or other)
37
  self.display_file_path = None # For UI (always viewable format)
38
+ self.pending_tool_calls = {}
39
 
40
  def handle_upload(self, file_path: str) -> str:
41
  """
 
136
  messages.append({"role": "user", "content": [{"type": "text", "text": message}]})
137
 
138
  try:
139
+ accumulated_content = ""
140
+ final_message = None
141
+
142
+ for chunk in self.agent.workflow.stream(
143
+ {"messages": messages},
144
+ {"configurable": {"thread_id": self.current_thread_id}},
145
+ stream_mode="updates",
146
  ):
147
+ if not isinstance(chunk, dict):
148
+ continue
149
+
150
+ for node_name, node_output in chunk.items():
151
+ if "messages" not in node_output:
152
+ continue
153
+
154
+ for msg in node_output["messages"]:
155
+ if isinstance(msg, AIMessageChunk) and msg.content:
156
+ accumulated_content += msg.content
157
+ if final_message is None:
158
+ final_message = ChatMessage(
159
+ role="assistant", content=accumulated_content
160
+ )
161
+ chat_history.append(final_message)
162
+ else:
163
+ final_message.content = accumulated_content
164
  yield chat_history, self.display_file_path, ""
165
 
166
+ elif isinstance(msg, AIMessage):
167
+ if msg.content:
168
+ final_content = re.sub(r"temp/[^\s]*", "", msg.content).strip()
169
+ if final_message:
170
+ final_message.content = final_content
171
+ else:
172
+ chat_history.append(
173
+ ChatMessage(role="assistant", content=final_content)
 
 
 
 
 
 
 
 
174
  )
175
+ yield chat_history, self.display_file_path, ""
176
+
177
+ if msg.tool_calls:
178
+ for tool_call in msg.tool_calls:
179
+ self.pending_tool_calls[tool_call["id"]] = {
180
+ "name": tool_call["name"],
181
+ "args": tool_call["args"],
182
+ }
183
+
184
+ final_message = None
185
+ accumulated_content = ""
186
+
187
+ elif isinstance(msg, ToolMessage):
188
+ tool_call_id = msg.tool_call_id
189
+ if tool_call_id in self.pending_tool_calls:
190
+ pending_call = self.pending_tool_calls.pop(tool_call_id)
191
+ tool_name = pending_call["name"]
192
+ tool_args = pending_call["args"]
193
+
194
+ try:
195
+ tool_output_json = json.loads(msg.content)
196
+ tool_output_str = json.dumps(tool_output_json, indent=2)
197
+ except (json.JSONDecodeError, TypeError):
198
+ tool_output_str = str(msg.content)
199
+
200
+ tool_args_str = json.dumps(tool_args, indent=2)
201
+
202
+ description = f"**Input:**\n```json\n{tool_args_str}\n```\n\n**Output:**\n```json\n{tool_output_str}\n```"
203
+
204
+ metadata = {
205
+ "title": f"⚒️ Tool: {tool_name}",
206
+ "description": description,
207
+ "status": "done",
208
+ }
209
  chat_history.append(
210
  ChatMessage(
211
  role="assistant",
212
+ content=description,
213
+ metadata=metadata,
214
  )
215
  )
216
+ yield chat_history, self.display_file_path, ""
217
+
218
+ if tool_name == "image_visualizer":
219
+ try:
220
+ result = json.loads(msg.content)
221
+ if isinstance(result, dict) and "image_path" in result:
222
+ self.display_file_path = result["image_path"]
223
+ chat_history.append(
224
+ ChatMessage(
225
+ role="assistant",
226
+ content={"path": self.display_file_path},
227
+ )
228
+ )
229
+ yield chat_history, self.display_file_path, ""
230
+ except (json.JSONDecodeError, TypeError):
231
+ pass
232
 
233
  except Exception as e:
234
  chat_history.append(
 
236
  role="assistant", content=f"❌ Error: {str(e)}", metadata={"title": "Error"}
237
  )
238
  )
239
+ yield chat_history, self.display_file_path, ""
240
 
241
 
242
  def create_demo(agent, tools_dict):
 
262
  )
263
 
264
  with gr.Row():
265
+ with gr.Column(scale=5):
266
  chatbot = gr.Chatbot(
267
  [],
268
+ height=1000,
269
  container=True,
270
  show_label=True,
271
  elem_classes="chat-box",
 
286
 
287
  with gr.Column(scale=3):
288
  image_display = gr.Image(
289
+ label="Image", type="filepath", height=600, container=True
290
  )
291
  with gr.Row():
292
  upload_button = gr.UploadButton(
main.py CHANGED
@@ -43,6 +43,7 @@ def initialize_agent(
43
  top_p: float = 0.95,
44
  rag_config: Optional[RAGConfig] = None,
45
  model_kwargs: Dict[str, Any] = {},
 
46
  ):
47
  """Initialize the MedRAX agent with specified tools and configuration.
48
 
@@ -57,6 +58,7 @@ def initialize_agent(
57
  top_p (float, optional): Top P for the model. Defaults to 0.95.
58
  rag_config (RAGConfig, optional): Configuration for the RAG tool. Defaults to None.
59
  model_kwargs (dict, optional): Additional keyword arguments for model.
 
60
 
61
  Returns:
62
  Tuple[Agent, Dict[str, BaseTool]]: Initialized agent and dictionary of tool instances
@@ -119,6 +121,7 @@ def initialize_agent(
119
  log_dir="logs",
120
  system_prompt=prompt,
121
  checkpointer=checkpointer,
 
122
  )
123
  print("Agent initialized")
124
 
@@ -188,6 +191,7 @@ if __name__ == "__main__":
188
  top_p=0.95,
189
  model_kwargs=model_kwargs,
190
  rag_config=rag_config,
 
191
  )
192
 
193
  # Create and launch the web interface
 
43
  top_p: float = 0.95,
44
  rag_config: Optional[RAGConfig] = None,
45
  model_kwargs: Dict[str, Any] = {},
46
+ debug: bool = False,
47
  ):
48
  """Initialize the MedRAX agent with specified tools and configuration.
49
 
 
58
  top_p (float, optional): Top P for the model. Defaults to 0.95.
59
  rag_config (RAGConfig, optional): Configuration for the RAG tool. Defaults to None.
60
  model_kwargs (dict, optional): Additional keyword arguments for model.
61
+ debug (bool, optional): Whether to enable debug mode. Defaults to False.
62
 
63
  Returns:
64
  Tuple[Agent, Dict[str, BaseTool]]: Initialized agent and dictionary of tool instances
 
121
  log_dir="logs",
122
  system_prompt=prompt,
123
  checkpointer=checkpointer,
124
+ debug=debug,
125
  )
126
  print("Agent initialized")
127
 
 
191
  top_p=0.95,
192
  model_kwargs=model_kwargs,
193
  rag_config=rag_config,
194
+ debug=True,
195
  )
196
 
197
  # Create and launch the web interface
medrax/agent/agent.py CHANGED
@@ -71,6 +71,7 @@ class Agent:
71
  system_prompt: str = "",
72
  log_tools: bool = True,
73
  log_dir: Optional[str] = "logs",
 
74
  ):
75
  """
76
  Initialize the Agent.
@@ -82,6 +83,7 @@ class Agent:
82
  system_prompt (str, optional): System instructions. Defaults to "".
83
  log_tools (bool, optional): Whether to log tool calls. Defaults to True.
84
  log_dir (str, optional): Directory to save logs. Defaults to 'logs'.
 
85
  """
86
  self.system_prompt = system_prompt
87
  self.log_tools = log_tools
@@ -96,5 +98,6 @@ class Agent:
96
  checkpointer=checkpointer,
97
  state_schema=State,
98
  prompt=system_prompt if system_prompt else None,
 
99
  )
100
  self.tools = {t.name: t for t in tools}
 
71
  system_prompt: str = "",
72
  log_tools: bool = True,
73
  log_dir: Optional[str] = "logs",
74
+ debug: bool = False,
75
  ):
76
  """
77
  Initialize the Agent.
 
83
  system_prompt (str, optional): System instructions. Defaults to "".
84
  log_tools (bool, optional): Whether to log tool calls. Defaults to True.
85
  log_dir (str, optional): Directory to save logs. Defaults to 'logs'.
86
+ debug (bool, optional): Whether to enable debug mode. Defaults to False.
87
  """
88
  self.system_prompt = system_prompt
89
  self.log_tools = log_tools
 
98
  checkpointer=checkpointer,
99
  state_schema=State,
100
  prompt=system_prompt if system_prompt else None,
101
+ debug=debug,
102
  )
103
  self.tools = {t.name: t for t in tools}