Jonas commited on
Commit
e382e80
·
1 Parent(s): a7f1030

Enhance app.py and openfda_client.py to support configurable limits for adverse events, serious outcomes, and report sources; update data retrieval functions to include total report counts for improved output context.

Browse files
Files changed (2) hide show
  1. app.py +56 -14
  2. openfda_client.py +59 -13
app.py CHANGED
@@ -34,12 +34,13 @@ def format_pair_frequency_results(data: dict, drug_name: str, event_name: str) -
34
 
35
  # --- Tool Functions ---
36
 
37
- def top_adverse_events_tool(drug_name: str, patient_sex: str = "all", min_age: int = 0, max_age: int = 120):
38
  """
39
  MCP Tool: Finds the top reported adverse events for a given drug.
40
 
41
  Args:
42
  drug_name (str): The generic name of the drug is preferred! A small sample of brand names (e.g., 'Tylenol') are converted to generic names for demonstration purposes.
 
43
  patient_sex (str): The patient's sex to filter by.
44
  min_age (int): The minimum age for the filter.
45
  max_age (int): The maximum age for the filter.
@@ -64,7 +65,7 @@ def top_adverse_events_tool(drug_name: str, patient_sex: str = "all", min_age: i
64
  if min_age > 0 or max_age < 120:
65
  age_range = (min_age, max_age)
66
 
67
- data = get_top_adverse_events(drug_name, patient_sex=sex_code, age_range=age_range)
68
 
69
  if "error" in data:
70
  error_message = f"An error occurred: {data['error']}"
@@ -79,24 +80,32 @@ def top_adverse_events_tool(drug_name: str, patient_sex: str = "all", min_age: i
79
  df = pd.DataFrame(data["results"])
80
  df = df.rename(columns={"term": "Adverse Event", "count": "Report Count"})
81
 
 
 
 
 
 
 
82
  header = (
83
- f"### Top Adverse Events for '{drug_name.title()}'\n"
 
84
  "**Source**: FDA FAERS via OpenFDA\n"
85
  "**Disclaimer**: Spontaneous reports do not prove causation. Consult a healthcare professional."
86
  )
87
  return chart, df, header
88
 
89
- def serious_outcomes_tool(drug_name: str):
90
  """
91
  MCP Tool: Finds the top reported serious outcomes for a given drug.
92
 
93
  Args:
94
  drug_name (str): The generic name of the drug is preferred. A small sample of brand names (e.g., 'Tylenol') are converted to generic names for demonstration purposes.
 
95
 
96
  Returns:
97
  tuple: A Plotly figure, a Pandas DataFrame, and a summary string.
98
  """
99
- data = get_serious_outcomes(drug_name)
100
 
101
  if "error" in data:
102
  error_message = f"An error occurred: {data['error']}"
@@ -111,8 +120,16 @@ def serious_outcomes_tool(drug_name: str):
111
  df = pd.DataFrame(data["results"])
112
  df = df.rename(columns={"term": "Serious Outcome", "count": "Report Count"})
113
 
 
 
 
 
 
 
114
  header = (
115
- f"### Top Serious Outcomes for '{drug_name.title()}'\n"
 
 
116
  "**Source**: FDA FAERS via OpenFDA\n"
117
  "**Disclaimer**: Spontaneous reports do not prove causation. Consult a healthcare professional."
118
  )
@@ -153,27 +170,43 @@ def time_series_tool(drug_name: str, event_name: str, aggregation: str):
153
  chart = create_time_series_chart(data, drug_name, event_name, time_aggregation=agg_code)
154
  return chart
155
 
156
- def report_source_tool(drug_name: str):
157
  """
158
  MCP Tool: Creates a pie chart of report sources for a given drug.
159
 
160
  Args:
161
  drug_name (str): The generic name of the drug is preferred. A small sample of brand names (e.g., 'Tylenol') are converted to generic names for demonstration purposes.
 
162
 
163
  Returns:
164
- A Plotly figure and a string for the Markdown output.
165
  """
166
- data = get_report_source_data(drug_name)
167
 
168
  if "error" in data:
169
- return None, f"An error occurred: {data['error']}"
 
170
 
171
  if not data or not data.get("results"):
172
  message = f"No report source data found for '{drug_name}'."
173
- return create_placeholder_chart(message), message
174
 
175
  chart = create_pie_chart(data, drug_name)
176
- return chart, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  # --- Gradio Interface ---
179
 
@@ -194,6 +227,12 @@ interface1 = gr.Interface(
194
  label="Drug Name",
195
  info="Enter a brand or generic drug name (e.g., 'Aspirin', 'Lisinopril')."
196
  ),
 
 
 
 
 
 
197
  gr.Radio(
198
  ["All", "Male", "Female"],
199
  label="Patient Sex",
@@ -230,7 +269,8 @@ interface3 = gr.Interface(
230
  gr.Textbox(
231
  label="Drug Name",
232
  info="Enter a brand or generic drug name (e.g., 'Aspirin', 'Lisinopril')."
233
- )
 
234
  ],
235
  outputs=[
236
  gr.Plot(label="Top Serious Outcomes Chart"),
@@ -274,10 +314,12 @@ interface4 = gr.Interface(
274
  interface5 = gr.Interface(
275
  fn=report_source_tool,
276
  inputs=[
277
- gr.Textbox(label="Drug Name", info="e.g., 'Aspirin', 'Lisinopril'")
 
278
  ],
279
  outputs=[
280
  gr.Plot(label="Report Source Breakdown"),
 
281
  gr.Markdown()
282
  ],
283
  title="Report Source Breakdown",
 
34
 
35
  # --- Tool Functions ---
36
 
37
+ def top_adverse_events_tool(drug_name: str, top_n: int = 10, patient_sex: str = "all", min_age: int = 0, max_age: int = 120):
38
  """
39
  MCP Tool: Finds the top reported adverse events for a given drug.
40
 
41
  Args:
42
  drug_name (str): The generic name of the drug is preferred! A small sample of brand names (e.g., 'Tylenol') are converted to generic names for demonstration purposes.
43
+ top_n (int): The number of top adverse events to return.
44
  patient_sex (str): The patient's sex to filter by.
45
  min_age (int): The minimum age for the filter.
46
  max_age (int): The maximum age for the filter.
 
65
  if min_age > 0 or max_age < 120:
66
  age_range = (min_age, max_age)
67
 
68
+ data = get_top_adverse_events(drug_name, limit=top_n, patient_sex=sex_code, age_range=age_range)
69
 
70
  if "error" in data:
71
  error_message = f"An error occurred: {data['error']}"
 
80
  df = pd.DataFrame(data["results"])
81
  df = df.rename(columns={"term": "Adverse Event", "count": "Report Count"})
82
 
83
+ total_reports = data.get("meta", {}).get("total_reports_for_query", 0)
84
+ if total_reports > 0:
85
+ df['Relative Frequency (%)'] = ((df['Report Count'] / total_reports) * 100).round(2)
86
+ else:
87
+ df['Relative Frequency (%)'] = 0.0
88
+
89
  header = (
90
+ f"### Top {len(df)} Adverse Events for '{drug_name.title()}'\n"
91
+ f"Based on **{total_reports:,}** total reports matching the given filters.\n"
92
  "**Source**: FDA FAERS via OpenFDA\n"
93
  "**Disclaimer**: Spontaneous reports do not prove causation. Consult a healthcare professional."
94
  )
95
  return chart, df, header
96
 
97
+ def serious_outcomes_tool(drug_name: str, top_n: int = 6):
98
  """
99
  MCP Tool: Finds the top reported serious outcomes for a given drug.
100
 
101
  Args:
102
  drug_name (str): The generic name of the drug is preferred. A small sample of brand names (e.g., 'Tylenol') are converted to generic names for demonstration purposes.
103
+ top_n (int): The number of top serious outcomes to return.
104
 
105
  Returns:
106
  tuple: A Plotly figure, a Pandas DataFrame, and a summary string.
107
  """
108
+ data = get_serious_outcomes(drug_name, limit=top_n)
109
 
110
  if "error" in data:
111
  error_message = f"An error occurred: {data['error']}"
 
120
  df = pd.DataFrame(data["results"])
121
  df = df.rename(columns={"term": "Serious Outcome", "count": "Report Count"})
122
 
123
+ total_serious_reports = data.get("meta", {}).get("total_reports_for_query", 0)
124
+ if total_serious_reports > 0:
125
+ df['% of Serious Reports'] = ((df['Report Count'] / total_serious_reports) * 100).round(2)
126
+ else:
127
+ df['% of Serious Reports'] = 0.0
128
+
129
  header = (
130
+ f"### Top {len(df)} Serious Outcomes for '{drug_name.title()}'\n"
131
+ f"Out of **{total_serious_reports:,}** total serious reports. "
132
+ "Note: a single report may be associated with multiple outcomes.\n"
133
  "**Source**: FDA FAERS via OpenFDA\n"
134
  "**Disclaimer**: Spontaneous reports do not prove causation. Consult a healthcare professional."
135
  )
 
170
  chart = create_time_series_chart(data, drug_name, event_name, time_aggregation=agg_code)
171
  return chart
172
 
173
+ def report_source_tool(drug_name: str, top_n: int = 5):
174
  """
175
  MCP Tool: Creates a pie chart of report sources for a given drug.
176
 
177
  Args:
178
  drug_name (str): The generic name of the drug is preferred. A small sample of brand names (e.g., 'Tylenol') are converted to generic names for demonstration purposes.
179
+ top_n (int): The number of top sources to return.
180
 
181
  Returns:
182
+ tuple: A Plotly figure, a Pandas DataFrame, and a summary string.
183
  """
184
+ data = get_report_source_data(drug_name, limit=top_n)
185
 
186
  if "error" in data:
187
+ error_message = f"An error occurred: {data['error']}"
188
+ return create_placeholder_chart(error_message), pd.DataFrame(), error_message
189
 
190
  if not data or not data.get("results"):
191
  message = f"No report source data found for '{drug_name}'."
192
+ return create_placeholder_chart(message), pd.DataFrame(), message
193
 
194
  chart = create_pie_chart(data, drug_name)
195
+
196
+ df = pd.DataFrame(data['results'])
197
+ df = df.rename(columns={"term": "Source", "count": "Report Count"})
198
+
199
+ total_reports = data.get("meta", {}).get("total_reports_for_query", 0)
200
+ if total_reports > 0:
201
+ df['Percentage'] = ((df['Report Count'] / total_reports) * 100).round(2)
202
+ else:
203
+ df['Percentage'] = 0.0
204
+
205
+ header = (
206
+ f"### Report Sources for '{drug_name.title()}'\n"
207
+ f"Based on **{total_reports:,}** reports with source information."
208
+ )
209
+ return chart, df, header
210
 
211
  # --- Gradio Interface ---
212
 
 
227
  label="Drug Name",
228
  info="Enter a brand or generic drug name (e.g., 'Aspirin', 'Lisinopril')."
229
  ),
230
+ gr.Slider(
231
+ 5, 50,
232
+ value=10,
233
+ label="Number of Events to Show",
234
+ step=1
235
+ ),
236
  gr.Radio(
237
  ["All", "Male", "Female"],
238
  label="Patient Sex",
 
269
  gr.Textbox(
270
  label="Drug Name",
271
  info="Enter a brand or generic drug name (e.g., 'Aspirin', 'Lisinopril')."
272
+ ),
273
+ gr.Slider(1, 6, value=6, label="Number of Outcomes to Show", step=1),
274
  ],
275
  outputs=[
276
  gr.Plot(label="Top Serious Outcomes Chart"),
 
314
  interface5 = gr.Interface(
315
  fn=report_source_tool,
316
  inputs=[
317
+ gr.Textbox(label="Drug Name", info="e.g., 'Aspirin', 'Lisinopril'"),
318
+ gr.Slider(1, 5, value=5, label="Number of Sources to Show", step=1),
319
  ],
320
  outputs=[
321
  gr.Plot(label="Report Source Breakdown"),
322
+ gr.DataFrame(label="Report Source Data", interactive=False),
323
  gr.Markdown()
324
  ],
325
  title="Report Source Breakdown",
openfda_client.py CHANGED
@@ -168,8 +168,9 @@ def get_top_adverse_events(drug_name: str, limit: int = 10, patient_sex: Optiona
168
  if cache_key in cache:
169
  return cache[cache_key]
170
 
171
- query = (
172
- f'search={search_query}'
 
173
  f'&count=patient.reaction.reactionmeddrapt.exact&limit={limit}'
174
  )
175
 
@@ -177,10 +178,22 @@ def get_top_adverse_events(drug_name: str, limit: int = 10, patient_sex: Optiona
177
  # Respect rate limits
178
  time.sleep(REQUEST_DELAY_SECONDS)
179
 
180
- response = requests.get(f"{API_BASE_URL}?{query}")
181
  response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
182
-
183
  data = response.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
  cache[cache_key] = data
186
  return data
@@ -241,14 +254,14 @@ def get_drug_event_pair_frequency(drug_name: str, event_name: str) -> dict:
241
  except Exception as e:
242
  return {"error": f"An unexpected error occurred: {e}"}
243
 
244
- def get_serious_outcomes(drug_name: str, limit: int = 10) -> dict:
245
  """
246
  Query OpenFDA to get the most frequent serious outcomes for a given drug.
247
  This function makes multiple API calls to count different outcome fields.
248
 
249
  Args:
250
  drug_name (str): The name of the drug to search for.
251
- limit (int): This argument is maintained for signature consistency but is not directly used in the multi-query logic.
252
 
253
  Returns:
254
  dict: A dictionary containing aggregated results or an error.
@@ -260,7 +273,7 @@ def get_serious_outcomes(drug_name: str, limit: int = 10) -> dict:
260
  drug_name_processed = DRUG_SYNONYM_MAPPING.get(drug_name_processed, drug_name_processed)
261
 
262
  # Use a cache key for the aggregated result
263
- cache_key = f"serious_outcomes_aggregated_{drug_name_processed}"
264
  if cache_key in cache:
265
  return cache[cache_key]
266
 
@@ -269,6 +282,22 @@ def get_serious_outcomes(drug_name: str, limit: int = 10) -> dict:
269
  # Base search for all serious reports
270
  base_search_query = f'patient.drug.medicinalproduct:"{drug_name_processed}"+AND+serious:1'
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  for field in SERIOUS_OUTCOME_FIELDS:
273
  try:
274
  # Each query counts reports where the specific seriousness field exists
@@ -296,11 +325,14 @@ def get_serious_outcomes(drug_name: str, limit: int = 10) -> dict:
296
 
297
  # Format the results to match the expected structure for plotting
298
  final_data = {
299
- "results": [{"term": k, "count": v} for k, v in aggregated_results.items()]
 
300
  }
301
 
302
- # Sort results by count, descending
303
  final_data["results"] = sorted(final_data["results"], key=lambda x: x['count'], reverse=True)
 
 
304
 
305
  cache[cache_key] = final_data
306
  return final_data
@@ -352,12 +384,13 @@ def get_time_series_data(drug_name: str, event_name: str) -> dict:
352
  except Exception as e:
353
  return {"error": f"An unexpected error occurred: {e}"}
354
 
355
- def get_report_source_data(drug_name: str) -> dict:
356
  """
357
  Query OpenFDA to get the breakdown of report sources for a given drug.
358
 
359
  Args:
360
  drug_name (str): The name of the drug to search for.
 
361
 
362
  Returns:
363
  dict: The JSON response from the API, or an error dictionary.
@@ -368,7 +401,7 @@ def get_report_source_data(drug_name: str) -> dict:
368
  drug_name_processed = drug_name.lower().strip()
369
  drug_name_processed = DRUG_SYNONYM_MAPPING.get(drug_name_processed, drug_name_processed)
370
 
371
- cache_key = f"report_source_{drug_name_processed}"
372
  if cache_key in cache:
373
  return cache[cache_key]
374
 
@@ -385,13 +418,26 @@ def get_report_source_data(drug_name: str) -> dict:
385
 
386
  data = response.json()
387
 
388
- # Translate the qualification codes to human-readable terms
389
  if "results" in data:
 
 
 
 
 
 
 
 
 
 
390
  for item in data["results"]:
391
- # The API returns numeric codes, ensure they are strings for mapping
392
  term_str = str(item["term"])
393
  item["term"] = QUALIFICATION_MAPPING.get(term_str, f"Unknown ({term_str})")
394
 
 
 
 
 
395
  cache[cache_key] = data
396
  return data
397
 
 
168
  if cache_key in cache:
169
  return cache[cache_key]
170
 
171
+ # Query for top events by count
172
+ count_query_url = (
173
+ f'{API_BASE_URL}?search={search_query}'
174
  f'&count=patient.reaction.reactionmeddrapt.exact&limit={limit}'
175
  )
176
 
 
178
  # Respect rate limits
179
  time.sleep(REQUEST_DELAY_SECONDS)
180
 
181
+ response = requests.get(count_query_url)
182
  response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
 
183
  data = response.json()
184
+
185
+ # Query for total reports matching the filters
186
+ total_query_url = f'{API_BASE_URL}?search={search_query}'
187
+ time.sleep(REQUEST_DELAY_SECONDS)
188
+ total_response = requests.get(total_query_url)
189
+ total_response.raise_for_status()
190
+ total_data = total_response.json()
191
+ total_reports = total_data.get("meta", {},).get("results", {}).get("total", 0)
192
+
193
+ # Add total to the main data object
194
+ if 'meta' not in data:
195
+ data['meta'] = {}
196
+ data['meta']['total_reports_for_query'] = total_reports
197
 
198
  cache[cache_key] = data
199
  return data
 
254
  except Exception as e:
255
  return {"error": f"An unexpected error occurred: {e}"}
256
 
257
+ def get_serious_outcomes(drug_name: str, limit: int = 6) -> dict:
258
  """
259
  Query OpenFDA to get the most frequent serious outcomes for a given drug.
260
  This function makes multiple API calls to count different outcome fields.
261
 
262
  Args:
263
  drug_name (str): The name of the drug to search for.
264
+ limit (int): The maximum number of outcomes to return.
265
 
266
  Returns:
267
  dict: A dictionary containing aggregated results or an error.
 
273
  drug_name_processed = DRUG_SYNONYM_MAPPING.get(drug_name_processed, drug_name_processed)
274
 
275
  # Use a cache key for the aggregated result
276
+ cache_key = f"serious_outcomes_aggregated_{drug_name_processed}_{limit}"
277
  if cache_key in cache:
278
  return cache[cache_key]
279
 
 
282
  # Base search for all serious reports
283
  base_search_query = f'patient.drug.medicinalproduct:"{drug_name_processed}"+AND+serious:1'
284
 
285
+ # Get total number of serious reports
286
+ total_serious_reports = 0
287
+ try:
288
+ total_query_url = f"{API_BASE_URL}?search={base_search_query}"
289
+ time.sleep(REQUEST_DELAY_SECONDS)
290
+ response = requests.get(total_query_url)
291
+ if response.status_code == 200:
292
+ total_data = response.json()
293
+ total_serious_reports = total_data.get("meta", {}).get("results", {}).get("total", 0)
294
+ elif response.status_code != 404:
295
+ # If this call fails, we can still proceed, the total will just be 0.
296
+ pass
297
+ except requests.exceptions.RequestException:
298
+ # If fetching total fails, proceed without it.
299
+ pass
300
+
301
  for field in SERIOUS_OUTCOME_FIELDS:
302
  try:
303
  # Each query counts reports where the specific seriousness field exists
 
325
 
326
  # Format the results to match the expected structure for plotting
327
  final_data = {
328
+ "results": [{"term": k, "count": v} for k, v in aggregated_results.items()],
329
+ "meta": {"total_reports_for_query": total_serious_reports}
330
  }
331
 
332
+ # Sort results by count, descending, and then limit
333
  final_data["results"] = sorted(final_data["results"], key=lambda x: x['count'], reverse=True)
334
+ if limit:
335
+ final_data["results"] = final_data["results"][:limit]
336
 
337
  cache[cache_key] = final_data
338
  return final_data
 
384
  except Exception as e:
385
  return {"error": f"An unexpected error occurred: {e}"}
386
 
387
+ def get_report_source_data(drug_name: str, limit: int = 5) -> dict:
388
  """
389
  Query OpenFDA to get the breakdown of report sources for a given drug.
390
 
391
  Args:
392
  drug_name (str): The name of the drug to search for.
393
+ limit (int): The maximum number of sources to return.
394
 
395
  Returns:
396
  dict: The JSON response from the API, or an error dictionary.
 
401
  drug_name_processed = drug_name.lower().strip()
402
  drug_name_processed = DRUG_SYNONYM_MAPPING.get(drug_name_processed, drug_name_processed)
403
 
404
+ cache_key = f"report_source_{drug_name_processed}_{limit}"
405
  if cache_key in cache:
406
  return cache[cache_key]
407
 
 
418
 
419
  data = response.json()
420
 
421
+ # Translate the qualification codes and calculate total before limiting
422
  if "results" in data:
423
+ # Sort by count first
424
+ data['results'] = sorted(data['results'], key=lambda x: x['count'], reverse=True)
425
+
426
+ # Calculate total from all results before limiting
427
+ total_with_source = sum(item['count'] for item in data['results'])
428
+ if 'meta' not in data:
429
+ data['meta'] = {}
430
+ data['meta']['total_reports_for_query'] = total_with_source
431
+
432
+ # Translate codes after processing
433
  for item in data["results"]:
 
434
  term_str = str(item["term"])
435
  item["term"] = QUALIFICATION_MAPPING.get(term_str, f"Unknown ({term_str})")
436
 
437
+ # Apply limit
438
+ if limit:
439
+ data['results'] = data['results'][:limit]
440
+
441
  cache[cache_key] = data
442
  return data
443