Adibvafa commited on
Commit
fbb4118
·
1 Parent(s): a85df01

Ensure all tools follow dict, dict return

Browse files
medrax/tools/llava_med.py CHANGED
@@ -117,7 +117,7 @@ class LlavaMedTool(BaseTool):
117
  question: str,
118
  image_path: Optional[str] = None,
119
  run_manager: Optional[CallbackManagerForToolRun] = None,
120
- ) -> Tuple[str, Dict]:
121
  """Answer a medical question, optionally based on an input image.
122
 
123
  Args:
@@ -126,7 +126,7 @@ class LlavaMedTool(BaseTool):
126
  run_manager (Optional[CallbackManagerForToolRun]): The callback manager for the tool run.
127
 
128
  Returns:
129
- Tuple[str, Dict]: A tuple containing the model's answer and any additional metadata.
130
 
131
  Raises:
132
  Exception: If there's an error processing the input or generating the answer.
@@ -146,7 +146,12 @@ class LlavaMedTool(BaseTool):
146
  use_cache=True,
147
  )
148
 
149
- output = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
 
 
 
 
 
150
  metadata = {
151
  "question": question,
152
  "image_path": image_path,
@@ -154,18 +159,20 @@ class LlavaMedTool(BaseTool):
154
  }
155
  return output, metadata
156
  except Exception as e:
157
- return f"Error generating answer: {str(e)}", {
 
158
  "question": question,
159
  "image_path": image_path,
160
  "analysis_status": "failed",
161
  }
 
162
 
163
  async def _arun(
164
  self,
165
  question: str,
166
  image_path: Optional[str] = None,
167
  run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
168
- ) -> Tuple[str, Dict]:
169
  """Asynchronously answer a medical question, optionally based on an input image.
170
 
171
  This method currently calls the synchronous version, as the model inference
@@ -178,7 +185,7 @@ class LlavaMedTool(BaseTool):
178
  run_manager (Optional[AsyncCallbackManagerForToolRun]): The async callback manager for the tool run.
179
 
180
  Returns:
181
- Tuple[str, Dict]: A tuple containing the model's answer and any additional metadata.
182
 
183
  Raises:
184
  Exception: If there's an error processing the input or generating the answer.
 
117
  question: str,
118
  image_path: Optional[str] = None,
119
  run_manager: Optional[CallbackManagerForToolRun] = None,
120
+ ) -> Tuple[Dict[str, Any], Dict]:
121
  """Answer a medical question, optionally based on an input image.
122
 
123
  Args:
 
126
  run_manager (Optional[CallbackManagerForToolRun]): The callback manager for the tool run.
127
 
128
  Returns:
129
+ Tuple[Dict[str, Any], Dict]: A tuple containing the output dictionary and metadata dictionary.
130
 
131
  Raises:
132
  Exception: If there's an error processing the input or generating the answer.
 
146
  use_cache=True,
147
  )
148
 
149
+ answer = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
150
+
151
+ output = {
152
+ "answer": answer,
153
+ }
154
+
155
  metadata = {
156
  "question": question,
157
  "image_path": image_path,
 
159
  }
160
  return output, metadata
161
  except Exception as e:
162
+ output = {"error": f"Error generating answer: {str(e)}"}
163
+ metadata = {
164
  "question": question,
165
  "image_path": image_path,
166
  "analysis_status": "failed",
167
  }
168
+ return output, metadata
169
 
170
  async def _arun(
171
  self,
172
  question: str,
173
  image_path: Optional[str] = None,
174
  run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
175
+ ) -> Tuple[Dict[str, Any], Dict]:
176
  """Asynchronously answer a medical question, optionally based on an input image.
177
 
178
  This method currently calls the synchronous version, as the model inference
 
185
  run_manager (Optional[AsyncCallbackManagerForToolRun]): The async callback manager for the tool run.
186
 
187
  Returns:
188
+ Tuple[Dict[str, Any], Dict]: A tuple containing the output dictionary and metadata dictionary.
189
 
190
  Raises:
191
  Exception: If there's an error processing the input or generating the answer.
medrax/tools/report_generation.py CHANGED
@@ -158,7 +158,7 @@ class ChestXRayReportGeneratorTool(BaseTool):
158
  self,
159
  image_path: str,
160
  run_manager: Optional[CallbackManagerForToolRun] = None,
161
- ) -> Tuple[str, Dict]:
162
  """Generate a comprehensive chest X-ray report containing both findings and impression.
163
 
164
  Args:
@@ -166,7 +166,7 @@ class ChestXRayReportGeneratorTool(BaseTool):
166
  run_manager (Optional[CallbackManagerForToolRun]): The callback manager.
167
 
168
  Returns:
169
- Tuple[str, Dict]: A tuple containing the complete report and metadata.
170
  """
171
  try:
172
  # Process image for both models
@@ -193,25 +193,33 @@ class ChestXRayReportGeneratorTool(BaseTool):
193
  f"IMPRESSION:\n{impression_text}"
194
  )
195
 
 
 
 
 
 
 
196
  metadata = {
197
  "image_path": image_path,
198
  "analysis_status": "completed",
199
  "sections_generated": ["findings", "impression"],
200
  }
201
 
202
- return report, metadata
203
 
204
  except Exception as e:
205
- return f"Error generating report: {str(e)}", {
 
206
  "image_path": image_path,
207
  "analysis_status": "failed",
208
  "error": str(e),
209
  }
 
210
 
211
  async def _arun(
212
  self,
213
  image_path: str,
214
  run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
215
- ) -> Tuple[str, Dict]:
216
  """Asynchronously generate a comprehensive chest X-ray report."""
217
  return self._run(image_path)
 
158
  self,
159
  image_path: str,
160
  run_manager: Optional[CallbackManagerForToolRun] = None,
161
+ ) -> Tuple[Dict[str, Any], Dict]:
162
  """Generate a comprehensive chest X-ray report containing both findings and impression.
163
 
164
  Args:
 
166
  run_manager (Optional[CallbackManagerForToolRun]): The callback manager.
167
 
168
  Returns:
169
+ Tuple[Dict[str, Any], Dict]: A tuple containing the output dictionary and metadata dictionary.
170
  """
171
  try:
172
  # Process image for both models
 
193
  f"IMPRESSION:\n{impression_text}"
194
  )
195
 
196
+ output = {
197
+ "report": report,
198
+ "findings": findings_text,
199
+ "impression": impression_text,
200
+ }
201
+
202
  metadata = {
203
  "image_path": image_path,
204
  "analysis_status": "completed",
205
  "sections_generated": ["findings", "impression"],
206
  }
207
 
208
+ return output, metadata
209
 
210
  except Exception as e:
211
+ output = {"error": f"Error generating report: {str(e)}"}
212
+ metadata = {
213
  "image_path": image_path,
214
  "analysis_status": "failed",
215
  "error": str(e),
216
  }
217
+ return output, metadata
218
 
219
  async def _arun(
220
  self,
221
  image_path: str,
222
  run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
223
+ ) -> Tuple[Dict[str, Any], Dict]:
224
  """Asynchronously generate a comprehensive chest X-ray report."""
225
  return self._run(image_path)