boardgpt-llm / src /simulator.py
melmoheb's picture
Upload 6 files
129641e verified
import json
import pandas as pd
# Assuming retriever and evaluator classes are in these files:
from .retriever import ClinicalCaseRetriever, DummyRetriever
from .evaluator import AnswerEvaluator
class OralExamSimulator:
"""Main class that coordinates the oral board exam simulation."""
def __init__(self, retriever, evaluator):
if not isinstance(retriever, (ClinicalCaseRetriever, DummyRetriever)):
raise TypeError("Retriever must be an instance of ClinicalCaseRetriever or DummyRetriever")
if not isinstance(evaluator, AnswerEvaluator):
raise TypeError("Evaluator must be an instance of AnswerEvaluator")
self.retriever = retriever
self.evaluator = evaluator
self.current_case = None
self.current_question_idx = 0
self.session_history = []
def start_new_case(self, clinical_query=None, case_idx=None):
"""
Initialize a new exam case based on query or direct selection.
Args:
clinical_query (str, optional): Text description of the desired case topic.
case_idx (int, optional): Direct index of the case to use from the retriever's dataset.
Returns:
dict: Contains case info and first question, or an error message.
"""
print("-" * 50)
print(f"Attempting to start new case | Query: '{clinical_query}' | Index: {case_idx}")
# Reset state for the new case
self.current_case = None
self.current_question_idx = 0
self.session_history = []
# Case selection logic
retrieved_info = None # Use a temporary variable
if case_idx is not None:
try:
# Direct case selection by index
# Ensure index is valid
if 0 <= int(case_idx) < len(self.retriever.dataset):
self.current_case = self.retriever.dataset[int(case_idx)]
similarity_score = 1.0 # Direct selection implies perfect 'match'
print(f"Selected case by index {case_idx}: {self.current_case.get('clinical_presentation', 'Unknown Presentation')}")
retrieved_info = (self.current_case, similarity_score)
else:
print(f"Error: Invalid case index {case_idx}. Must be between 0 and {len(self.retriever.dataset)-1}.")
return {"error": f"Invalid case index: {case_idx}"}
except Exception as e:
print(f"Error selecting case by index {case_idx}: {e}")
return {"error": f"Failed to select case by index: {e}"}
elif clinical_query:
# RAG-based retrieval
try:
# retrieve_relevant_case now returns a list of tuples: [(case_dict, score), ...]
retrieved_results = self.retriever.retrieve_relevant_case(clinical_query, top_k=1)
if retrieved_results: # Check if list is not empty
retrieved_info = retrieved_results[0] # Get the first tuple (case_dict, score)
self.current_case = retrieved_info[0]
similarity_score = retrieved_info[1]
print(f"Retrieved case via query ('{clinical_query}') with score {similarity_score:.4f}: {self.current_case.get('clinical_presentation', 'Unknown Presentation')}")
else:
print(f"Error: No case found for query: '{clinical_query}'")
return {"error": f"No relevant case found for query: {clinical_query}"}
except Exception as e:
print(f"Error retrieving case for query '{clinical_query}': {e}")
return {"error": f"Failed to retrieve case by query: {e}"}
else:
# No selection method provided
print("Error: Must provide either a clinical query or a case index.")
return {"error": "Please provide either a clinical query or case index."}
# --- Post-selection setup ---
if self.current_case is None:
# This should ideally be caught above, but double-check
print("Error: Failed to set current_case.")
return {"error": "Failed to load the selected case."}
# Validate case structure
if 'questions' not in self.current_case or 'answers' not in self.current_case or \
not isinstance(self.current_case['questions'], list) or \
not isinstance(self.current_case['answers'], list) or \
len(self.current_case['questions']) != len(self.current_case['answers']):
print(f"Error: Invalid case structure for case ID {self.current_case.get('case_id', 'N/A')}. Mismatched or missing Q/A lists.")
return {"error": "Selected case has invalid format."}
if not self.current_case['questions']:
print(f"Warning: Selected case ID {self.current_case.get('case_id', 'N/A')} has no questions.")
# Decide how to handle this - error or proceed? Let's return an error for now.
return {"error": "Selected case contains no questions."}
# Start a new session record
self.session_history.append({
"role": "system",
"content": f"Clinical scenario started: {self.current_case.get('clinical_presentation', 'Unknown Presentation')} (Case ID: {self.current_case.get('case_id', 'N/A')})"
})
# Get the first question
first_question = self.current_case['questions'][0]
# Record this interaction
self.session_history.append({
"role": "examiner",
"content": first_question
})
print(f"Case successfully started. Total questions: {len(self.current_case['questions'])}")
print("-" * 50)
return {
"case_id": self.current_case.get('case_id', 'unknown'),
"clinical_presentation": self.current_case.get('clinical_presentation', 'Unknown'),
"similarity_score": similarity_score, # Use the score from retrieval
"current_question": first_question,
"question_number": 1,
"total_questions": len(self.current_case['questions'])
}
def process_user_response(self, response):
"""
Process the user's answer, get feedback, and return the next question or completion status.
Args:
response (str): User's answer text.
Returns:
dict: Contains feedback, expected answer, completion status, and next question (if applicable), or an error message.
"""
if self.current_case is None:
print("Error: No active case.")
return {"error": "No active case. Please start a new case first."}
if self.current_question_idx >= len(self.current_case['questions']):
print("Error: Attempting to process response when case is already complete.")
return {"error": "Case already completed."}
print("-" * 50)
current_q_num = self.current_question_idx + 1
total_q = len(self.current_case['questions'])
print(f"Processing response for Question {current_q_num}/{total_q}")
print(f"User Response: {response}")
# Save the user's response to history
self.session_history.append({
"role": "resident",
"content": response
})
# Get the expected answer for the current question
expected_answer = self.current_case['answers'][self.current_question_idx]
print(f"Expected Answer: {expected_answer}")
# Evaluate the answer
feedback = self.evaluator.evaluate_answer(
response,
expected_answer,
clinical_context = f"Regarding the case '{self.current_case.get('clinical_presentation', 'N/A')}'"
)
print(f"Generated Feedback: {feedback}")
# Add feedback to history
self.session_history.append({
"role": "feedback",
"content": feedback
})
# Move to the next question index
self.current_question_idx += 1
# Check if the case is complete
is_complete = self.current_question_idx >= len(self.current_case['questions'])
result = {
"feedback": feedback,
"expected_answer": expected_answer,
"is_complete": is_complete,
"question_number": self.current_question_idx
}
# Add next question if not complete
if not is_complete:
next_question = self.current_case['questions'][self.current_question_idx]
result["next_question"] = next_question
result["total_questions"] = total_q
# Add next question to history
self.session_history.append({
"role": "examiner",
"content": next_question
})
print(f"Next question ({result['question_number']}/{total_q}): {next_question}")
else:
print("Case completed.")
summary = self.generate_session_summary()
result["session_summary"] = summary
self.session_history.append({
"role": "system",
"content": "End of clinical scenario."
})
print("-" * 50)
return result
def generate_session_summary(self):
"""Generate a summary dictionary of the completed session."""
if not self.current_case or not self.session_history:
return {"error": "No active or completed session to summarize."}
# Simple summary structure
return {
"case_id": self.current_case.get('case_id', 'N/A'),
"case": self.current_case.get('clinical_presentation', 'Unknown'),
"total_questions_in_case": len(self.current_case.get('questions', [])),
"interaction_history": self.session_history # Include the full log
}
def save_session(self, filepath):
"""Save the current session summary to a JSON file."""
summary = self.generate_session_summary()
if "error" in summary:
print(f"Error generating summary for saving: {summary['error']}")
return {"error": "No session to save"}
try:
# Add a timestamp to the saved data
summary["timestamp"] = pd.Timestamp.now().isoformat()
# Ensure directory exists
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, 'w') as f:
json.dump(summary, f, indent=2)
print(f"Session saved successfully to {filepath}")
return {"status": "Session saved successfully"}
except Exception as e:
print(f"Error saving session to {filepath}: {e}")
return {"error": f"Failed to save session: {e}"}