melmoheb commited on
Commit
129641e
·
verified ·
1 Parent(s): f97faaa

Upload 6 files

Browse files
src/data_processing.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import pandas as pd
4
+ import numpy as np
5
+ import torch
6
+ from docx import Document
7
+ from sentence_transformers import SentenceTransformer
8
+ from datasets import Dataset
9
+ from tqdm import tqdm
10
+
11
+ def read_docx(file_path):
12
+ """Reads text content from a .docx file."""
13
+ try:
14
+ doc = Document(file_path)
15
+ return '\n'.join(para.text for para in doc.paragraphs)
16
+ except Exception as e:
17
+ print(f"Error reading {file_path}: {e}")
18
+ return ""
19
+
20
+ def extract_qa_pairs(text):
21
+ """Extracts alternating Examiner and Examinee Q&A pairs from text."""
22
+ pattern = re.compile(r"\*\*Examiner:\*\*(.*?)\n\n\*\*Examinee:\*\*(.*?)(?=\n\n\*\*Examiner:\*\*|$)", re.DOTALL)
23
+ return [{"question": q.strip(), "answer": a.strip()} for q, a in pattern.findall(text)]
24
+
25
+ def parse_filename(filename):
26
+ """Parses case ID and topic from BTK filename format."""
27
+ # Example: BTK_-_77A___Burn.docx -> case_id = 77A, clinical_presentation = Burn
28
+ base = os.path.splitext(filename)[0]
29
+ match = re.match(r"BTK_-_(\d+[A-Z]?)___(.+)", base)
30
+ if match:
31
+ case_id = match.group(1)
32
+ topic = match.group(2).replace("_", " ").strip()
33
+ else:
34
+ # Handle potential variations or log unknown formats if needed
35
+ print(f"Warning: Could not parse filename format: {filename}")
36
+ case_id, topic = "Unknown", "Unknown"
37
+ return case_id, topic
38
+
39
+ def process_all_cases(folder_path):
40
+ """Reads all .docx files in a folder and structures them into a DataFrame."""
41
+ rows = []
42
+ if not os.path.isdir(folder_path):
43
+ print(f"Error: Folder not found at {folder_path}")
44
+ return pd.DataFrame(rows)
45
+
46
+ print(f"Processing case files from: {folder_path}")
47
+ for filename in os.listdir(folder_path):
48
+ if filename.lower().endswith('.docx') and not filename.startswith('~'): # Avoid temp files
49
+ file_path = os.path.join(folder_path, filename)
50
+ text = read_docx(file_path)
51
+ if text:
52
+ qa_pairs = extract_qa_pairs(text)
53
+ case_id, presentation = parse_filename(filename)
54
+ if not qa_pairs:
55
+ print(f"Warning: No Q&A pairs extracted from {filename}")
56
+ for i, pair in enumerate(qa_pairs):
57
+ rows.append({
58
+ "case_id": case_id,
59
+ "clinical_presentation": presentation,
60
+ "turn_id": i + 1,
61
+ "question": pair["question"],
62
+ "answer": pair["answer"]
63
+ })
64
+ else:
65
+ print(f"Warning: Empty content for file {filename}")
66
+
67
+ if not rows:
68
+ print("Warning: No data rows were generated. Check input files and formats.")
69
+
70
+ return pd.DataFrame(rows)
71
+
72
+
73
+ # --- ClinicalCaseProcessor Class ---
74
+ class ClinicalCaseProcessor:
75
+ """Handles preprocessing of clinical cases for the RAG system using sentence-transformers."""
76
+ def __init__(self, model_name="all-MiniLM-L6-v2"):
77
+ print(f"Initializing ClinicalCaseProcessor with model: {model_name}")
78
+ self.model = SentenceTransformer(model_name)
79
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
80
+ print(f"Using device: {self.device}")
81
+ self.model.to(self.device)
82
+
83
+ def preprocess_data(self, input_data, output_path="./processed_clinical_cases", batch_size=16):
84
+ """
85
+ Convert raw case data (DataFrame or path to CSV) into a vectorized Hugging Face dataset.
86
+
87
+ Args:
88
+ input_data: DataFrame or path to CSV file with clinical cases.
89
+ output_path: Where to save the processed Hugging Face dataset.
90
+ batch_size: Batch size for embedding generation.
91
+
92
+ Returns:
93
+ datasets.Dataset: The processed dataset with embeddings.
94
+ """
95
+ # Load data
96
+ if isinstance(input_data, pd.DataFrame):
97
+ df = input_data
98
+ print("Using provided DataFrame.")
99
+ elif isinstance(input_data, str) and os.path.exists(input_data):
100
+ try:
101
+ df = pd.read_csv(input_data)
102
+ print(f"Data loaded from CSV: {input_data}")
103
+ except Exception as e:
104
+ print(f"Error loading CSV {input_data}: {e}")
105
+ return None
106
+ else:
107
+ print(f"Error: Invalid input_data type or path does not exist: {input_data}")
108
+ return None
109
+
110
+ if df.empty:
111
+ print("Error: Input DataFrame is empty. Cannot process.")
112
+ return None
113
+
114
+ print(f"Raw data shape: {df.shape}")
115
+
116
+ # Validate necessary columns
117
+ required_cols = ['case_id', 'clinical_presentation', 'turn_id', 'question', 'answer']
118
+ if not all(col in df.columns for col in required_cols):
119
+ print(f"Error: DataFrame missing required columns. Found: {df.columns}. Required: {required_cols}")
120
+ return None
121
+
122
+ # Group by case_id to get all Q&A pairs for each case
123
+ grouped = df.groupby(['case_id', 'clinical_presentation'], dropna=False)
124
+
125
+ # Create a new dataframe with one row per case
126
+ case_data = []
127
+ print("Grouping data by case...")
128
+ for (case_id, presentation), group in tqdm(grouped, desc="Processing Cases"):
129
+ # Sort by turn_id to ensure correct order
130
+ group = group.sort_values('turn_id')
131
+
132
+ # Extract questions and answers in order
133
+ questions = group['question'].tolist()
134
+ answers = group['answer'].tolist()
135
+
136
+ # Handle potential NaN/None in presentation if groupby didn't drop them
137
+ presentation_str = str(presentation) if pd.notna(presentation) else "Unknown Presentation"
138
+
139
+ case_data.append({
140
+ 'case_id': str(case_id) if pd.notna(case_id) else "Unknown ID",
141
+ 'clinical_presentation': presentation_str,
142
+ 'questions': questions,
143
+ 'answers': answers
144
+ })
145
+
146
+ if not case_data:
147
+ print("Error: No cases could be processed after grouping. Check input data integrity.")
148
+ return None
149
+
150
+ processed_df = pd.DataFrame(case_data)
151
+ print(f"Processed data into {len(processed_df)} unique cases.")
152
+
153
+ # Create a searchable summary of each case (handle empty question lists)
154
+ processed_df['case_summary'] = processed_df.apply(
155
+ lambda x: f"Clinical case: {x['clinical_presentation']}. First question: {x['questions'][0] if x['questions'] else 'No questions available'}",
156
+ axis=1
157
+ )
158
+
159
+ # Generate embeddings using sentence-transformers
160
+ texts_to_embed = processed_df['case_summary'].tolist()
161
+ all_embeddings = []
162
+
163
+ print(f"Generating embeddings for {len(texts_to_embed)} case summaries...")
164
+ try:
165
+ for i in tqdm(range(0, len(texts_to_embed), batch_size), desc="Embedding Batches"):
166
+ batch_texts = texts_to_embed[i:i+batch_size]
167
+ # Generate embeddings for the batch
168
+ batch_embeddings = self.model.encode(batch_texts, convert_to_numpy=True, device=self.device, show_progress_bar=False)
169
+ all_embeddings.append(batch_embeddings)
170
+
171
+ # Combine all batch embeddings
172
+ if not all_embeddings:
173
+ print("Error: No embeddings were generated.")
174
+ return None
175
+ final_embeddings = np.vstack(all_embeddings)
176
+ print(f"Generated embeddings with shape: {final_embeddings.shape}")
177
+
178
+ except Exception as e:
179
+ print(f"Error during embedding generation: {e}")
180
+ return None
181
+
182
+
183
+ # Convert to HF Dataset and add embeddings
184
+ try:
185
+ dataset = Dataset.from_pandas(processed_df)
186
+ # Ensure embeddings column is compatible (list of lists)
187
+ dataset = dataset.add_column('embeddings', final_embeddings.tolist())
188
+ except Exception as e:
189
+ print(f"Error converting to Hugging Face Dataset or adding embeddings: {e}")
190
+ return None
191
+
192
+ # Save processed dataset
193
+ try:
194
+ os.makedirs(output_path, exist_ok=True) # Ensure directory exists
195
+ dataset.save_to_disk(output_path)
196
+ print(f"Processed dataset saved successfully to {output_path}")
197
+ except Exception as e:
198
+ print(f"Error saving dataset to disk at {output_path}: {e}")
199
+ return None # Return None if saving failed
200
+
201
+ return dataset
src/evaluation_utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn.metrics import ndcg_score
3
+ from src.retriever import ClinicalCaseRetriever, DummyRetriever
4
+
5
+ def retrieval_metrics(retriever_instance: ClinicalCaseRetriever, queries: list[str], gold_ids: list[str], k: int = 5) -> dict | None:
6
+ """
7
+ Calculates retrieval metrics for a set of queries.
8
+
9
+ Args:
10
+ retriever_instance: An initialized ClinicalCaseRetriever instance.
11
+ queries: A list of query strings.
12
+ gold_ids: A list of the expected 'case_id' strings for each query.
13
+ k: The number of top results to consider for Hit@k and NDCG@k.
14
+
15
+ Returns:
16
+ A dictionary containing Hit@k, MRR, and NDCG@k scores, or None on error.
17
+ """
18
+ # --- Initialization ---
19
+ hits, reciprocal_ranks, ndcgs = [], [], []
20
+ print(f"\nCalculating retrieval metrics for {len(queries)} queries (k={k})...")
21
+
22
+ # --- Process Each Query ---
23
+ for q_idx, (q, gold) in enumerate(zip(queries, gold_ids)):
24
+ print(f"\nProcessing query {q_idx+1}/{len(queries)}: '{q}' (Expected ID: '{gold}')")
25
+ retrieved_cases, scores = retriever_instance.retrieve_relevant_case(q, top_k=k, return_scores=True)
26
+
27
+ # Safely extract IDs, handle missing keys
28
+ retrieved_ids = [c.get('case_id', 'N/A') for c in retrieved_cases]
29
+ print(f"Retrieved IDs: {retrieved_ids}")
30
+ print(f"Retrieved Scores: {[round(s, 4) for s in scores]}")
31
+
32
+ # --- Calculate Metrics ---
33
+ is_hit = int(gold in retrieved_ids)
34
+ hits.append(is_hit)
35
+
36
+ rank = 0
37
+ if is_hit:
38
+ rank = retrieved_ids.index(gold) + 1
39
+ reciprocal_ranks.append(1.0 / rank)
40
+ else:
41
+ reciprocal_ranks.append(0.0)
42
+
43
+ # NDCG calculation
44
+ true_relevance = np.asarray([[1.0 if gid == gold else 0.0 for gid in retrieved_ids]])
45
+ predicted_scores = np.asarray([scores])
46
+
47
+ current_ndcg = 0.0
48
+ if true_relevance.shape[1] > 0:
49
+ ndcg_k = min(k, true_relevance.shape[1]) # Ensure k is not out of bounds
50
+ current_ndcg = ndcg_score(true_relevance, predicted_scores, k=ndcg_k)
51
+ ndcgs.append(current_ndcg)
52
+
53
+ print(f"Hit: {is_hit}, Rank: {rank if rank > 0 else 'N/A'}, NDCG@{k}: {current_ndcg:.4f}")
54
+
55
+ # --- Aggregate Results ---
56
+ avg_hit = np.mean(hits) if hits else 0.0
57
+ avg_mrr = np.mean(reciprocal_ranks) if reciprocal_ranks else 0.0
58
+ avg_ndcg = np.mean(ndcgs) if ndcgs else 0.0
59
+
60
+ print(f"\n--- Overall Retrieval Results (k={k}) --- ")
61
+ print(f"Average Hit@{k}: {avg_hit:.4f}")
62
+ print(f"Average MRR: {avg_mrr:.4f}") # Corrected spacing for alignment
63
+ print(f"Average NDCG@{k}: {avg_ndcg:.4f}")
64
+
65
+ return {f"Hit@{k}": avg_hit,
66
+ f"MRR": avg_mrr,
67
+ f"NDCG@{k}": avg_ndcg}
src/evaluator.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+
4
+ class AnswerEvaluator:
5
+ """Evaluates user answers against expected answers using an LLM."""
6
+
7
+ def __init__(self, model_id="meta-llama/Llama-3.2-3B-Instruct"):
8
+ print(f"Initializing AnswerEvaluator with model: {model_id}")
9
+ try:
10
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
11
+ if self.tokenizer.pad_token is None:
12
+ self.tokenizer.pad_token = self.tokenizer.eos_token
13
+ print("Set pad_token to eos_token")
14
+
15
+ self.model = AutoModelForCausalLM.from_pretrained(
16
+ model_id,
17
+ torch_dtype=torch.float16,
18
+ device_map="auto"
19
+ )
20
+ self.model.eval()
21
+ self.device = self.model.device
22
+ print(f"AnswerEvaluator model loaded successfully on device: {self.device}")
23
+
24
+ except Exception as e:
25
+ print(f"Error initializing AnswerEvaluator model {model_id}: {e}")
26
+ raise
27
+
28
+
29
+ def evaluate_answer(self, user_answer, expected_answer, clinical_context=None):
30
+ """
31
+ Compare user answer to expected answer and provide feedback
32
+
33
+ Args:
34
+ user_answer: Examinee's response
35
+ expected_answer: Model answer from the dataset
36
+ clinical_context: Optional clinical context to consider
37
+
38
+ Returns:
39
+ Feedback string
40
+ """
41
+ context_str = f"Clinical context: {clinical_context}\n\n" if clinical_context else ""
42
+
43
+ prompt = f"""<s>[INST] You are acting as an expert examiner for the American Board of Surgery (ABS) oral board exam. You are evaluating a general surgery resident’s answer to a clinical question. \n
44
+ Compare the answer provided by the residents to the correct expected answer, which I will provide you with. \n
45
+ Use the grading rubric below to assess their response:
46
+
47
+ [RUBRIC]
48
+ - Correct: Resident includes all major points and clinical reasoning aligns closely with the expected answer.
49
+ - Partially Correct: Resident includes some key points but omits others, or reasoning is partially flawed.
50
+ - Incorrect: Resident misses most key points or demonstrates incorrect reasoning.
51
+
52
+ {context_str}Here is the model answer that contains the key points expected from the resident:
53
+ {expected_answer}
54
+
55
+ Now, here is the resident’s actual response:
56
+ {user_answer}
57
+
58
+ Evaluate the resident’s response based **only** on the expected answer above. Do not rely on external knowledge or previous responses.
59
+
60
+ Focus your evaluation on:
61
+ 1. Which key points were mentioned vs. missed
62
+ 2. The accuracy and clarity of the clinical reasoning
63
+ 3. Any major omissions or misunderstandings
64
+
65
+ Start your output with:
66
+ ASSESSMENT: [Correct / Partially Correct / Incorrect]
67
+ Then write 1–2 clear, specific sentences explaining how the resident’s response compares to the expected answer.
68
+
69
+ [EXAMPLE 1]
70
+ Expected answer:
71
+ "The differential diagnosis includes acute appendicitis, mesenteric adenitis, gastroenteritis, UTI, and testicular torsion."
72
+
73
+ Resident’s response:
74
+ "My top concern is appendicitis, but I’d also consider things like gastroenteritis or maybe even kidney stones."
75
+
76
+ ASSESSMENT: Partially Correct
77
+ The resident mentioned appendicitis and gastroenteritis but missed several other expected differentials like UTI, testicular torsion, and mesenteric adenitis.
78
+
79
+ [EXAMPLE 2]
80
+ Expected answer:
81
+ "Initial labs should include CBC, CMP, lipase, and abdominal ultrasound to assess for gallstones."
82
+
83
+ Resident’s response:
84
+ "I’d start with a full workup including CBC, liver enzymes, lipase, and an abdominal ultrasound."
85
+
86
+ ASSESSMENT: Correct
87
+ The resident included all key labs and the correct imaging modality. Their reasoning aligns well with the expected answer.
88
+
89
+ [/INST]</s>"""
90
+
91
+ try:
92
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(self.device) # Added truncation
93
+
94
+ with torch.no_grad():
95
+ # Generate feedback using the model
96
+ outputs = self.model.generate(
97
+ **inputs,
98
+ max_new_tokens=150,
99
+ do_sample=True,
100
+ temperature = 0.2,
101
+ pad_token_id=self.tokenizer.eos_token_id # Ensure pad token ID is set
102
+ )
103
+
104
+ prompt_length_tokens = inputs.input_ids.shape[1]
105
+ generated_ids = outputs[0][prompt_length_tokens:]
106
+
107
+ feedback = self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
108
+
109
+ return feedback
110
+
111
+ except Exception as e:
112
+ print(f"Error during LLM evaluation: {e}")
113
+ return "Error: Could not generate feedback."
src/retriever.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from sentence_transformers import SentenceTransformer
4
+ from sklearn.metrics.pairwise import cosine_similarity
5
+ from datasets import load_from_disk, Dataset
6
+ import os
7
+ import pandas as pd
8
+
9
+ class ClinicalCaseRetriever:
10
+ """Retrieves relevant clinical cases based on user input using sentence-transformers embeddings."""
11
+
12
+ def __init__(self, dataset_path='./processed_clinical_cases', model_name="all-MiniLM-L6-v2"):
13
+ print(f"Initializing ClinicalCaseRetriever with model: {model_name}")
14
+ if isinstance(dataset_path, Dataset):
15
+ self.dataset = dataset_path
16
+ print("Using provided Hugging Face Dataset object.")
17
+ elif isinstance(dataset_path, str) and os.path.isdir(dataset_path):
18
+ try:
19
+ self.dataset = load_from_disk(dataset_path)
20
+ print(f"Dataset loaded successfully from disk: {dataset_path}")
21
+ except Exception as e:
22
+ print(f"Error loading dataset from disk {dataset_path}: {e}")
23
+ raise ValueError("Failed to load dataset.") from e
24
+ else:
25
+ raise ValueError(f"Invalid dataset_path: Must be a Dataset object or a valid directory path. Got: {dataset_path}")
26
+
27
+ if 'embeddings' not in self.dataset.column_names:
28
+ raise ValueError("Dataset must contain an 'embeddings' column.")
29
+
30
+ print(f"Dataset features: {self.dataset.features}")
31
+ print(f"Number of cases in dataset: {len(self.dataset)}")
32
+
33
+
34
+ self.model = SentenceTransformer(model_name)
35
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+ print(f"Using device: {self.device}")
37
+ self.model.to(self.device)
38
+
39
+ try:
40
+ # Ensure embeddings are loaded as a NumPy array
41
+ self.case_embeddings = np.array(self.dataset['embeddings'])
42
+ if self.case_embeddings.ndim != 2:
43
+ raise ValueError(f"Embeddings array must be 2-dimensional. Got shape: {self.case_embeddings.shape}")
44
+ print(f"Loaded {len(self.dataset)} cases with embeddings of shape {self.case_embeddings.shape}")
45
+ except Exception as e:
46
+ print(f"Error processing embeddings from dataset: {e}")
47
+ raise ValueError("Failed to load or process embeddings.") from e
48
+
49
+
50
+ def get_available_cases(self, n=5):
51
+ """Return a sample of available cases for user selection."""
52
+ num_cases = len(self.dataset)
53
+ if num_cases == 0:
54
+ return []
55
+ sample_size = min(n, num_cases)
56
+ indices = np.random.choice(num_cases, sample_size, replace=False)
57
+ # Ensure indices are int for slicing dataset
58
+ return [(int(i), self.dataset[int(i)]['clinical_presentation']) for i in indices]
59
+
60
+ def encode_query(self, query):
61
+ """Generate embedding for a query string."""
62
+ # Create a better search query structure
63
+ search_query = f"Clinical case about {query}"
64
+ print(f"Encoding query: '{search_query}'")
65
+
66
+ # Generate embedding using sentence-transformers
67
+ try:
68
+ query_embedding = self.model.encode([search_query], convert_to_numpy=True, device=self.device, show_progress_bar=False)
69
+ return query_embedding
70
+ except Exception as e:
71
+ print(f"Error encoding query '{query}': {e}")
72
+ return None # Or raise an error
73
+
74
+ def retrieve_relevant_case(self, query, top_k=1, return_scores=False):
75
+ """Find the most relevant clinical case(s) given a query."""
76
+ if not isinstance(top_k, int) or top_k < 1:
77
+ print("Warning: top_k must be a positive integer. Defaulting to 1.")
78
+ top_k = 1
79
+
80
+ # Get query embedding
81
+ query_embedding = self.encode_query(query)
82
+ if query_embedding is None:
83
+ return [] if not return_scores else ([], [])
84
+
85
+ # Calculate similarity scores
86
+ try:
87
+ similarities = cosine_similarity(query_embedding, self.case_embeddings)[0] # Get the single row of similarities
88
+ except Exception as e:
89
+ print(f"Error calculating cosine similarity: {e}")
90
+ return [] if not return_scores else ([], [])
91
+
92
+ # Get indices of top-k most similar cases
93
+ # Ensure we don't request more indices than available cases
94
+ k_actual = min(top_k, len(similarities))
95
+ if k_actual == 0: # Should not happen if dataset loaded, but safe check
96
+ return [] if not return_scores else ([], [])
97
+
98
+ # Use partitioning for efficiency if k is much smaller than N, or argsort otherwise
99
+ # Using argsort is generally simpler and fine for moderate N
100
+ top_indices = np.argsort(similarities)[-k_actual:][::-1].astype(int) # Get top k indices, sorted descending
101
+
102
+ top_scores = similarities[top_indices].tolist() # Get scores for these indices
103
+
104
+ # Return the most relevant case(s)
105
+ try:
106
+ # Retrieve cases safely using integer indices
107
+ retrieved_cases = [self.dataset[int(idx)] for idx in top_indices]
108
+ except IndexError as e:
109
+ print(f"Error retrieving cases using indices {top_indices}: {e}")
110
+ return [] if not return_scores else ([], [])
111
+ except Exception as e:
112
+ print(f"Unexpected error retrieving cases: {e}")
113
+ return [] if not return_scores else ([], [])
114
+
115
+
116
+ results_with_scores = list(zip(retrieved_cases, top_scores))
117
+ print(f"Retrieved {len(results_with_scores)} cases with similarity scores:")
118
+ for case, score in results_with_scores:
119
+ # Safely access presentation, provide default if missing
120
+ presentation = case.get('clinical_presentation', 'Unknown Presentation')
121
+ print(f"- {presentation}: {score:.4f}")
122
+
123
+ if return_scores:
124
+ return retrieved_cases, top_scores
125
+ else:
126
+ # Return list of tuples (case_dict, score)
127
+ return results_with_scores
128
+
129
+
130
+ class DummyRetriever:
131
+ """A simple retriever that bypasses RAG, taking a pre-formatted DataFrame."""
132
+
133
+ def __init__(self, df):
134
+ self.dataset = []
135
+ if not isinstance(df, pd.DataFrame) or df.empty:
136
+ print("Warning: DummyRetriever initialized with empty or invalid DataFrame.")
137
+ return
138
+
139
+ # Expects df to be pre-processed with columns:
140
+ # 'clinical_presentation', 'turn_id', 'question', 'answer'
141
+ required_cols = ['clinical_presentation', 'turn_id', 'question', 'answer']
142
+ if not all(col in df.columns for col in required_cols):
143
+ print(f"Warning: DummyRetriever DataFrame missing required columns. Need: {required_cols}")
144
+ return
145
+
146
+ grouped = df.groupby('clinical_presentation')
147
+ print(f"DummyRetriever processing {len(grouped)} unique presentations.")
148
+ for i, (scenario, group) in enumerate(grouped):
149
+ group_sorted = group.sort_values('turn_id')
150
+
151
+ case_dict = {
152
+ "case_id": group_sorted['case_id'].iloc[0] if 'case_id' in group_sorted.columns else f"dummy_{i}",
153
+ "clinical_presentation": scenario,
154
+ "questions": group_sorted["question"].tolist(),
155
+ "answers": group_sorted["answer"].tolist()
156
+ }
157
+ self.dataset.append(case_dict)
158
+ print(f"DummyRetriever initialized with {len(self.dataset)} cases.")
159
+
160
+ def retrieve_relevant_case(self, scenario_query, top_k=1):
161
+ """
162
+ Finds the case matching the query string exactly.
163
+ Ignores 'top_k' but mimics the return structure [(case_dict, score)].
164
+ """
165
+ print(f"DummyRetriever searching for exact match: '{scenario_query}'")
166
+ for case_dict in self.dataset:
167
+ if case_dict["clinical_presentation"] == scenario_query:
168
+ print(f"DummyRetriever found match: {case_dict['clinical_presentation']}")
169
+ return [(case_dict, 1.0)]
170
+
171
+ print(f"DummyRetriever: No exact match found for '{scenario_query}'")
172
+ return []
src/simulator.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pandas as pd
3
+ # Assuming retriever and evaluator classes are in these files:
4
+ from .retriever import ClinicalCaseRetriever, DummyRetriever
5
+ from .evaluator import AnswerEvaluator
6
+
7
+ class OralExamSimulator:
8
+ """Main class that coordinates the oral board exam simulation."""
9
+
10
+ def __init__(self, retriever, evaluator):
11
+ if not isinstance(retriever, (ClinicalCaseRetriever, DummyRetriever)):
12
+ raise TypeError("Retriever must be an instance of ClinicalCaseRetriever or DummyRetriever")
13
+ if not isinstance(evaluator, AnswerEvaluator):
14
+ raise TypeError("Evaluator must be an instance of AnswerEvaluator")
15
+
16
+ self.retriever = retriever
17
+ self.evaluator = evaluator
18
+ self.current_case = None
19
+ self.current_question_idx = 0
20
+ self.session_history = []
21
+
22
+ def start_new_case(self, clinical_query=None, case_idx=None):
23
+ """
24
+ Initialize a new exam case based on query or direct selection.
25
+
26
+ Args:
27
+ clinical_query (str, optional): Text description of the desired case topic.
28
+ case_idx (int, optional): Direct index of the case to use from the retriever's dataset.
29
+
30
+ Returns:
31
+ dict: Contains case info and first question, or an error message.
32
+ """
33
+ print("-" * 50)
34
+ print(f"Attempting to start new case | Query: '{clinical_query}' | Index: {case_idx}")
35
+
36
+ # Reset state for the new case
37
+ self.current_case = None
38
+ self.current_question_idx = 0
39
+ self.session_history = []
40
+
41
+ # Case selection logic
42
+ retrieved_info = None # Use a temporary variable
43
+ if case_idx is not None:
44
+ try:
45
+ # Direct case selection by index
46
+ # Ensure index is valid
47
+ if 0 <= int(case_idx) < len(self.retriever.dataset):
48
+ self.current_case = self.retriever.dataset[int(case_idx)]
49
+ similarity_score = 1.0 # Direct selection implies perfect 'match'
50
+ print(f"Selected case by index {case_idx}: {self.current_case.get('clinical_presentation', 'Unknown Presentation')}")
51
+ retrieved_info = (self.current_case, similarity_score)
52
+ else:
53
+ print(f"Error: Invalid case index {case_idx}. Must be between 0 and {len(self.retriever.dataset)-1}.")
54
+ return {"error": f"Invalid case index: {case_idx}"}
55
+ except Exception as e:
56
+ print(f"Error selecting case by index {case_idx}: {e}")
57
+ return {"error": f"Failed to select case by index: {e}"}
58
+
59
+ elif clinical_query:
60
+ # RAG-based retrieval
61
+ try:
62
+ # retrieve_relevant_case now returns a list of tuples: [(case_dict, score), ...]
63
+ retrieved_results = self.retriever.retrieve_relevant_case(clinical_query, top_k=1)
64
+ if retrieved_results: # Check if list is not empty
65
+ retrieved_info = retrieved_results[0] # Get the first tuple (case_dict, score)
66
+ self.current_case = retrieved_info[0]
67
+ similarity_score = retrieved_info[1]
68
+ print(f"Retrieved case via query ('{clinical_query}') with score {similarity_score:.4f}: {self.current_case.get('clinical_presentation', 'Unknown Presentation')}")
69
+ else:
70
+ print(f"Error: No case found for query: '{clinical_query}'")
71
+ return {"error": f"No relevant case found for query: {clinical_query}"}
72
+ except Exception as e:
73
+ print(f"Error retrieving case for query '{clinical_query}': {e}")
74
+ return {"error": f"Failed to retrieve case by query: {e}"}
75
+ else:
76
+ # No selection method provided
77
+ print("Error: Must provide either a clinical query or a case index.")
78
+ return {"error": "Please provide either a clinical query or case index."}
79
+
80
+ # --- Post-selection setup ---
81
+ if self.current_case is None:
82
+ # This should ideally be caught above, but double-check
83
+ print("Error: Failed to set current_case.")
84
+ return {"error": "Failed to load the selected case."}
85
+
86
+ # Validate case structure
87
+ if 'questions' not in self.current_case or 'answers' not in self.current_case or \
88
+ not isinstance(self.current_case['questions'], list) or \
89
+ not isinstance(self.current_case['answers'], list) or \
90
+ len(self.current_case['questions']) != len(self.current_case['answers']):
91
+ print(f"Error: Invalid case structure for case ID {self.current_case.get('case_id', 'N/A')}. Mismatched or missing Q/A lists.")
92
+ return {"error": "Selected case has invalid format."}
93
+
94
+ if not self.current_case['questions']:
95
+ print(f"Warning: Selected case ID {self.current_case.get('case_id', 'N/A')} has no questions.")
96
+ # Decide how to handle this - error or proceed? Let's return an error for now.
97
+ return {"error": "Selected case contains no questions."}
98
+
99
+
100
+ # Start a new session record
101
+ self.session_history.append({
102
+ "role": "system",
103
+ "content": f"Clinical scenario started: {self.current_case.get('clinical_presentation', 'Unknown Presentation')} (Case ID: {self.current_case.get('case_id', 'N/A')})"
104
+ })
105
+
106
+ # Get the first question
107
+ first_question = self.current_case['questions'][0]
108
+
109
+ # Record this interaction
110
+ self.session_history.append({
111
+ "role": "examiner",
112
+ "content": first_question
113
+ })
114
+
115
+ print(f"Case successfully started. Total questions: {len(self.current_case['questions'])}")
116
+ print("-" * 50)
117
+
118
+ return {
119
+ "case_id": self.current_case.get('case_id', 'unknown'),
120
+ "clinical_presentation": self.current_case.get('clinical_presentation', 'Unknown'),
121
+ "similarity_score": similarity_score, # Use the score from retrieval
122
+ "current_question": first_question,
123
+ "question_number": 1,
124
+ "total_questions": len(self.current_case['questions'])
125
+ }
126
+
127
+ def process_user_response(self, response):
128
+ """
129
+ Process the user's answer, get feedback, and return the next question or completion status.
130
+
131
+ Args:
132
+ response (str): User's answer text.
133
+
134
+ Returns:
135
+ dict: Contains feedback, expected answer, completion status, and next question (if applicable), or an error message.
136
+ """
137
+ if self.current_case is None:
138
+ print("Error: No active case.")
139
+ return {"error": "No active case. Please start a new case first."}
140
+
141
+ if self.current_question_idx >= len(self.current_case['questions']):
142
+ print("Error: Attempting to process response when case is already complete.")
143
+ return {"error": "Case already completed."}
144
+
145
+ print("-" * 50)
146
+ current_q_num = self.current_question_idx + 1
147
+ total_q = len(self.current_case['questions'])
148
+ print(f"Processing response for Question {current_q_num}/{total_q}")
149
+ print(f"User Response: {response}")
150
+
151
+ # Save the user's response to history
152
+ self.session_history.append({
153
+ "role": "resident",
154
+ "content": response
155
+ })
156
+
157
+ # Get the expected answer for the current question
158
+ expected_answer = self.current_case['answers'][self.current_question_idx]
159
+ print(f"Expected Answer: {expected_answer}")
160
+
161
+ # Evaluate the answer
162
+ feedback = self.evaluator.evaluate_answer(
163
+ response,
164
+ expected_answer,
165
+ clinical_context = f"Regarding the case '{self.current_case.get('clinical_presentation', 'N/A')}'"
166
+ )
167
+ print(f"Generated Feedback: {feedback}")
168
+
169
+
170
+ # Add feedback to history
171
+ self.session_history.append({
172
+ "role": "feedback",
173
+ "content": feedback
174
+ })
175
+
176
+ # Move to the next question index
177
+ self.current_question_idx += 1
178
+
179
+ # Check if the case is complete
180
+ is_complete = self.current_question_idx >= len(self.current_case['questions'])
181
+
182
+ result = {
183
+ "feedback": feedback,
184
+ "expected_answer": expected_answer,
185
+ "is_complete": is_complete,
186
+ "question_number": self.current_question_idx
187
+ }
188
+
189
+ # Add next question if not complete
190
+ if not is_complete:
191
+ next_question = self.current_case['questions'][self.current_question_idx]
192
+ result["next_question"] = next_question
193
+ result["total_questions"] = total_q
194
+
195
+ # Add next question to history
196
+ self.session_history.append({
197
+ "role": "examiner",
198
+ "content": next_question
199
+ })
200
+ print(f"Next question ({result['question_number']}/{total_q}): {next_question}")
201
+ else:
202
+ print("Case completed.")
203
+ summary = self.generate_session_summary()
204
+ result["session_summary"] = summary
205
+ self.session_history.append({
206
+ "role": "system",
207
+ "content": "End of clinical scenario."
208
+ })
209
+
210
+
211
+ print("-" * 50)
212
+ return result
213
+
214
+ def generate_session_summary(self):
215
+ """Generate a summary dictionary of the completed session."""
216
+ if not self.current_case or not self.session_history:
217
+ return {"error": "No active or completed session to summarize."}
218
+
219
+ # Simple summary structure
220
+ return {
221
+ "case_id": self.current_case.get('case_id', 'N/A'),
222
+ "case": self.current_case.get('clinical_presentation', 'Unknown'),
223
+ "total_questions_in_case": len(self.current_case.get('questions', [])),
224
+ "interaction_history": self.session_history # Include the full log
225
+ }
226
+
227
+ def save_session(self, filepath):
228
+ """Save the current session summary to a JSON file."""
229
+ summary = self.generate_session_summary()
230
+ if "error" in summary:
231
+ print(f"Error generating summary for saving: {summary['error']}")
232
+ return {"error": "No session to save"}
233
+
234
+ try:
235
+ # Add a timestamp to the saved data
236
+ summary["timestamp"] = pd.Timestamp.now().isoformat()
237
+
238
+ # Ensure directory exists
239
+ os.makedirs(os.path.dirname(filepath), exist_ok=True)
240
+
241
+ with open(filepath, 'w') as f:
242
+ json.dump(summary, f, indent=2)
243
+ print(f"Session saved successfully to {filepath}")
244
+ return {"status": "Session saved successfully"}
245
+ except Exception as e:
246
+ print(f"Error saving session to {filepath}: {e}")
247
+ return {"error": f"Failed to save session: {e}"}
src/synthetic_generator.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import torch
4
+ import re
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+
7
+ def generate_synthetic_case(clinical_query, model_id="meta-llama/Llama-3.2-3B-Instruct", max_tokens=800):
8
+ """Generate a synthetic clinical case with examiner questions and expected answers."""
9
+ print(f"Generating synthetic case for '{clinical_query}' using {model_id}...")
10
+ gen_tokenizer = None
11
+ gen_model = None
12
+ try:
13
+ # Initialize generator model components
14
+ gen_tokenizer = AutoTokenizer.from_pretrained(model_id)
15
+ gen_model = AutoModelForCausalLM.from_pretrained(
16
+ model_id,
17
+ torch_dtype=torch.float16,
18
+ device_map="auto"
19
+ )
20
+ gen_model.eval()
21
+ device = gen_model.device
22
+ if gen_tokenizer.pad_token is None:
23
+ gen_tokenizer.pad_token = gen_tokenizer.eos_token
24
+
25
+ except Exception as e:
26
+ print(f"Error initializing generator model {model_id}: {e}")
27
+ return None
28
+
29
+ prompt = f"""<s>[INST] You are a board-certified general surgeon simulating a clinical oral board exam.
30
+ Create a synthetic case on the topic: "{clinical_query}".
31
+ Start by describing the initial clinical presentation in 1–2 sentences.
32
+ Then generate a list of 5–8 examiner questions (Q1, Q2...), each paired with the expected examinee answer (A1, A2...). Ensure Q/A pairs are clearly separated.
33
+ Output ONLY the presentation and Q&A pairs in this exact format:
34
+ Clinical Presentation: ...
35
+
36
+ Q1: ...
37
+ A1: ...
38
+
39
+ Q2: ...
40
+ A2: ...
41
+
42
+ (continue until Qn/An)
43
+ Focus on common scenarios and standard knowledge. Avoid overly complex or rare details.
44
+ [/INST]</s>"""
45
+
46
+ output_text = None
47
+ try:
48
+ inputs = gen_tokenizer(prompt, return_tensors="pt").to(device)
49
+ input_ids_length = inputs.input_ids.shape[1]
50
+
51
+ with torch.no_grad():
52
+ outputs = gen_model.generate(
53
+ inputs.input_ids,
54
+ max_new_tokens=max_tokens,
55
+ do_sample=True, # Sample to get potentially varied outputs
56
+ temperature=0.7,
57
+ top_p=0.9,
58
+ pad_token_id=gen_tokenizer.eos_token_id
59
+ )
60
+
61
+ # Decode only the newly generated tokens
62
+ generated_ids = outputs[0][input_ids_length:]
63
+ output_text = gen_tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
64
+ print("Synthetic case generation complete.")
65
+
66
+ except Exception as e:
67
+ print(f"Error during synthetic case generation: {e}")
68
+ finally:
69
+ # Clean up model resources
70
+ del gen_model
71
+ del gen_tokenizer
72
+ if torch.cuda.is_available():
73
+ torch.cuda.empty_cache()
74
+ return output_text
75
+
76
+ def process_synthetic_data(clinical_query, output_text):
77
+ """Process the raw LLM output text into a structured DataFrame for the DummyRetriever."""
78
+ # Extract clinical presentation
79
+ match = re.search(r"Clinical Presentation:(.*?)(?=\n\nQ1:|$)", output_text, re.DOTALL | re.IGNORECASE)
80
+ clinical_presentation_text = match.group(1).strip() if match else "Synthetic Case: " + clinical_query
81
+
82
+ # Extract Q&A pairs
83
+ qa_pattern = r"Q(\d+):\s*(.*?)\s*A\1:\s*(.*?)(?=\n*Q\d+:|\Z)"
84
+ qa_matches = re.findall(qa_pattern, output_text, flags=re.DOTALL | re.IGNORECASE)
85
+
86
+ qa_list = []
87
+ for match_tuple in qa_matches:
88
+ try:
89
+ q_num = int(match_tuple[0])
90
+ q_text = match_tuple[1].strip()
91
+ a_text = match_tuple[2].strip()
92
+ if q_text and a_text:
93
+ qa_list.append({'turn_id': q_num, 'question': q_text, 'answer': a_text})
94
+ except (IndexError, ValueError) as e:
95
+ print(f"Warning: Skipping malformed Q/A match: {match_tuple} due to {e}")
96
+
97
+ if not qa_list:
98
+ print("Warning: No valid Q&A pairs extracted from synthetic text.")
99
+ return pd.DataFrame()
100
+
101
+ qa_list.sort(key=lambda item: item['turn_id'])
102
+
103
+ rows = []
104
+ for item in qa_list:
105
+ rows.append({
106
+ 'case_id': 'SYNTH_01',
107
+ 'clinical_presentation': clinical_query, # Use query as presentation title
108
+ 'turn_id': item['turn_id'],
109
+ 'question': item['question'],
110
+ 'answer': item['answer']
111
+ })
112
+
113
+ df_synthetic = pd.DataFrame(rows)
114
+
115
+ if not df_synthetic.empty and clinical_presentation_text:
116
+ # Find the index of the first turn
117
+ first_turn_index = df_synthetic[df_synthetic['turn_id'] == 1].index
118
+ if not first_turn_index.empty:
119
+ idx = first_turn_index[0]
120
+ df_synthetic.loc[idx, 'question'] = clinical_presentation_text + " " + df_synthetic.loc[idx, 'question']
121
+ else:
122
+ print("Warning: Could not find turn_id 1 to prepend presentation.")
123
+
124
+ print(f"Processed synthetic data into DataFrame with {len(df_synthetic)} turns.")
125
+ return df_synthetic