File size: 10,903 Bytes
2247e66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 |
import json
import pandas as pd
# Assuming retriever and evaluator classes are in these files:
from .retriever import ClinicalCaseRetriever, DummyRetriever
from .evaluator import AnswerEvaluator
class OralExamSimulator:
"""Main class that coordinates the oral board exam simulation."""
def __init__(self, retriever, evaluator):
if not isinstance(retriever, (ClinicalCaseRetriever, DummyRetriever)):
raise TypeError("Retriever must be an instance of ClinicalCaseRetriever or DummyRetriever")
if not isinstance(evaluator, AnswerEvaluator):
raise TypeError("Evaluator must be an instance of AnswerEvaluator")
self.retriever = retriever
self.evaluator = evaluator
self.current_case = None
self.current_question_idx = 0
self.session_history = []
def start_new_case(self, clinical_query=None, case_idx=None):
"""
Initialize a new exam case based on query or direct selection.
Args:
clinical_query (str, optional): Text description of the desired case topic.
case_idx (int, optional): Direct index of the case to use from the retriever's dataset.
Returns:
dict: Contains case info and first question, or an error message.
"""
print("-" * 50)
print(f"Attempting to start new case | Query: '{clinical_query}' | Index: {case_idx}")
# Reset state for the new case
self.current_case = None
self.current_question_idx = 0
self.session_history = []
# Case selection logic
retrieved_info = None # Use a temporary variable
if case_idx is not None:
try:
# Direct case selection by index
# Ensure index is valid
if 0 <= int(case_idx) < len(self.retriever.dataset):
self.current_case = self.retriever.dataset[int(case_idx)]
similarity_score = 1.0 # Direct selection implies perfect 'match'
print(f"Selected case by index {case_idx}: {self.current_case.get('clinical_presentation', 'Unknown Presentation')}")
retrieved_info = (self.current_case, similarity_score)
else:
print(f"Error: Invalid case index {case_idx}. Must be between 0 and {len(self.retriever.dataset)-1}.")
return {"error": f"Invalid case index: {case_idx}"}
except Exception as e:
print(f"Error selecting case by index {case_idx}: {e}")
return {"error": f"Failed to select case by index: {e}"}
elif clinical_query:
# RAG-based retrieval
try:
# retrieve_relevant_case now returns a list of tuples: [(case_dict, score), ...]
retrieved_results = self.retriever.retrieve_relevant_case(clinical_query, top_k=1)
if retrieved_results: # Check if list is not empty
retrieved_info = retrieved_results[0] # Get the first tuple (case_dict, score)
self.current_case = retrieved_info[0]
similarity_score = retrieved_info[1]
print(f"Retrieved case via query ('{clinical_query}') with score {similarity_score:.4f}: {self.current_case.get('clinical_presentation', 'Unknown Presentation')}")
else:
print(f"Error: No case found for query: '{clinical_query}'")
return {"error": f"No relevant case found for query: {clinical_query}"}
except Exception as e:
print(f"Error retrieving case for query '{clinical_query}': {e}")
return {"error": f"Failed to retrieve case by query: {e}"}
else:
# No selection method provided
print("Error: Must provide either a clinical query or a case index.")
return {"error": "Please provide either a clinical query or case index."}
# --- Post-selection setup ---
if self.current_case is None:
# This should ideally be caught above, but double-check
print("Error: Failed to set current_case.")
return {"error": "Failed to load the selected case."}
# Validate case structure
if 'questions' not in self.current_case or 'answers' not in self.current_case or \
not isinstance(self.current_case['questions'], list) or \
not isinstance(self.current_case['answers'], list) or \
len(self.current_case['questions']) != len(self.current_case['answers']):
print(f"Error: Invalid case structure for case ID {self.current_case.get('case_id', 'N/A')}. Mismatched or missing Q/A lists.")
return {"error": "Selected case has invalid format."}
if not self.current_case['questions']:
print(f"Warning: Selected case ID {self.current_case.get('case_id', 'N/A')} has no questions.")
# Decide how to handle this - error or proceed? Let's return an error for now.
return {"error": "Selected case contains no questions."}
# Start a new session record
self.session_history.append({
"role": "system",
"content": f"Clinical scenario started: {self.current_case.get('clinical_presentation', 'Unknown Presentation')} (Case ID: {self.current_case.get('case_id', 'N/A')})"
})
# Get the first question
first_question = self.current_case['questions'][0]
# Record this interaction
self.session_history.append({
"role": "examiner",
"content": first_question
})
print(f"Case successfully started. Total questions: {len(self.current_case['questions'])}")
print("-" * 50)
return {
"case_id": self.current_case.get('case_id', 'unknown'),
"clinical_presentation": self.current_case.get('clinical_presentation', 'Unknown'),
"similarity_score": similarity_score, # Use the score from retrieval
"current_question": first_question,
"question_number": 1,
"total_questions": len(self.current_case['questions'])
}
def process_user_response(self, response):
"""
Process the user's answer, get feedback, and return the next question or completion status.
Args:
response (str): User's answer text.
Returns:
dict: Contains feedback, expected answer, completion status, and next question (if applicable), or an error message.
"""
if self.current_case is None:
print("Error: No active case.")
return {"error": "No active case. Please start a new case first."}
if self.current_question_idx >= len(self.current_case['questions']):
print("Error: Attempting to process response when case is already complete.")
return {"error": "Case already completed."}
print("-" * 50)
current_q_num = self.current_question_idx + 1
total_q = len(self.current_case['questions'])
print(f"Processing response for Question {current_q_num}/{total_q}")
print(f"User Response: {response}")
# Save the user's response to history
self.session_history.append({
"role": "resident",
"content": response
})
# Get the expected answer for the current question
expected_answer = self.current_case['answers'][self.current_question_idx]
print(f"Expected Answer: {expected_answer}")
# Evaluate the answer
feedback = self.evaluator.evaluate_answer(
response,
expected_answer,
clinical_context = f"Regarding the case '{self.current_case.get('clinical_presentation', 'N/A')}'"
)
print(f"Generated Feedback: {feedback}")
# Add feedback to history
self.session_history.append({
"role": "feedback",
"content": feedback
})
# Move to the next question index
self.current_question_idx += 1
# Check if the case is complete
is_complete = self.current_question_idx >= len(self.current_case['questions'])
result = {
"feedback": feedback,
"expected_answer": expected_answer,
"is_complete": is_complete,
"question_number": self.current_question_idx
}
# Add next question if not complete
if not is_complete:
next_question = self.current_case['questions'][self.current_question_idx]
result["next_question"] = next_question
result["total_questions"] = total_q
# Add next question to history
self.session_history.append({
"role": "examiner",
"content": next_question
})
print(f"Next question ({result['question_number']}/{total_q}): {next_question}")
else:
print("Case completed.")
summary = self.generate_session_summary()
result["session_summary"] = summary
self.session_history.append({
"role": "system",
"content": "End of clinical scenario."
})
print("-" * 50)
return result
def generate_session_summary(self):
"""Generate a summary dictionary of the completed session."""
if not self.current_case or not self.session_history:
return {"error": "No active or completed session to summarize."}
# Simple summary structure
return {
"case_id": self.current_case.get('case_id', 'N/A'),
"case": self.current_case.get('clinical_presentation', 'Unknown'),
"total_questions_in_case": len(self.current_case.get('questions', [])),
"interaction_history": self.session_history # Include the full log
}
def save_session(self, filepath):
"""Save the current session summary to a JSON file."""
summary = self.generate_session_summary()
if "error" in summary:
print(f"Error generating summary for saving: {summary['error']}")
return {"error": "No session to save"}
try:
# Add a timestamp to the saved data
summary["timestamp"] = pd.Timestamp.now().isoformat()
# Ensure directory exists
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, 'w') as f:
json.dump(summary, f, indent=2)
print(f"Session saved successfully to {filepath}")
return {"status": "Session saved successfully"}
except Exception as e:
print(f"Error saving session to {filepath}: {e}")
return {"error": f"Failed to save session: {e}"} |