VictorLJZ commited on
Commit
c90e4b6
·
1 Parent(s): ec85157

final fixed version

Browse files
benchmarking/llm_providers/medrax_provider.py CHANGED
@@ -6,7 +6,7 @@ import re
6
  from pathlib import Path
7
 
8
  from .base import LLMProvider, LLMRequest, LLMResponse
9
- from langchain_core.messages import AIMessage, AIMessageChunk, ToolMessage
10
 
11
  from medrax.rag.rag import RAGConfig
12
  from main import initialize_agent
@@ -37,17 +37,19 @@ class MedRAXProvider(LLMProvider):
37
  selected_tools = [
38
  # "ImageVisualizerTool", # For displaying images in the UI
39
  # "DicomProcessorTool", # For processing DICOM medical image files
40
- "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
41
- "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
42
  # "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
43
- "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
44
- "XRayVQATool", # For visual question answering on X-rays
45
  # "LlavaMedTool", # For multimodal medical image understanding
46
- "XRayPhraseGroundingTool", # For locating described features in X-rays
47
  # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
48
- "WebBrowserTool", # For web browsing and search capabilities
49
- "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
50
  # "PythonSandboxTool", # Add the Python sandbox tool
 
 
 
 
 
 
 
 
 
51
  ]
52
 
53
  rag_config = RAGConfig(
@@ -70,9 +72,9 @@ class MedRAXProvider(LLMProvider):
70
  agent, tools_dict = initialize_agent(
71
  prompt_file="medrax/docs/system_prompts.txt",
72
  tools_to_use=selected_tools,
73
- model_dir="/model-weights",
74
  temp_dir="temp", # Change this to the path of the temporary directory
75
- device="cuda:0",
76
  model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
77
  temperature=0.3,
78
  top_p=0.95,
@@ -133,38 +135,31 @@ class MedRAXProvider(LLMProvider):
133
  print(f"File successfully copied: {dest_path}")
134
 
135
  # Add image path message for tools
136
- messages.append({
137
- "role": "user",
138
- "content": f"image_path: {dest_path}"
139
- })
140
 
141
  # Add image content for multimodal LLM
142
  with open(image_path, "rb") as img_file:
143
  img_base64 = self._encode_image(image_path)
144
 
145
- messages.append({
146
- "role": "user",
147
- "content": [{
148
- "type": "image_url",
149
- "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}
150
- }]
151
- })
152
 
153
  # Add text message
154
- messages.append({
155
- "role": "user",
156
- "content": [{
157
  "type": "text",
158
  "text": request.text
159
- }]
160
- })
 
 
161
 
162
  # Run the agent with proper message type handling
163
- accumulated_content = ""
164
  final_response = ""
165
- chat_history = []
166
  chunk_history = []
167
-
168
 
169
  for chunk in self.agent.workflow.stream(
170
  {"messages": messages},
@@ -182,7 +177,6 @@ class MedRAXProvider(LLMProvider):
182
  serializable_chunk = {
183
  "node_name": node_name,
184
  "node_type": type(node_output).__name__,
185
- "has_messages": "messages" in node_output if isinstance(node_output, dict) else False
186
  }
187
 
188
  # Log messages in this chunk
@@ -203,39 +197,13 @@ class MedRAXProvider(LLMProvider):
203
  continue
204
 
205
  for msg in node_output["messages"]:
206
- if isinstance(msg, AIMessageChunk) and msg.content:
207
- # Accumulate streaming LLM content
208
- accumulated_content += msg.content
209
- chat_history.append({
210
- "role": "AI message chunk",
211
- "content": msg.content
212
- })
213
-
214
- elif isinstance(msg, AIMessage):
215
- # Handle final LLM response
216
- if msg.content:
217
- # Clean up the content (remove temp paths, etc.)
218
- final_response = re.sub(r"temp/[^\s]*", "", msg.content).strip()
219
- # Reset accumulated content since we have the final response
220
- accumulated_content = ""
221
- chat_history.append({
222
- "role": "AI message",
223
- "content": msg.content
224
- })
225
- elif isinstance(msg, ToolMessage):
226
- # Handle tool outputs (store for debugging but don't use as final answer)
227
- chat_history.append({
228
- "role": "tool message",
229
- "content": msg.content
230
- })
231
 
232
  # Determine the final response
233
- # Priority: final_response > accumulated_content > fallback
234
  if final_response:
235
  response_content = final_response
236
- elif accumulated_content:
237
- # If no final AIMessage was received, use accumulated content
238
- response_content = accumulated_content.strip()
239
  else:
240
  # Fallback if no LLM response was received
241
  response_content = "No response generated"
@@ -249,7 +217,6 @@ class MedRAXProvider(LLMProvider):
249
  raw_response={
250
  "thread_id": thread_id,
251
  "image_paths": image_paths,
252
- "chat_history": chat_history,
253
  "chunk_history": chunk_history,
254
  }
255
  )
 
6
  from pathlib import Path
7
 
8
  from .base import LLMProvider, LLMRequest, LLMResponse
9
+ from langchain_core.messages import AIMessage, HumanMessage
10
 
11
  from medrax.rag.rag import RAGConfig
12
  from main import initialize_agent
 
37
  selected_tools = [
38
  # "ImageVisualizerTool", # For displaying images in the UI
39
  # "DicomProcessorTool", # For processing DICOM medical image files
 
 
40
  # "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
 
 
41
  # "LlavaMedTool", # For multimodal medical image understanding
 
42
  # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
 
 
43
  # "PythonSandboxTool", # Add the Python sandbox tool
44
+
45
+ # "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
46
+ # "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
47
+ # "WebBrowserTool", # For web browsing and search capabilities
48
+ # "XRayVQATool", # For visual question answering on X-rays
49
+ "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
50
+
51
+ # "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
52
+ # "XRayPhraseGroundingTool", # For locating described features in X-rays
53
  ]
54
 
55
  rag_config = RAGConfig(
 
72
  agent, tools_dict = initialize_agent(
73
  prompt_file="medrax/docs/system_prompts.txt",
74
  tools_to_use=selected_tools,
75
+ model_dir="model-weights",
76
  temp_dir="temp", # Change this to the path of the temporary directory
77
+ device="cpu",
78
  model=self.model_name, # Change this to the model you want to use, e.g. gpt-4.1-2025-04-14, gemini-2.5-pro
79
  temperature=0.3,
80
  top_p=0.95,
 
135
  print(f"File successfully copied: {dest_path}")
136
 
137
  # Add image path message for tools
138
+ messages.append(HumanMessage(content=f"image_path: {dest_path}"))
 
 
 
139
 
140
  # Add image content for multimodal LLM
141
  with open(image_path, "rb") as img_file:
142
  img_base64 = self._encode_image(image_path)
143
 
144
+ messages.append(HumanMessage(content=[{
145
+ "type": "image_url",
146
+ "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}
147
+ }]))
 
 
 
148
 
149
  # Add text message
150
+ if request.images:
151
+ # If there are images, add text as part of multimodal content
152
+ messages.append(HumanMessage(content=[{
153
  "type": "text",
154
  "text": request.text
155
+ }]))
156
+ else:
157
+ # If no images, add text as simple string
158
+ messages.append(HumanMessage(content=request.text))
159
 
160
  # Run the agent with proper message type handling
 
161
  final_response = ""
 
162
  chunk_history = []
 
163
 
164
  for chunk in self.agent.workflow.stream(
165
  {"messages": messages},
 
177
  serializable_chunk = {
178
  "node_name": node_name,
179
  "node_type": type(node_output).__name__,
 
180
  }
181
 
182
  # Log messages in this chunk
 
197
  continue
198
 
199
  for msg in node_output["messages"]:
200
+ if isinstance(msg, AIMessage) and msg.content:
201
+ # Clean up the content (remove temp paths, etc.)
202
+ final_response = re.sub(r"temp/[^\s]*", "", msg.content).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
  # Determine the final response
 
205
  if final_response:
206
  response_content = final_response
 
 
 
207
  else:
208
  # Fallback if no LLM response was received
209
  response_content = "No response generated"
 
217
  raw_response={
218
  "thread_id": thread_id,
219
  "image_paths": image_paths,
 
220
  "chunk_history": chunk_history,
221
  }
222
  )