File size: 10,903 Bytes
2247e66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import json
import pandas as pd
# Assuming retriever and evaluator classes are in these files:
from .retriever import ClinicalCaseRetriever, DummyRetriever
from .evaluator import AnswerEvaluator

class OralExamSimulator:
    """Main class that coordinates the oral board exam simulation."""

    def __init__(self, retriever, evaluator):
        if not isinstance(retriever, (ClinicalCaseRetriever, DummyRetriever)):
             raise TypeError("Retriever must be an instance of ClinicalCaseRetriever or DummyRetriever")
        if not isinstance(evaluator, AnswerEvaluator):
             raise TypeError("Evaluator must be an instance of AnswerEvaluator")

        self.retriever = retriever
        self.evaluator = evaluator
        self.current_case = None
        self.current_question_idx = 0
        self.session_history = []

    def start_new_case(self, clinical_query=None, case_idx=None):
        """
        Initialize a new exam case based on query or direct selection.

        Args:
            clinical_query (str, optional): Text description of the desired case topic.
            case_idx (int, optional): Direct index of the case to use from the retriever's dataset.

        Returns:
            dict: Contains case info and first question, or an error message.
        """
        print("-" * 50)
        print(f"Attempting to start new case | Query: '{clinical_query}' | Index: {case_idx}")

        # Reset state for the new case
        self.current_case = None
        self.current_question_idx = 0
        self.session_history = []

        # Case selection logic
        retrieved_info = None # Use a temporary variable
        if case_idx is not None:
            try:
                # Direct case selection by index
                # Ensure index is valid
                if 0 <= int(case_idx) < len(self.retriever.dataset):
                    self.current_case = self.retriever.dataset[int(case_idx)]
                    similarity_score = 1.0 # Direct selection implies perfect 'match'
                    print(f"Selected case by index {case_idx}: {self.current_case.get('clinical_presentation', 'Unknown Presentation')}")
                    retrieved_info = (self.current_case, similarity_score)
                else:
                    print(f"Error: Invalid case index {case_idx}. Must be between 0 and {len(self.retriever.dataset)-1}.")
                    return {"error": f"Invalid case index: {case_idx}"}
            except Exception as e:
                print(f"Error selecting case by index {case_idx}: {e}")
                return {"error": f"Failed to select case by index: {e}"}

        elif clinical_query:
            # RAG-based retrieval
            try:
                 # retrieve_relevant_case now returns a list of tuples: [(case_dict, score), ...]
                 retrieved_results = self.retriever.retrieve_relevant_case(clinical_query, top_k=1)
                 if retrieved_results: # Check if list is not empty
                     retrieved_info = retrieved_results[0] # Get the first tuple (case_dict, score)
                     self.current_case = retrieved_info[0]
                     similarity_score = retrieved_info[1]
                     print(f"Retrieved case via query ('{clinical_query}') with score {similarity_score:.4f}: {self.current_case.get('clinical_presentation', 'Unknown Presentation')}")
                 else:
                      print(f"Error: No case found for query: '{clinical_query}'")
                      return {"error": f"No relevant case found for query: {clinical_query}"}
            except Exception as e:
                 print(f"Error retrieving case for query '{clinical_query}': {e}")
                 return {"error": f"Failed to retrieve case by query: {e}"}
        else:
            # No selection method provided
            print("Error: Must provide either a clinical query or a case index.")
            return {"error": "Please provide either a clinical query or case index."}

        # --- Post-selection setup ---
        if self.current_case is None:
             # This should ideally be caught above, but double-check
             print("Error: Failed to set current_case.")
             return {"error": "Failed to load the selected case."}

        # Validate case structure
        if 'questions' not in self.current_case or 'answers' not in self.current_case or \
           not isinstance(self.current_case['questions'], list) or \
           not isinstance(self.current_case['answers'], list) or \
           len(self.current_case['questions']) != len(self.current_case['answers']):
            print(f"Error: Invalid case structure for case ID {self.current_case.get('case_id', 'N/A')}. Mismatched or missing Q/A lists.")
            return {"error": "Selected case has invalid format."}

        if not self.current_case['questions']:
             print(f"Warning: Selected case ID {self.current_case.get('case_id', 'N/A')} has no questions.")
             # Decide how to handle this - error or proceed? Let's return an error for now.
             return {"error": "Selected case contains no questions."}


        # Start a new session record
        self.session_history.append({
            "role": "system",
            "content": f"Clinical scenario started: {self.current_case.get('clinical_presentation', 'Unknown Presentation')} (Case ID: {self.current_case.get('case_id', 'N/A')})"
        })

        # Get the first question
        first_question = self.current_case['questions'][0]

        # Record this interaction
        self.session_history.append({
            "role": "examiner",
            "content": first_question
        })

        print(f"Case successfully started. Total questions: {len(self.current_case['questions'])}")
        print("-" * 50)

        return {
            "case_id": self.current_case.get('case_id', 'unknown'),
            "clinical_presentation": self.current_case.get('clinical_presentation', 'Unknown'),
            "similarity_score": similarity_score, # Use the score from retrieval
            "current_question": first_question,
            "question_number": 1,
            "total_questions": len(self.current_case['questions'])
        }

    def process_user_response(self, response):
        """
        Process the user's answer, get feedback, and return the next question or completion status.

        Args:
            response (str): User's answer text.

        Returns:
            dict: Contains feedback, expected answer, completion status, and next question (if applicable), or an error message.
        """
        if self.current_case is None:
            print("Error: No active case.")
            return {"error": "No active case. Please start a new case first."}

        if self.current_question_idx >= len(self.current_case['questions']):
             print("Error: Attempting to process response when case is already complete.")
             return {"error": "Case already completed."}

        print("-" * 50)
        current_q_num = self.current_question_idx + 1
        total_q = len(self.current_case['questions'])
        print(f"Processing response for Question {current_q_num}/{total_q}")
        print(f"User Response: {response}")

        # Save the user's response to history
        self.session_history.append({
            "role": "resident",
            "content": response
        })

        # Get the expected answer for the current question
        expected_answer = self.current_case['answers'][self.current_question_idx]
        print(f"Expected Answer: {expected_answer}")

        # Evaluate the answer
        feedback = self.evaluator.evaluate_answer(
            response,
            expected_answer,
            clinical_context = f"Regarding the case '{self.current_case.get('clinical_presentation', 'N/A')}'"
        )
        print(f"Generated Feedback: {feedback}")


        # Add feedback to history
        self.session_history.append({
            "role": "feedback",
            "content": feedback
        })

        # Move to the next question index
        self.current_question_idx += 1

        # Check if the case is complete
        is_complete = self.current_question_idx >= len(self.current_case['questions'])

        result = {
            "feedback": feedback,
            "expected_answer": expected_answer,
            "is_complete": is_complete,
            "question_number": self.current_question_idx
        }

        # Add next question if not complete
        if not is_complete:
            next_question = self.current_case['questions'][self.current_question_idx]
            result["next_question"] = next_question
            result["total_questions"] = total_q

            # Add next question to history
            self.session_history.append({
                "role": "examiner",
                "content": next_question
            })
            print(f"Next question ({result['question_number']}/{total_q}): {next_question}")
        else:
            print("Case completed.")
            summary = self.generate_session_summary()
            result["session_summary"] = summary
            self.session_history.append({
                 "role": "system",
                 "content": "End of clinical scenario."
            })


        print("-" * 50)
        return result

    def generate_session_summary(self):
        """Generate a summary dictionary of the completed session."""
        if not self.current_case or not self.session_history:
             return {"error": "No active or completed session to summarize."}

        # Simple summary structure
        return {
            "case_id": self.current_case.get('case_id', 'N/A'),
            "case": self.current_case.get('clinical_presentation', 'Unknown'),
            "total_questions_in_case": len(self.current_case.get('questions', [])),
            "interaction_history": self.session_history # Include the full log
        }

    def save_session(self, filepath):
        """Save the current session summary to a JSON file."""
        summary = self.generate_session_summary()
        if "error" in summary:
            print(f"Error generating summary for saving: {summary['error']}")
            return {"error": "No session to save"}

        try:
             # Add a timestamp to the saved data
             summary["timestamp"] = pd.Timestamp.now().isoformat()

             # Ensure directory exists
             os.makedirs(os.path.dirname(filepath), exist_ok=True)

             with open(filepath, 'w') as f:
                 json.dump(summary, f, indent=2)
             print(f"Session saved successfully to {filepath}")
             return {"status": "Session saved successfully"}
        except Exception as e:
             print(f"Error saving session to {filepath}: {e}")
             return {"error": f"Failed to save session: {e}"}