VictorLJZ commited on
Commit
d6cb1b4
·
1 Parent(s): c90e4b6
benchmarking/llm_providers/base.py CHANGED
@@ -25,7 +25,7 @@ class LLMResponse:
25
  content: str
26
  usage: Optional[Dict[str, Any]] = None
27
  duration: Optional[float] = None
28
- raw_response: Optional[Any] = None
29
 
30
 
31
  class LLMProvider(ABC):
 
25
  content: str
26
  usage: Optional[Dict[str, Any]] = None
27
  duration: Optional[float] = None
28
+ chunk_history: Optional[Any] = None
29
 
30
 
31
  class LLMProvider(ABC):
benchmarking/llm_providers/google_provider.py CHANGED
@@ -92,13 +92,11 @@ class GoogleProvider(LLMProvider):
92
  return LLMResponse(
93
  content=content,
94
  usage=usage,
95
- duration=duration,
96
- raw_response=response
97
  )
98
 
99
  except Exception as e:
100
  return LLMResponse(
101
  content=f"Error: {str(e)}",
102
- duration=time.time() - start_time,
103
- raw_response=None
104
  )
 
92
  return LLMResponse(
93
  content=content,
94
  usage=usage,
95
+ duration=duration
 
96
  )
97
 
98
  except Exception as e:
99
  return LLMResponse(
100
  content=f"Error: {str(e)}",
101
+ duration=time.time() - start_time
 
102
  )
benchmarking/llm_providers/medrax_provider.py CHANGED
@@ -35,21 +35,13 @@ class MedRAXProvider(LLMProvider):
35
  print("Starting server...")
36
 
37
  selected_tools = [
38
- # "ImageVisualizerTool", # For displaying images in the UI
39
- # "DicomProcessorTool", # For processing DICOM medical image files
40
- # "ChestXRaySegmentationTool", # For segmenting anatomical regions in chest X-rays
41
- # "LlavaMedTool", # For multimodal medical image understanding
42
- # "ChestXRayGeneratorTool", # For generating synthetic chest X-rays
43
- # "PythonSandboxTool", # Add the Python sandbox tool
44
-
45
- # "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
46
- # "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
47
- # "WebBrowserTool", # For web browsing and search capabilities
48
- # "XRayVQATool", # For visual question answering on X-rays
49
  "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
50
-
51
- # "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
52
- # "XRayPhraseGroundingTool", # For locating described features in X-rays
53
  ]
54
 
55
  rag_config = RAGConfig(
@@ -106,8 +98,7 @@ class MedRAXProvider(LLMProvider):
106
  if self.agent is None:
107
  return LLMResponse(
108
  content="Error: MedRAX agent not initialized",
109
- duration=time.time() - start_time,
110
- raw_response=None
111
  )
112
 
113
  try:
@@ -115,27 +106,12 @@ class MedRAXProvider(LLMProvider):
115
  messages = []
116
  thread_id = str(int(time.time() * 1000)) # Unique thread ID
117
 
118
- # Copy images to session temp directory and provide paths
119
- image_paths = []
120
  if request.images:
121
  valid_images = self._validate_image_paths(request.images)
122
  print(f"Processing {len(valid_images)} images")
123
  for i, image_path in enumerate(valid_images):
124
- print(f"Original image path: {image_path}")
125
- # Copy image to session temp directory
126
- dest_path = Path("temp") / f"image_{i}_{Path(image_path).name}"
127
- print(f"Destination path: {dest_path}")
128
- shutil.copy2(image_path, dest_path)
129
- image_paths.append(str(dest_path))
130
-
131
- # Verify file exists after copy
132
- if not dest_path.exists():
133
- print(f"ERROR: File not found after copy: {dest_path}")
134
- else:
135
- print(f"File successfully copied: {dest_path}")
136
-
137
  # Add image path message for tools
138
- messages.append(HumanMessage(content=f"image_path: {dest_path}"))
139
 
140
  # Add image content for multimodal LLM
141
  with open(image_path, "rb") as img_file:
@@ -214,16 +190,11 @@ class MedRAXProvider(LLMProvider):
214
  content=response_content,
215
  usage={"agent_tools": list(self.tools_dict.keys())},
216
  duration=duration,
217
- raw_response={
218
- "thread_id": thread_id,
219
- "image_paths": image_paths,
220
- "chunk_history": chunk_history,
221
- }
222
  )
223
 
224
  except Exception as e:
225
  return LLMResponse(
226
  content=f"Error: {str(e)}",
227
- duration=time.time() - start_time,
228
- raw_response=None
229
  )
 
35
  print("Starting server...")
36
 
37
  selected_tools = [
38
+ "ChestXRayReportGeneratorTool", # For generating medical reports from X-rays
39
+ "MedicalRAGTool", # For retrieval-augmented generation with medical knowledge
40
+ "WebBrowserTool", # For web browsing and search capabilities
 
 
 
 
 
 
 
 
41
  "TorchXRayVisionClassifierTool", # For classifying chest X-ray images using TorchXRayVision
42
+ "XRayVQATool", # For visual question answering on X-rays
43
+ "ArcPlusClassifierTool", # For advanced chest X-ray classification using ArcPlus
44
+ "XRayPhraseGroundingTool", # For locating described features in X-rays
45
  ]
46
 
47
  rag_config = RAGConfig(
 
98
  if self.agent is None:
99
  return LLMResponse(
100
  content="Error: MedRAX agent not initialized",
101
+ duration=time.time() - start_time
 
102
  )
103
 
104
  try:
 
106
  messages = []
107
  thread_id = str(int(time.time() * 1000)) # Unique thread ID
108
 
 
 
109
  if request.images:
110
  valid_images = self._validate_image_paths(request.images)
111
  print(f"Processing {len(valid_images)} images")
112
  for i, image_path in enumerate(valid_images):
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  # Add image path message for tools
114
+ messages.append(HumanMessage(content=f"image_path: {image_path}"))
115
 
116
  # Add image content for multimodal LLM
117
  with open(image_path, "rb") as img_file:
 
190
  content=response_content,
191
  usage={"agent_tools": list(self.tools_dict.keys())},
192
  duration=duration,
193
+ chunk_history=chunk_history
 
 
 
 
194
  )
195
 
196
  except Exception as e:
197
  return LLMResponse(
198
  content=f"Error: {str(e)}",
199
+ duration=time.time() - start_time
 
200
  )
benchmarking/llm_providers/openai_provider.py CHANGED
@@ -101,13 +101,11 @@ class OpenAIProvider(LLMProvider):
101
  return LLMResponse(
102
  content=content,
103
  usage=usage,
104
- duration=duration,
105
- raw_response=response
106
  )
107
 
108
  except Exception as e:
109
  return LLMResponse(
110
  content=f"Error: {str(e)}",
111
- duration=time.time() - start_time,
112
- raw_response=None
113
  )
 
101
  return LLMResponse(
102
  content=content,
103
  usage=usage,
104
+ duration=duration
 
105
  )
106
 
107
  except Exception as e:
108
  return LLMResponse(
109
  content=f"Error: {str(e)}",
110
+ duration=time.time() - start_time
 
111
  )
benchmarking/llm_providers/openrouter_provider.py CHANGED
@@ -78,12 +78,10 @@ class OpenRouterProvider(LLMProvider):
78
  return LLMResponse(
79
  content=content,
80
  usage=usage,
81
- duration=duration,
82
- raw_response=response
83
  )
84
  except Exception as e:
85
  return LLMResponse(
86
  content=f"Error: {str(e)}",
87
- duration=time.time() - start_time,
88
- raw_response=None
89
  )
 
78
  return LLMResponse(
79
  content=content,
80
  usage=usage,
81
+ duration=duration
 
82
  )
83
  except Exception as e:
84
  return LLMResponse(
85
  content=f"Error: {str(e)}",
86
+ duration=time.time() - start_time
 
87
  )
benchmarking/runner.py CHANGED
@@ -24,7 +24,7 @@ class BenchmarkResult:
24
  duration: float
25
  usage: Optional[Dict[str, Any]] = None
26
  error: Optional[str] = None
27
- raw_response: Optional[Dict[str, Any]] = None
28
  metadata: Optional[Dict[str, Any]] = None
29
 
30
 
@@ -226,7 +226,7 @@ class BenchmarkRunner:
226
  is_correct=is_correct,
227
  duration=duration,
228
  usage=response.usage,
229
- raw_response=response.raw_response,
230
  metadata={
231
  "data_point_metadata": data_point.metadata,
232
  "case_id": data_point.case_id,
@@ -245,7 +245,7 @@ class BenchmarkRunner:
245
  is_correct=False,
246
  duration=duration,
247
  error=str(e),
248
- raw_response=None,
249
  metadata={
250
  "data_point_metadata": data_point.metadata,
251
  "case_id": data_point.case_id,
@@ -318,6 +318,8 @@ class BenchmarkRunner:
318
 
319
  # Convert result to serializable format
320
  result_data = {
 
 
321
  "data_point_id": result.data_point_id,
322
  "question": result.question,
323
  "model_answer": result.model_answer,
@@ -326,10 +328,8 @@ class BenchmarkRunner:
326
  "duration": result.duration,
327
  "usage": result.usage,
328
  "error": result.error,
329
- "raw_response": result.raw_response,
330
- "metadata": result.metadata,
331
- "timestamp": datetime.now().isoformat(),
332
- "run_id": self.run_id,
333
  }
334
 
335
  # Save to file
 
24
  duration: float
25
  usage: Optional[Dict[str, Any]] = None
26
  error: Optional[str] = None
27
+ chunk_history: Optional[Dict[str, Any]] = None
28
  metadata: Optional[Dict[str, Any]] = None
29
 
30
 
 
226
  is_correct=is_correct,
227
  duration=duration,
228
  usage=response.usage,
229
+ chunk_history=response.chunk_history,
230
  metadata={
231
  "data_point_metadata": data_point.metadata,
232
  "case_id": data_point.case_id,
 
245
  is_correct=False,
246
  duration=duration,
247
  error=str(e),
248
+ chunk_history=None,
249
  metadata={
250
  "data_point_metadata": data_point.metadata,
251
  "case_id": data_point.case_id,
 
318
 
319
  # Convert result to serializable format
320
  result_data = {
321
+ "timestamp": datetime.now().isoformat(),
322
+ "run_id": self.run_id,
323
  "data_point_id": result.data_point_id,
324
  "question": result.question,
325
  "model_answer": result.model_answer,
 
328
  "duration": result.duration,
329
  "usage": result.usage,
330
  "error": result.error,
331
+ "chunk_history": result.chunk_history,
332
+ "metadata": result.metadata
 
 
333
  }
334
 
335
  # Save to file