| | |
| | """ |
| | Medical X-ray Question Generation Benchmark aka ChestAgentBench |
| | |
| | This script generates clinical questions from X-ray case data of Eurorad dataset using GPT-4o. |
| | It structures questions across different analytical categories and saves them as JSON. |
| | """ |
| |
|
| | import os |
| | import re |
| | import json |
| | from typing import * |
| | from pprint import pprint |
| |
|
| | import openai |
| | import numpy as np |
| | from scipy import stats |
| | import plotly.graph_objects as go |
| | from tqdm import tqdm |
| |
|
| | from benchmark.utils import load_eurorad_dataset |
| | from benchmark.llm import get_llm_response |
| |
|
| | |
| | DATA_DIR = "set your data directory here, e.g. /home/MedRAX/data" |
| | DATASET_PATH = os.path.join(DATA_DIR, "eurorad_metadata.json") |
| |
|
| | SYSTEM_PROMPT = """ |
| | You are an expert medical benchmark creation assistant. |
| | Your goal is to generate questions that evaluate a multimodal medical AI agent's ability to interpret and reason about chest X-rays. |
| | """.strip() |
| |
|
| | CATEGORIES_META = { |
| | "detection": "Identify and locate specific findings in the chest X-ray.", |
| | "classification": "Determine whether specific findings are present or absent in the chest X-ray.", |
| | "enumeration": "Count the number of target findings in the chest X-ray.", |
| | "localization": "Locate a given finding in the chest X-ray.", |
| | "comparison": "Compare the size or position of a specific finding in the chest X-ray.", |
| | "relationship": "Determine the relationship between two or more findings in the chest X-ray.", |
| | "diagnosis": "Make a diagnosis or determine a treatment plan by interpreting the chest X-ray.", |
| | "characterization": "Describe specific attributes (shape, density, margins, etc.) of findings.", |
| | "reasoning": "Explain the medical rationale and thought process behind findings and conclusions.", |
| | } |
| | CATEGORIES = list(CATEGORIES_META.keys()) |
| |
|
| | CATEGORY_COMBINATIONS = [ |
| | ["detection", "localization", "characterization", "reasoning"], |
| | ["detection", "classification", "relationship", "reasoning"], |
| | ["localization", "comparison", "relationship", "reasoning"], |
| | ["classification", "comparison", "diagnosis", "reasoning"], |
| | ["classification", "characterization", "diagnosis", "reasoning"], |
| | ] |
| |
|
| | DEFAULT_SECTIONS = [ |
| | "history", |
| | "image_finding", |
| | "discussion", |
| | "differential_diagnosis", |
| | "diagnosis", |
| | "figures", |
| | ] |
| |
|
| |
|
| | class Question: |
| | """A class to generate clinical questions from case data. |
| | |
| | This class handles creating structured clinical questions by combining case data with |
| | specified categories and difficulty levels. |
| | |
| | Attributes: |
| | type (str): The type of question (e.g. multiple choice) |
| | difficulty (str): Difficulty level of the question |
| | case_data (Dict[str, Any]): Dictionary containing the clinical case data |
| | case_content (str): Formatted case data from selected sections |
| | case_id (str): Unique identifier for the case |
| | categories (List[str]): List of analytical categories this question tests |
| | sections (List[str]): Case sections to include in question |
| | raw_content (Optional[str]): Raw LLM response to the question prompt |
| | content (Optional[Dict[str, str]]): Extracted content from the raw LLM response |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | type: str, |
| | difficulty: str, |
| | case_data: Dict[str, Any], |
| | categories: List[str], |
| | sections: List[str] = [ |
| | "history", |
| | "image_finding", |
| | "discussion", |
| | "differential_diagnosis", |
| | "diagnosis", |
| | "figures", |
| | ], |
| | system_prompt: str = "You are an expert medical benchmark creation assistant.", |
| | ) -> None: |
| | self.type = type |
| | self.difficulty = difficulty |
| | self.case_data = case_data |
| | self.case_id = case_data["case_id"] |
| | self.categories = categories |
| | self.sections = sections |
| | self.system_prompt = system_prompt |
| | self.case_content = self.select_case_sections() |
| | self.raw_content: Optional[str] = None |
| | self.content: Optional[Dict[str, str]] = None |
| |
|
| | def create_question_prompt(self) -> str: |
| | """Creates a formatted prompt for generating a clinical question. |
| | |
| | Returns: |
| | str: A structured prompt containing the question parameters and clinical data |
| | """ |
| | category_descriptions = "\n".join( |
| | f"{category}: {desc}" |
| | for category, desc in CATEGORIES_META.items() |
| | if category in self.categories |
| | ) |
| |
|
| | return f""" |
| | You must follow these guidelines: |
| | 1. Questions must be answerable using only context and chest X-rays. |
| | - Questions must explicitly mention the referenced figures |
| | - Questions can only reference the chest X-ray figures |
| | |
| | 2. Questions must have unambiguous, verifiable answers, and should: |
| | - Challenge the agent's analytical capabilities |
| | - Require multi-step reasoning |
| | - Test ability to make precise observations |
| | - Evaluate capability to derive insights and findings from the chest X-ray |
| | |
| | 3. The agent has access to tools like classification, report generation, segmentation, grounding, visual question answering, etc. Your question should be complex to require the use of such tools. |
| | |
| | |
| | Create a {self.difficulty} {self.type} clinical question that integrates the following: |
| | |
| | {category_descriptions} |
| | |
| | based on the following clinical case: |
| | |
| | {self.case_content} |
| | |
| | Do not use any infomration derived from the CT and MRI images. Do not provide any information and findings about the chest X-rays. |
| | Your question should require the agent to derive insights and findings from the chest X-ray by itself. |
| | Your answer should be verifiable directly in the context of the case. |
| | You can only use the image findings that come from the chest X-ray figures. |
| | |
| | Your response must follow this exact format: |
| | THOUGHTS: [Think about different reasoning steps and tools the agent should use to answer the question] |
| | QUESTION: [complete question with relevant context. Incorrect choices should be very close to the correct answer.] |
| | FIGURES: [list of required figures, e.g. ["Figure 1", "Figure 2a"]] |
| | EXPLANATION: [short explanation of why your answer is verifiable in the case] |
| | ANSWER: [correct answer e.g. "A"] |
| | """.strip().replace( |
| | " ", "" |
| | ) |
| |
|
| | def select_case_sections(self) -> str: |
| | """Extract and format selected sections from case data into paragraphs. |
| | |
| | Returns: |
| | str: Formatted string with case sections and content |
| | """ |
| | section_mapping = { |
| | "history": ("history", "No history provided."), |
| | "image_finding": ("image_finding", "No findings provided."), |
| | "discussion": ("discussion", "No discussion provided."), |
| | "differential_diagnosis": ( |
| | "differential_diagnosis", |
| | "No differential diagnosis provided.", |
| | ), |
| | "diagnosis": ("diagnosis", "No diagnosis provided."), |
| | "figures": ("figures", "No figures provided."), |
| | } |
| |
|
| | formatted = [] |
| | for section in self.sections: |
| | if section in section_mapping: |
| | key, default = section_mapping[section] |
| | content = self.case_data.get(key, default) |
| |
|
| | if key == "figures": |
| | figures_text = [] |
| | for figure in content: |
| | for subfig in figure["subfigures"]: |
| | figures_text.append(f"{subfig['number']}: {subfig['caption']}") |
| | content = "\n".join(figures_text) |
| |
|
| | formatted.append(f"{section}:\n{content}") |
| |
|
| | return "\n\n".join(formatted) |
| |
|
| | def create_question( |
| | self, |
| | client: openai.OpenAI, |
| | temperature: float = 0.7, |
| | top_p: float = 0.95, |
| | max_tokens: int = 500, |
| | model: str = "gpt-4o", |
| | ) -> str: |
| | """Create a clinical question using LLM. |
| | |
| | Args: |
| | client (openai.OpenAI): OpenAI client instance |
| | temperature (float): Controls randomness in responses. Defaults to 0.7. |
| | top_p (float): Controls diversity via nucleus sampling. Defaults to 0.95. |
| | max_tokens (int): Max tokens in model response. Defaults to 500. |
| | model (str): OpenAI model to use. Defaults to "gpt-4o". |
| | |
| | Returns: |
| | str: LLM response containing formatted question components |
| | """ |
| | self.raw_content = get_llm_response( |
| | client=client, |
| | prompt=self.create_question_prompt(), |
| | system_prompt=self.system_prompt, |
| | temperature=temperature, |
| | top_p=top_p, |
| | max_tokens=max_tokens, |
| | model=model, |
| | ) |
| | self.content = self.extract_content() |
| |
|
| | return self.raw_content |
| |
|
| | def extract_content(self) -> Dict[str, str]: |
| | """Extract sections from raw LLM response using regex patterns. |
| | |
| | Returns: |
| | Dict[str, str]: Extracted sections including thoughts, question, figures, explanation, and answer |
| | """ |
| | keywords = ["THOUGHTS", "QUESTION", "FIGURES", "EXPLANATION", "ANSWER"] |
| |
|
| | content = {} |
| | for kw in keywords: |
| | pattern = rf"{kw}:\s*(.*?)(?=\n[A-Z]+:|$)" |
| | match = re.search(pattern, self.raw_content, re.DOTALL) |
| | content[kw.lower()] = match.group(1).strip() if match else None |
| |
|
| | return content |
| |
|
| | def save(self, output_path: str) -> Dict[str, Any]: |
| | """Save question content and metadata as a JSON file. |
| | |
| | Args: |
| | output_path (str): Directory path where the JSON file will be saved |
| | |
| | Returns: |
| | Dict[str, Any]: Question data including content (thoughts, question, figures, options, |
| | explanation, answer) and metadata (type, difficulty, categories, etc.) |
| | """ |
| | question_metadata = self.content.copy() |
| |
|
| | |
| | question_metadata["metadata"] = { |
| | "case_id": self.case_id, |
| | "type": self.type, |
| | "difficulty": self.difficulty, |
| | "categories": self.categories, |
| | "sections": self.sections, |
| | } |
| |
|
| | |
| | case_dir = os.path.join(output_path, str(self.case_id)) |
| | os.makedirs(case_dir, exist_ok=True) |
| |
|
| | |
| | output_file = os.path.join(case_dir, f"{self.case_id}_{self.__hash__()}.json") |
| | with open(output_file, "w") as f: |
| | json.dump(question_metadata, f, indent=2) |
| |
|
| | return question_metadata |
| |
|
| |
|
| | def generate_questions( |
| | dataset: Dict[str, Any], |
| | client: openai.OpenAI, |
| | output_dir: str, |
| | skip_first: int = 100, |
| | temperature: float = 0.7, |
| | top_p: float = 0.95, |
| | max_tokens: int = 1200, |
| | model: str = "gpt-4o", |
| | ) -> None: |
| | """Generate questions for each case and category combination. |
| | |
| | Args: |
| | dataset: Dictionary of case data |
| | client: OpenAI client instance |
| | output_dir: Directory to save generated questions |
| | skip_first: Number of initial cases to skip |
| | temperature: LLM temperature parameter |
| | top_p: LLM top_p parameter |
| | max_tokens: Maximum tokens for LLM response |
| | model: LLM model name |
| | """ |
| | target_cases = sorted(list(dataset.keys()), key=int)[-len(dataset) : -skip_first] |
| |
|
| | for case_id in tqdm(target_cases, desc="Processing cases"): |
| | case_data = dataset[case_id] |
| |
|
| | for category in tqdm(CATEGORY_COMBINATIONS, desc=f"Categories for case {case_id}"): |
| | question = Question( |
| | type="multiple choice (A/B/C/D/E/F)", |
| | difficulty="complex", |
| | case_data=case_data, |
| | categories=category, |
| | sections=DEFAULT_SECTIONS, |
| | system_prompt=SYSTEM_PROMPT, |
| | ) |
| |
|
| | response = question.create_question( |
| | client=client, |
| | temperature=temperature, |
| | top_p=top_p, |
| | max_tokens=max_tokens, |
| | model=model, |
| | ) |
| | question.save(output_dir) |
| |
|
| |
|
| | def main(): |
| | """Main execution function.""" |
| | client = openai.OpenAI() |
| |
|
| | |
| | dataset = load_eurorad_dataset( |
| | DATASET_PATH, |
| | section="Chest Imaging", |
| | as_dict=True, |
| | filter_by_caption=[ |
| | "xray", |
| | "x-ray", |
| | "x ray", |
| | "ray", |
| | "xr", |
| | "radiograph", |
| | ], |
| | ) |
| | print(f"\n---\nFound {len(dataset)} cases with X-ray mentions\n---\n") |
| |
|
| | |
| | case_data = dataset["16798"] |
| | pprint(case_data, sort_dicts=False) |
| |
|
| | |
| | generate_questions(dataset=dataset, client=client, output_dir="benchmark/questions") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|