|
|
import json |
|
|
import pandas as pd |
|
|
|
|
|
from .retriever import ClinicalCaseRetriever, DummyRetriever |
|
|
from .evaluator import AnswerEvaluator |
|
|
|
|
|
class OralExamSimulator: |
|
|
"""Main class that coordinates the oral board exam simulation.""" |
|
|
|
|
|
def __init__(self, retriever, evaluator): |
|
|
if not isinstance(retriever, (ClinicalCaseRetriever, DummyRetriever)): |
|
|
raise TypeError("Retriever must be an instance of ClinicalCaseRetriever or DummyRetriever") |
|
|
if not isinstance(evaluator, AnswerEvaluator): |
|
|
raise TypeError("Evaluator must be an instance of AnswerEvaluator") |
|
|
|
|
|
self.retriever = retriever |
|
|
self.evaluator = evaluator |
|
|
self.current_case = None |
|
|
self.current_question_idx = 0 |
|
|
self.session_history = [] |
|
|
|
|
|
def start_new_case(self, clinical_query=None, case_idx=None): |
|
|
""" |
|
|
Initialize a new exam case based on query or direct selection. |
|
|
|
|
|
Args: |
|
|
clinical_query (str, optional): Text description of the desired case topic. |
|
|
case_idx (int, optional): Direct index of the case to use from the retriever's dataset. |
|
|
|
|
|
Returns: |
|
|
dict: Contains case info and first question, or an error message. |
|
|
""" |
|
|
print("-" * 50) |
|
|
print(f"Attempting to start new case | Query: '{clinical_query}' | Index: {case_idx}") |
|
|
|
|
|
|
|
|
self.current_case = None |
|
|
self.current_question_idx = 0 |
|
|
self.session_history = [] |
|
|
|
|
|
|
|
|
retrieved_info = None |
|
|
if case_idx is not None: |
|
|
try: |
|
|
|
|
|
|
|
|
if 0 <= int(case_idx) < len(self.retriever.dataset): |
|
|
self.current_case = self.retriever.dataset[int(case_idx)] |
|
|
similarity_score = 1.0 |
|
|
print(f"Selected case by index {case_idx}: {self.current_case.get('clinical_presentation', 'Unknown Presentation')}") |
|
|
retrieved_info = (self.current_case, similarity_score) |
|
|
else: |
|
|
print(f"Error: Invalid case index {case_idx}. Must be between 0 and {len(self.retriever.dataset)-1}.") |
|
|
return {"error": f"Invalid case index: {case_idx}"} |
|
|
except Exception as e: |
|
|
print(f"Error selecting case by index {case_idx}: {e}") |
|
|
return {"error": f"Failed to select case by index: {e}"} |
|
|
|
|
|
elif clinical_query: |
|
|
|
|
|
try: |
|
|
|
|
|
retrieved_results = self.retriever.retrieve_relevant_case(clinical_query, top_k=1) |
|
|
if retrieved_results: |
|
|
retrieved_info = retrieved_results[0] |
|
|
self.current_case = retrieved_info[0] |
|
|
similarity_score = retrieved_info[1] |
|
|
print(f"Retrieved case via query ('{clinical_query}') with score {similarity_score:.4f}: {self.current_case.get('clinical_presentation', 'Unknown Presentation')}") |
|
|
else: |
|
|
print(f"Error: No case found for query: '{clinical_query}'") |
|
|
return {"error": f"No relevant case found for query: {clinical_query}"} |
|
|
except Exception as e: |
|
|
print(f"Error retrieving case for query '{clinical_query}': {e}") |
|
|
return {"error": f"Failed to retrieve case by query: {e}"} |
|
|
else: |
|
|
|
|
|
print("Error: Must provide either a clinical query or a case index.") |
|
|
return {"error": "Please provide either a clinical query or case index."} |
|
|
|
|
|
|
|
|
if self.current_case is None: |
|
|
|
|
|
print("Error: Failed to set current_case.") |
|
|
return {"error": "Failed to load the selected case."} |
|
|
|
|
|
|
|
|
if 'questions' not in self.current_case or 'answers' not in self.current_case or \ |
|
|
not isinstance(self.current_case['questions'], list) or \ |
|
|
not isinstance(self.current_case['answers'], list) or \ |
|
|
len(self.current_case['questions']) != len(self.current_case['answers']): |
|
|
print(f"Error: Invalid case structure for case ID {self.current_case.get('case_id', 'N/A')}. Mismatched or missing Q/A lists.") |
|
|
return {"error": "Selected case has invalid format."} |
|
|
|
|
|
if not self.current_case['questions']: |
|
|
print(f"Warning: Selected case ID {self.current_case.get('case_id', 'N/A')} has no questions.") |
|
|
|
|
|
return {"error": "Selected case contains no questions."} |
|
|
|
|
|
|
|
|
|
|
|
self.session_history.append({ |
|
|
"role": "system", |
|
|
"content": f"Clinical scenario started: {self.current_case.get('clinical_presentation', 'Unknown Presentation')} (Case ID: {self.current_case.get('case_id', 'N/A')})" |
|
|
}) |
|
|
|
|
|
|
|
|
first_question = self.current_case['questions'][0] |
|
|
|
|
|
|
|
|
self.session_history.append({ |
|
|
"role": "examiner", |
|
|
"content": first_question |
|
|
}) |
|
|
|
|
|
print(f"Case successfully started. Total questions: {len(self.current_case['questions'])}") |
|
|
print("-" * 50) |
|
|
|
|
|
return { |
|
|
"case_id": self.current_case.get('case_id', 'unknown'), |
|
|
"clinical_presentation": self.current_case.get('clinical_presentation', 'Unknown'), |
|
|
"similarity_score": similarity_score, |
|
|
"current_question": first_question, |
|
|
"question_number": 1, |
|
|
"total_questions": len(self.current_case['questions']) |
|
|
} |
|
|
|
|
|
def process_user_response(self, response): |
|
|
""" |
|
|
Process the user's answer, get feedback, and return the next question or completion status. |
|
|
|
|
|
Args: |
|
|
response (str): User's answer text. |
|
|
|
|
|
Returns: |
|
|
dict: Contains feedback, expected answer, completion status, and next question (if applicable), or an error message. |
|
|
""" |
|
|
if self.current_case is None: |
|
|
print("Error: No active case.") |
|
|
return {"error": "No active case. Please start a new case first."} |
|
|
|
|
|
if self.current_question_idx >= len(self.current_case['questions']): |
|
|
print("Error: Attempting to process response when case is already complete.") |
|
|
return {"error": "Case already completed."} |
|
|
|
|
|
print("-" * 50) |
|
|
current_q_num = self.current_question_idx + 1 |
|
|
total_q = len(self.current_case['questions']) |
|
|
print(f"Processing response for Question {current_q_num}/{total_q}") |
|
|
print(f"User Response: {response}") |
|
|
|
|
|
|
|
|
self.session_history.append({ |
|
|
"role": "resident", |
|
|
"content": response |
|
|
}) |
|
|
|
|
|
|
|
|
expected_answer = self.current_case['answers'][self.current_question_idx] |
|
|
print(f"Expected Answer: {expected_answer}") |
|
|
|
|
|
|
|
|
feedback = self.evaluator.evaluate_answer( |
|
|
response, |
|
|
expected_answer, |
|
|
clinical_context = f"Regarding the case '{self.current_case.get('clinical_presentation', 'N/A')}'" |
|
|
) |
|
|
print(f"Generated Feedback: {feedback}") |
|
|
|
|
|
|
|
|
|
|
|
self.session_history.append({ |
|
|
"role": "feedback", |
|
|
"content": feedback |
|
|
}) |
|
|
|
|
|
|
|
|
self.current_question_idx += 1 |
|
|
|
|
|
|
|
|
is_complete = self.current_question_idx >= len(self.current_case['questions']) |
|
|
|
|
|
result = { |
|
|
"feedback": feedback, |
|
|
"expected_answer": expected_answer, |
|
|
"is_complete": is_complete, |
|
|
"question_number": self.current_question_idx |
|
|
} |
|
|
|
|
|
|
|
|
if not is_complete: |
|
|
next_question = self.current_case['questions'][self.current_question_idx] |
|
|
result["next_question"] = next_question |
|
|
result["total_questions"] = total_q |
|
|
|
|
|
|
|
|
self.session_history.append({ |
|
|
"role": "examiner", |
|
|
"content": next_question |
|
|
}) |
|
|
print(f"Next question ({result['question_number']}/{total_q}): {next_question}") |
|
|
else: |
|
|
print("Case completed.") |
|
|
summary = self.generate_session_summary() |
|
|
result["session_summary"] = summary |
|
|
self.session_history.append({ |
|
|
"role": "system", |
|
|
"content": "End of clinical scenario." |
|
|
}) |
|
|
|
|
|
|
|
|
print("-" * 50) |
|
|
return result |
|
|
|
|
|
def generate_session_summary(self): |
|
|
"""Generate a summary dictionary of the completed session.""" |
|
|
if not self.current_case or not self.session_history: |
|
|
return {"error": "No active or completed session to summarize."} |
|
|
|
|
|
|
|
|
return { |
|
|
"case_id": self.current_case.get('case_id', 'N/A'), |
|
|
"case": self.current_case.get('clinical_presentation', 'Unknown'), |
|
|
"total_questions_in_case": len(self.current_case.get('questions', [])), |
|
|
"interaction_history": self.session_history |
|
|
} |
|
|
|
|
|
def save_session(self, filepath): |
|
|
"""Save the current session summary to a JSON file.""" |
|
|
summary = self.generate_session_summary() |
|
|
if "error" in summary: |
|
|
print(f"Error generating summary for saving: {summary['error']}") |
|
|
return {"error": "No session to save"} |
|
|
|
|
|
try: |
|
|
|
|
|
summary["timestamp"] = pd.Timestamp.now().isoformat() |
|
|
|
|
|
|
|
|
os.makedirs(os.path.dirname(filepath), exist_ok=True) |
|
|
|
|
|
with open(filepath, 'w') as f: |
|
|
json.dump(summary, f, indent=2) |
|
|
print(f"Session saved successfully to {filepath}") |
|
|
return {"status": "Session saved successfully"} |
|
|
except Exception as e: |
|
|
print(f"Error saving session to {filepath}: {e}") |
|
|
return {"error": f"Failed to save session: {e}"} |