Upload 6 files
Browse files- src/data_processing.py +201 -0
- src/evaluation_utils.py +67 -0
- src/evaluator.py +113 -0
- src/retriever.py +172 -0
- src/simulator.py +247 -0
- src/synthetic_generator.py +125 -0
src/data_processing.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from docx import Document
|
| 7 |
+
from sentence_transformers import SentenceTransformer
|
| 8 |
+
from datasets import Dataset
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
def read_docx(file_path):
|
| 12 |
+
"""Reads text content from a .docx file."""
|
| 13 |
+
try:
|
| 14 |
+
doc = Document(file_path)
|
| 15 |
+
return '\n'.join(para.text for para in doc.paragraphs)
|
| 16 |
+
except Exception as e:
|
| 17 |
+
print(f"Error reading {file_path}: {e}")
|
| 18 |
+
return ""
|
| 19 |
+
|
| 20 |
+
def extract_qa_pairs(text):
|
| 21 |
+
"""Extracts alternating Examiner and Examinee Q&A pairs from text."""
|
| 22 |
+
pattern = re.compile(r"\*\*Examiner:\*\*(.*?)\n\n\*\*Examinee:\*\*(.*?)(?=\n\n\*\*Examiner:\*\*|$)", re.DOTALL)
|
| 23 |
+
return [{"question": q.strip(), "answer": a.strip()} for q, a in pattern.findall(text)]
|
| 24 |
+
|
| 25 |
+
def parse_filename(filename):
|
| 26 |
+
"""Parses case ID and topic from BTK filename format."""
|
| 27 |
+
# Example: BTK_-_77A___Burn.docx -> case_id = 77A, clinical_presentation = Burn
|
| 28 |
+
base = os.path.splitext(filename)[0]
|
| 29 |
+
match = re.match(r"BTK_-_(\d+[A-Z]?)___(.+)", base)
|
| 30 |
+
if match:
|
| 31 |
+
case_id = match.group(1)
|
| 32 |
+
topic = match.group(2).replace("_", " ").strip()
|
| 33 |
+
else:
|
| 34 |
+
# Handle potential variations or log unknown formats if needed
|
| 35 |
+
print(f"Warning: Could not parse filename format: {filename}")
|
| 36 |
+
case_id, topic = "Unknown", "Unknown"
|
| 37 |
+
return case_id, topic
|
| 38 |
+
|
| 39 |
+
def process_all_cases(folder_path):
|
| 40 |
+
"""Reads all .docx files in a folder and structures them into a DataFrame."""
|
| 41 |
+
rows = []
|
| 42 |
+
if not os.path.isdir(folder_path):
|
| 43 |
+
print(f"Error: Folder not found at {folder_path}")
|
| 44 |
+
return pd.DataFrame(rows)
|
| 45 |
+
|
| 46 |
+
print(f"Processing case files from: {folder_path}")
|
| 47 |
+
for filename in os.listdir(folder_path):
|
| 48 |
+
if filename.lower().endswith('.docx') and not filename.startswith('~'): # Avoid temp files
|
| 49 |
+
file_path = os.path.join(folder_path, filename)
|
| 50 |
+
text = read_docx(file_path)
|
| 51 |
+
if text:
|
| 52 |
+
qa_pairs = extract_qa_pairs(text)
|
| 53 |
+
case_id, presentation = parse_filename(filename)
|
| 54 |
+
if not qa_pairs:
|
| 55 |
+
print(f"Warning: No Q&A pairs extracted from {filename}")
|
| 56 |
+
for i, pair in enumerate(qa_pairs):
|
| 57 |
+
rows.append({
|
| 58 |
+
"case_id": case_id,
|
| 59 |
+
"clinical_presentation": presentation,
|
| 60 |
+
"turn_id": i + 1,
|
| 61 |
+
"question": pair["question"],
|
| 62 |
+
"answer": pair["answer"]
|
| 63 |
+
})
|
| 64 |
+
else:
|
| 65 |
+
print(f"Warning: Empty content for file {filename}")
|
| 66 |
+
|
| 67 |
+
if not rows:
|
| 68 |
+
print("Warning: No data rows were generated. Check input files and formats.")
|
| 69 |
+
|
| 70 |
+
return pd.DataFrame(rows)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# --- ClinicalCaseProcessor Class ---
|
| 74 |
+
class ClinicalCaseProcessor:
|
| 75 |
+
"""Handles preprocessing of clinical cases for the RAG system using sentence-transformers."""
|
| 76 |
+
def __init__(self, model_name="all-MiniLM-L6-v2"):
|
| 77 |
+
print(f"Initializing ClinicalCaseProcessor with model: {model_name}")
|
| 78 |
+
self.model = SentenceTransformer(model_name)
|
| 79 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 80 |
+
print(f"Using device: {self.device}")
|
| 81 |
+
self.model.to(self.device)
|
| 82 |
+
|
| 83 |
+
def preprocess_data(self, input_data, output_path="./processed_clinical_cases", batch_size=16):
|
| 84 |
+
"""
|
| 85 |
+
Convert raw case data (DataFrame or path to CSV) into a vectorized Hugging Face dataset.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
input_data: DataFrame or path to CSV file with clinical cases.
|
| 89 |
+
output_path: Where to save the processed Hugging Face dataset.
|
| 90 |
+
batch_size: Batch size for embedding generation.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
datasets.Dataset: The processed dataset with embeddings.
|
| 94 |
+
"""
|
| 95 |
+
# Load data
|
| 96 |
+
if isinstance(input_data, pd.DataFrame):
|
| 97 |
+
df = input_data
|
| 98 |
+
print("Using provided DataFrame.")
|
| 99 |
+
elif isinstance(input_data, str) and os.path.exists(input_data):
|
| 100 |
+
try:
|
| 101 |
+
df = pd.read_csv(input_data)
|
| 102 |
+
print(f"Data loaded from CSV: {input_data}")
|
| 103 |
+
except Exception as e:
|
| 104 |
+
print(f"Error loading CSV {input_data}: {e}")
|
| 105 |
+
return None
|
| 106 |
+
else:
|
| 107 |
+
print(f"Error: Invalid input_data type or path does not exist: {input_data}")
|
| 108 |
+
return None
|
| 109 |
+
|
| 110 |
+
if df.empty:
|
| 111 |
+
print("Error: Input DataFrame is empty. Cannot process.")
|
| 112 |
+
return None
|
| 113 |
+
|
| 114 |
+
print(f"Raw data shape: {df.shape}")
|
| 115 |
+
|
| 116 |
+
# Validate necessary columns
|
| 117 |
+
required_cols = ['case_id', 'clinical_presentation', 'turn_id', 'question', 'answer']
|
| 118 |
+
if not all(col in df.columns for col in required_cols):
|
| 119 |
+
print(f"Error: DataFrame missing required columns. Found: {df.columns}. Required: {required_cols}")
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
# Group by case_id to get all Q&A pairs for each case
|
| 123 |
+
grouped = df.groupby(['case_id', 'clinical_presentation'], dropna=False)
|
| 124 |
+
|
| 125 |
+
# Create a new dataframe with one row per case
|
| 126 |
+
case_data = []
|
| 127 |
+
print("Grouping data by case...")
|
| 128 |
+
for (case_id, presentation), group in tqdm(grouped, desc="Processing Cases"):
|
| 129 |
+
# Sort by turn_id to ensure correct order
|
| 130 |
+
group = group.sort_values('turn_id')
|
| 131 |
+
|
| 132 |
+
# Extract questions and answers in order
|
| 133 |
+
questions = group['question'].tolist()
|
| 134 |
+
answers = group['answer'].tolist()
|
| 135 |
+
|
| 136 |
+
# Handle potential NaN/None in presentation if groupby didn't drop them
|
| 137 |
+
presentation_str = str(presentation) if pd.notna(presentation) else "Unknown Presentation"
|
| 138 |
+
|
| 139 |
+
case_data.append({
|
| 140 |
+
'case_id': str(case_id) if pd.notna(case_id) else "Unknown ID",
|
| 141 |
+
'clinical_presentation': presentation_str,
|
| 142 |
+
'questions': questions,
|
| 143 |
+
'answers': answers
|
| 144 |
+
})
|
| 145 |
+
|
| 146 |
+
if not case_data:
|
| 147 |
+
print("Error: No cases could be processed after grouping. Check input data integrity.")
|
| 148 |
+
return None
|
| 149 |
+
|
| 150 |
+
processed_df = pd.DataFrame(case_data)
|
| 151 |
+
print(f"Processed data into {len(processed_df)} unique cases.")
|
| 152 |
+
|
| 153 |
+
# Create a searchable summary of each case (handle empty question lists)
|
| 154 |
+
processed_df['case_summary'] = processed_df.apply(
|
| 155 |
+
lambda x: f"Clinical case: {x['clinical_presentation']}. First question: {x['questions'][0] if x['questions'] else 'No questions available'}",
|
| 156 |
+
axis=1
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# Generate embeddings using sentence-transformers
|
| 160 |
+
texts_to_embed = processed_df['case_summary'].tolist()
|
| 161 |
+
all_embeddings = []
|
| 162 |
+
|
| 163 |
+
print(f"Generating embeddings for {len(texts_to_embed)} case summaries...")
|
| 164 |
+
try:
|
| 165 |
+
for i in tqdm(range(0, len(texts_to_embed), batch_size), desc="Embedding Batches"):
|
| 166 |
+
batch_texts = texts_to_embed[i:i+batch_size]
|
| 167 |
+
# Generate embeddings for the batch
|
| 168 |
+
batch_embeddings = self.model.encode(batch_texts, convert_to_numpy=True, device=self.device, show_progress_bar=False)
|
| 169 |
+
all_embeddings.append(batch_embeddings)
|
| 170 |
+
|
| 171 |
+
# Combine all batch embeddings
|
| 172 |
+
if not all_embeddings:
|
| 173 |
+
print("Error: No embeddings were generated.")
|
| 174 |
+
return None
|
| 175 |
+
final_embeddings = np.vstack(all_embeddings)
|
| 176 |
+
print(f"Generated embeddings with shape: {final_embeddings.shape}")
|
| 177 |
+
|
| 178 |
+
except Exception as e:
|
| 179 |
+
print(f"Error during embedding generation: {e}")
|
| 180 |
+
return None
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# Convert to HF Dataset and add embeddings
|
| 184 |
+
try:
|
| 185 |
+
dataset = Dataset.from_pandas(processed_df)
|
| 186 |
+
# Ensure embeddings column is compatible (list of lists)
|
| 187 |
+
dataset = dataset.add_column('embeddings', final_embeddings.tolist())
|
| 188 |
+
except Exception as e:
|
| 189 |
+
print(f"Error converting to Hugging Face Dataset or adding embeddings: {e}")
|
| 190 |
+
return None
|
| 191 |
+
|
| 192 |
+
# Save processed dataset
|
| 193 |
+
try:
|
| 194 |
+
os.makedirs(output_path, exist_ok=True) # Ensure directory exists
|
| 195 |
+
dataset.save_to_disk(output_path)
|
| 196 |
+
print(f"Processed dataset saved successfully to {output_path}")
|
| 197 |
+
except Exception as e:
|
| 198 |
+
print(f"Error saving dataset to disk at {output_path}: {e}")
|
| 199 |
+
return None # Return None if saving failed
|
| 200 |
+
|
| 201 |
+
return dataset
|
src/evaluation_utils.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from sklearn.metrics import ndcg_score
|
| 3 |
+
from src.retriever import ClinicalCaseRetriever, DummyRetriever
|
| 4 |
+
|
| 5 |
+
def retrieval_metrics(retriever_instance: ClinicalCaseRetriever, queries: list[str], gold_ids: list[str], k: int = 5) -> dict | None:
|
| 6 |
+
"""
|
| 7 |
+
Calculates retrieval metrics for a set of queries.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
retriever_instance: An initialized ClinicalCaseRetriever instance.
|
| 11 |
+
queries: A list of query strings.
|
| 12 |
+
gold_ids: A list of the expected 'case_id' strings for each query.
|
| 13 |
+
k: The number of top results to consider for Hit@k and NDCG@k.
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
A dictionary containing Hit@k, MRR, and NDCG@k scores, or None on error.
|
| 17 |
+
"""
|
| 18 |
+
# --- Initialization ---
|
| 19 |
+
hits, reciprocal_ranks, ndcgs = [], [], []
|
| 20 |
+
print(f"\nCalculating retrieval metrics for {len(queries)} queries (k={k})...")
|
| 21 |
+
|
| 22 |
+
# --- Process Each Query ---
|
| 23 |
+
for q_idx, (q, gold) in enumerate(zip(queries, gold_ids)):
|
| 24 |
+
print(f"\nProcessing query {q_idx+1}/{len(queries)}: '{q}' (Expected ID: '{gold}')")
|
| 25 |
+
retrieved_cases, scores = retriever_instance.retrieve_relevant_case(q, top_k=k, return_scores=True)
|
| 26 |
+
|
| 27 |
+
# Safely extract IDs, handle missing keys
|
| 28 |
+
retrieved_ids = [c.get('case_id', 'N/A') for c in retrieved_cases]
|
| 29 |
+
print(f"Retrieved IDs: {retrieved_ids}")
|
| 30 |
+
print(f"Retrieved Scores: {[round(s, 4) for s in scores]}")
|
| 31 |
+
|
| 32 |
+
# --- Calculate Metrics ---
|
| 33 |
+
is_hit = int(gold in retrieved_ids)
|
| 34 |
+
hits.append(is_hit)
|
| 35 |
+
|
| 36 |
+
rank = 0
|
| 37 |
+
if is_hit:
|
| 38 |
+
rank = retrieved_ids.index(gold) + 1
|
| 39 |
+
reciprocal_ranks.append(1.0 / rank)
|
| 40 |
+
else:
|
| 41 |
+
reciprocal_ranks.append(0.0)
|
| 42 |
+
|
| 43 |
+
# NDCG calculation
|
| 44 |
+
true_relevance = np.asarray([[1.0 if gid == gold else 0.0 for gid in retrieved_ids]])
|
| 45 |
+
predicted_scores = np.asarray([scores])
|
| 46 |
+
|
| 47 |
+
current_ndcg = 0.0
|
| 48 |
+
if true_relevance.shape[1] > 0:
|
| 49 |
+
ndcg_k = min(k, true_relevance.shape[1]) # Ensure k is not out of bounds
|
| 50 |
+
current_ndcg = ndcg_score(true_relevance, predicted_scores, k=ndcg_k)
|
| 51 |
+
ndcgs.append(current_ndcg)
|
| 52 |
+
|
| 53 |
+
print(f"Hit: {is_hit}, Rank: {rank if rank > 0 else 'N/A'}, NDCG@{k}: {current_ndcg:.4f}")
|
| 54 |
+
|
| 55 |
+
# --- Aggregate Results ---
|
| 56 |
+
avg_hit = np.mean(hits) if hits else 0.0
|
| 57 |
+
avg_mrr = np.mean(reciprocal_ranks) if reciprocal_ranks else 0.0
|
| 58 |
+
avg_ndcg = np.mean(ndcgs) if ndcgs else 0.0
|
| 59 |
+
|
| 60 |
+
print(f"\n--- Overall Retrieval Results (k={k}) --- ")
|
| 61 |
+
print(f"Average Hit@{k}: {avg_hit:.4f}")
|
| 62 |
+
print(f"Average MRR: {avg_mrr:.4f}") # Corrected spacing for alignment
|
| 63 |
+
print(f"Average NDCG@{k}: {avg_ndcg:.4f}")
|
| 64 |
+
|
| 65 |
+
return {f"Hit@{k}": avg_hit,
|
| 66 |
+
f"MRR": avg_mrr,
|
| 67 |
+
f"NDCG@{k}": avg_ndcg}
|
src/evaluator.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 3 |
+
|
| 4 |
+
class AnswerEvaluator:
|
| 5 |
+
"""Evaluates user answers against expected answers using an LLM."""
|
| 6 |
+
|
| 7 |
+
def __init__(self, model_id="meta-llama/Llama-3.2-3B-Instruct"):
|
| 8 |
+
print(f"Initializing AnswerEvaluator with model: {model_id}")
|
| 9 |
+
try:
|
| 10 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 11 |
+
if self.tokenizer.pad_token is None:
|
| 12 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 13 |
+
print("Set pad_token to eos_token")
|
| 14 |
+
|
| 15 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 16 |
+
model_id,
|
| 17 |
+
torch_dtype=torch.float16,
|
| 18 |
+
device_map="auto"
|
| 19 |
+
)
|
| 20 |
+
self.model.eval()
|
| 21 |
+
self.device = self.model.device
|
| 22 |
+
print(f"AnswerEvaluator model loaded successfully on device: {self.device}")
|
| 23 |
+
|
| 24 |
+
except Exception as e:
|
| 25 |
+
print(f"Error initializing AnswerEvaluator model {model_id}: {e}")
|
| 26 |
+
raise
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def evaluate_answer(self, user_answer, expected_answer, clinical_context=None):
|
| 30 |
+
"""
|
| 31 |
+
Compare user answer to expected answer and provide feedback
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
user_answer: Examinee's response
|
| 35 |
+
expected_answer: Model answer from the dataset
|
| 36 |
+
clinical_context: Optional clinical context to consider
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
Feedback string
|
| 40 |
+
"""
|
| 41 |
+
context_str = f"Clinical context: {clinical_context}\n\n" if clinical_context else ""
|
| 42 |
+
|
| 43 |
+
prompt = f"""<s>[INST] You are acting as an expert examiner for the American Board of Surgery (ABS) oral board exam. You are evaluating a general surgery resident’s answer to a clinical question. \n
|
| 44 |
+
Compare the answer provided by the residents to the correct expected answer, which I will provide you with. \n
|
| 45 |
+
Use the grading rubric below to assess their response:
|
| 46 |
+
|
| 47 |
+
[RUBRIC]
|
| 48 |
+
- Correct: Resident includes all major points and clinical reasoning aligns closely with the expected answer.
|
| 49 |
+
- Partially Correct: Resident includes some key points but omits others, or reasoning is partially flawed.
|
| 50 |
+
- Incorrect: Resident misses most key points or demonstrates incorrect reasoning.
|
| 51 |
+
|
| 52 |
+
{context_str}Here is the model answer that contains the key points expected from the resident:
|
| 53 |
+
{expected_answer}
|
| 54 |
+
|
| 55 |
+
Now, here is the resident’s actual response:
|
| 56 |
+
{user_answer}
|
| 57 |
+
|
| 58 |
+
Evaluate the resident’s response based **only** on the expected answer above. Do not rely on external knowledge or previous responses.
|
| 59 |
+
|
| 60 |
+
Focus your evaluation on:
|
| 61 |
+
1. Which key points were mentioned vs. missed
|
| 62 |
+
2. The accuracy and clarity of the clinical reasoning
|
| 63 |
+
3. Any major omissions or misunderstandings
|
| 64 |
+
|
| 65 |
+
Start your output with:
|
| 66 |
+
ASSESSMENT: [Correct / Partially Correct / Incorrect]
|
| 67 |
+
Then write 1–2 clear, specific sentences explaining how the resident’s response compares to the expected answer.
|
| 68 |
+
|
| 69 |
+
[EXAMPLE 1]
|
| 70 |
+
Expected answer:
|
| 71 |
+
"The differential diagnosis includes acute appendicitis, mesenteric adenitis, gastroenteritis, UTI, and testicular torsion."
|
| 72 |
+
|
| 73 |
+
Resident’s response:
|
| 74 |
+
"My top concern is appendicitis, but I’d also consider things like gastroenteritis or maybe even kidney stones."
|
| 75 |
+
|
| 76 |
+
ASSESSMENT: Partially Correct
|
| 77 |
+
The resident mentioned appendicitis and gastroenteritis but missed several other expected differentials like UTI, testicular torsion, and mesenteric adenitis.
|
| 78 |
+
|
| 79 |
+
[EXAMPLE 2]
|
| 80 |
+
Expected answer:
|
| 81 |
+
"Initial labs should include CBC, CMP, lipase, and abdominal ultrasound to assess for gallstones."
|
| 82 |
+
|
| 83 |
+
Resident’s response:
|
| 84 |
+
"I’d start with a full workup including CBC, liver enzymes, lipase, and an abdominal ultrasound."
|
| 85 |
+
|
| 86 |
+
ASSESSMENT: Correct
|
| 87 |
+
The resident included all key labs and the correct imaging modality. Their reasoning aligns well with the expected answer.
|
| 88 |
+
|
| 89 |
+
[/INST]</s>"""
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(self.device) # Added truncation
|
| 93 |
+
|
| 94 |
+
with torch.no_grad():
|
| 95 |
+
# Generate feedback using the model
|
| 96 |
+
outputs = self.model.generate(
|
| 97 |
+
**inputs,
|
| 98 |
+
max_new_tokens=150,
|
| 99 |
+
do_sample=True,
|
| 100 |
+
temperature = 0.2,
|
| 101 |
+
pad_token_id=self.tokenizer.eos_token_id # Ensure pad token ID is set
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
prompt_length_tokens = inputs.input_ids.shape[1]
|
| 105 |
+
generated_ids = outputs[0][prompt_length_tokens:]
|
| 106 |
+
|
| 107 |
+
feedback = self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
| 108 |
+
|
| 109 |
+
return feedback
|
| 110 |
+
|
| 111 |
+
except Exception as e:
|
| 112 |
+
print(f"Error during LLM evaluation: {e}")
|
| 113 |
+
return "Error: Could not generate feedback."
|
src/retriever.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from sentence_transformers import SentenceTransformer
|
| 4 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 5 |
+
from datasets import load_from_disk, Dataset
|
| 6 |
+
import os
|
| 7 |
+
import pandas as pd
|
| 8 |
+
|
| 9 |
+
class ClinicalCaseRetriever:
|
| 10 |
+
"""Retrieves relevant clinical cases based on user input using sentence-transformers embeddings."""
|
| 11 |
+
|
| 12 |
+
def __init__(self, dataset_path='./processed_clinical_cases', model_name="all-MiniLM-L6-v2"):
|
| 13 |
+
print(f"Initializing ClinicalCaseRetriever with model: {model_name}")
|
| 14 |
+
if isinstance(dataset_path, Dataset):
|
| 15 |
+
self.dataset = dataset_path
|
| 16 |
+
print("Using provided Hugging Face Dataset object.")
|
| 17 |
+
elif isinstance(dataset_path, str) and os.path.isdir(dataset_path):
|
| 18 |
+
try:
|
| 19 |
+
self.dataset = load_from_disk(dataset_path)
|
| 20 |
+
print(f"Dataset loaded successfully from disk: {dataset_path}")
|
| 21 |
+
except Exception as e:
|
| 22 |
+
print(f"Error loading dataset from disk {dataset_path}: {e}")
|
| 23 |
+
raise ValueError("Failed to load dataset.") from e
|
| 24 |
+
else:
|
| 25 |
+
raise ValueError(f"Invalid dataset_path: Must be a Dataset object or a valid directory path. Got: {dataset_path}")
|
| 26 |
+
|
| 27 |
+
if 'embeddings' not in self.dataset.column_names:
|
| 28 |
+
raise ValueError("Dataset must contain an 'embeddings' column.")
|
| 29 |
+
|
| 30 |
+
print(f"Dataset features: {self.dataset.features}")
|
| 31 |
+
print(f"Number of cases in dataset: {len(self.dataset)}")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
self.model = SentenceTransformer(model_name)
|
| 35 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 36 |
+
print(f"Using device: {self.device}")
|
| 37 |
+
self.model.to(self.device)
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
# Ensure embeddings are loaded as a NumPy array
|
| 41 |
+
self.case_embeddings = np.array(self.dataset['embeddings'])
|
| 42 |
+
if self.case_embeddings.ndim != 2:
|
| 43 |
+
raise ValueError(f"Embeddings array must be 2-dimensional. Got shape: {self.case_embeddings.shape}")
|
| 44 |
+
print(f"Loaded {len(self.dataset)} cases with embeddings of shape {self.case_embeddings.shape}")
|
| 45 |
+
except Exception as e:
|
| 46 |
+
print(f"Error processing embeddings from dataset: {e}")
|
| 47 |
+
raise ValueError("Failed to load or process embeddings.") from e
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_available_cases(self, n=5):
|
| 51 |
+
"""Return a sample of available cases for user selection."""
|
| 52 |
+
num_cases = len(self.dataset)
|
| 53 |
+
if num_cases == 0:
|
| 54 |
+
return []
|
| 55 |
+
sample_size = min(n, num_cases)
|
| 56 |
+
indices = np.random.choice(num_cases, sample_size, replace=False)
|
| 57 |
+
# Ensure indices are int for slicing dataset
|
| 58 |
+
return [(int(i), self.dataset[int(i)]['clinical_presentation']) for i in indices]
|
| 59 |
+
|
| 60 |
+
def encode_query(self, query):
|
| 61 |
+
"""Generate embedding for a query string."""
|
| 62 |
+
# Create a better search query structure
|
| 63 |
+
search_query = f"Clinical case about {query}"
|
| 64 |
+
print(f"Encoding query: '{search_query}'")
|
| 65 |
+
|
| 66 |
+
# Generate embedding using sentence-transformers
|
| 67 |
+
try:
|
| 68 |
+
query_embedding = self.model.encode([search_query], convert_to_numpy=True, device=self.device, show_progress_bar=False)
|
| 69 |
+
return query_embedding
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print(f"Error encoding query '{query}': {e}")
|
| 72 |
+
return None # Or raise an error
|
| 73 |
+
|
| 74 |
+
def retrieve_relevant_case(self, query, top_k=1, return_scores=False):
|
| 75 |
+
"""Find the most relevant clinical case(s) given a query."""
|
| 76 |
+
if not isinstance(top_k, int) or top_k < 1:
|
| 77 |
+
print("Warning: top_k must be a positive integer. Defaulting to 1.")
|
| 78 |
+
top_k = 1
|
| 79 |
+
|
| 80 |
+
# Get query embedding
|
| 81 |
+
query_embedding = self.encode_query(query)
|
| 82 |
+
if query_embedding is None:
|
| 83 |
+
return [] if not return_scores else ([], [])
|
| 84 |
+
|
| 85 |
+
# Calculate similarity scores
|
| 86 |
+
try:
|
| 87 |
+
similarities = cosine_similarity(query_embedding, self.case_embeddings)[0] # Get the single row of similarities
|
| 88 |
+
except Exception as e:
|
| 89 |
+
print(f"Error calculating cosine similarity: {e}")
|
| 90 |
+
return [] if not return_scores else ([], [])
|
| 91 |
+
|
| 92 |
+
# Get indices of top-k most similar cases
|
| 93 |
+
# Ensure we don't request more indices than available cases
|
| 94 |
+
k_actual = min(top_k, len(similarities))
|
| 95 |
+
if k_actual == 0: # Should not happen if dataset loaded, but safe check
|
| 96 |
+
return [] if not return_scores else ([], [])
|
| 97 |
+
|
| 98 |
+
# Use partitioning for efficiency if k is much smaller than N, or argsort otherwise
|
| 99 |
+
# Using argsort is generally simpler and fine for moderate N
|
| 100 |
+
top_indices = np.argsort(similarities)[-k_actual:][::-1].astype(int) # Get top k indices, sorted descending
|
| 101 |
+
|
| 102 |
+
top_scores = similarities[top_indices].tolist() # Get scores for these indices
|
| 103 |
+
|
| 104 |
+
# Return the most relevant case(s)
|
| 105 |
+
try:
|
| 106 |
+
# Retrieve cases safely using integer indices
|
| 107 |
+
retrieved_cases = [self.dataset[int(idx)] for idx in top_indices]
|
| 108 |
+
except IndexError as e:
|
| 109 |
+
print(f"Error retrieving cases using indices {top_indices}: {e}")
|
| 110 |
+
return [] if not return_scores else ([], [])
|
| 111 |
+
except Exception as e:
|
| 112 |
+
print(f"Unexpected error retrieving cases: {e}")
|
| 113 |
+
return [] if not return_scores else ([], [])
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
results_with_scores = list(zip(retrieved_cases, top_scores))
|
| 117 |
+
print(f"Retrieved {len(results_with_scores)} cases with similarity scores:")
|
| 118 |
+
for case, score in results_with_scores:
|
| 119 |
+
# Safely access presentation, provide default if missing
|
| 120 |
+
presentation = case.get('clinical_presentation', 'Unknown Presentation')
|
| 121 |
+
print(f"- {presentation}: {score:.4f}")
|
| 122 |
+
|
| 123 |
+
if return_scores:
|
| 124 |
+
return retrieved_cases, top_scores
|
| 125 |
+
else:
|
| 126 |
+
# Return list of tuples (case_dict, score)
|
| 127 |
+
return results_with_scores
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class DummyRetriever:
|
| 131 |
+
"""A simple retriever that bypasses RAG, taking a pre-formatted DataFrame."""
|
| 132 |
+
|
| 133 |
+
def __init__(self, df):
|
| 134 |
+
self.dataset = []
|
| 135 |
+
if not isinstance(df, pd.DataFrame) or df.empty:
|
| 136 |
+
print("Warning: DummyRetriever initialized with empty or invalid DataFrame.")
|
| 137 |
+
return
|
| 138 |
+
|
| 139 |
+
# Expects df to be pre-processed with columns:
|
| 140 |
+
# 'clinical_presentation', 'turn_id', 'question', 'answer'
|
| 141 |
+
required_cols = ['clinical_presentation', 'turn_id', 'question', 'answer']
|
| 142 |
+
if not all(col in df.columns for col in required_cols):
|
| 143 |
+
print(f"Warning: DummyRetriever DataFrame missing required columns. Need: {required_cols}")
|
| 144 |
+
return
|
| 145 |
+
|
| 146 |
+
grouped = df.groupby('clinical_presentation')
|
| 147 |
+
print(f"DummyRetriever processing {len(grouped)} unique presentations.")
|
| 148 |
+
for i, (scenario, group) in enumerate(grouped):
|
| 149 |
+
group_sorted = group.sort_values('turn_id')
|
| 150 |
+
|
| 151 |
+
case_dict = {
|
| 152 |
+
"case_id": group_sorted['case_id'].iloc[0] if 'case_id' in group_sorted.columns else f"dummy_{i}",
|
| 153 |
+
"clinical_presentation": scenario,
|
| 154 |
+
"questions": group_sorted["question"].tolist(),
|
| 155 |
+
"answers": group_sorted["answer"].tolist()
|
| 156 |
+
}
|
| 157 |
+
self.dataset.append(case_dict)
|
| 158 |
+
print(f"DummyRetriever initialized with {len(self.dataset)} cases.")
|
| 159 |
+
|
| 160 |
+
def retrieve_relevant_case(self, scenario_query, top_k=1):
|
| 161 |
+
"""
|
| 162 |
+
Finds the case matching the query string exactly.
|
| 163 |
+
Ignores 'top_k' but mimics the return structure [(case_dict, score)].
|
| 164 |
+
"""
|
| 165 |
+
print(f"DummyRetriever searching for exact match: '{scenario_query}'")
|
| 166 |
+
for case_dict in self.dataset:
|
| 167 |
+
if case_dict["clinical_presentation"] == scenario_query:
|
| 168 |
+
print(f"DummyRetriever found match: {case_dict['clinical_presentation']}")
|
| 169 |
+
return [(case_dict, 1.0)]
|
| 170 |
+
|
| 171 |
+
print(f"DummyRetriever: No exact match found for '{scenario_query}'")
|
| 172 |
+
return []
|
src/simulator.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import pandas as pd
|
| 3 |
+
# Assuming retriever and evaluator classes are in these files:
|
| 4 |
+
from .retriever import ClinicalCaseRetriever, DummyRetriever
|
| 5 |
+
from .evaluator import AnswerEvaluator
|
| 6 |
+
|
| 7 |
+
class OralExamSimulator:
|
| 8 |
+
"""Main class that coordinates the oral board exam simulation."""
|
| 9 |
+
|
| 10 |
+
def __init__(self, retriever, evaluator):
|
| 11 |
+
if not isinstance(retriever, (ClinicalCaseRetriever, DummyRetriever)):
|
| 12 |
+
raise TypeError("Retriever must be an instance of ClinicalCaseRetriever or DummyRetriever")
|
| 13 |
+
if not isinstance(evaluator, AnswerEvaluator):
|
| 14 |
+
raise TypeError("Evaluator must be an instance of AnswerEvaluator")
|
| 15 |
+
|
| 16 |
+
self.retriever = retriever
|
| 17 |
+
self.evaluator = evaluator
|
| 18 |
+
self.current_case = None
|
| 19 |
+
self.current_question_idx = 0
|
| 20 |
+
self.session_history = []
|
| 21 |
+
|
| 22 |
+
def start_new_case(self, clinical_query=None, case_idx=None):
|
| 23 |
+
"""
|
| 24 |
+
Initialize a new exam case based on query or direct selection.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
clinical_query (str, optional): Text description of the desired case topic.
|
| 28 |
+
case_idx (int, optional): Direct index of the case to use from the retriever's dataset.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
dict: Contains case info and first question, or an error message.
|
| 32 |
+
"""
|
| 33 |
+
print("-" * 50)
|
| 34 |
+
print(f"Attempting to start new case | Query: '{clinical_query}' | Index: {case_idx}")
|
| 35 |
+
|
| 36 |
+
# Reset state for the new case
|
| 37 |
+
self.current_case = None
|
| 38 |
+
self.current_question_idx = 0
|
| 39 |
+
self.session_history = []
|
| 40 |
+
|
| 41 |
+
# Case selection logic
|
| 42 |
+
retrieved_info = None # Use a temporary variable
|
| 43 |
+
if case_idx is not None:
|
| 44 |
+
try:
|
| 45 |
+
# Direct case selection by index
|
| 46 |
+
# Ensure index is valid
|
| 47 |
+
if 0 <= int(case_idx) < len(self.retriever.dataset):
|
| 48 |
+
self.current_case = self.retriever.dataset[int(case_idx)]
|
| 49 |
+
similarity_score = 1.0 # Direct selection implies perfect 'match'
|
| 50 |
+
print(f"Selected case by index {case_idx}: {self.current_case.get('clinical_presentation', 'Unknown Presentation')}")
|
| 51 |
+
retrieved_info = (self.current_case, similarity_score)
|
| 52 |
+
else:
|
| 53 |
+
print(f"Error: Invalid case index {case_idx}. Must be between 0 and {len(self.retriever.dataset)-1}.")
|
| 54 |
+
return {"error": f"Invalid case index: {case_idx}"}
|
| 55 |
+
except Exception as e:
|
| 56 |
+
print(f"Error selecting case by index {case_idx}: {e}")
|
| 57 |
+
return {"error": f"Failed to select case by index: {e}"}
|
| 58 |
+
|
| 59 |
+
elif clinical_query:
|
| 60 |
+
# RAG-based retrieval
|
| 61 |
+
try:
|
| 62 |
+
# retrieve_relevant_case now returns a list of tuples: [(case_dict, score), ...]
|
| 63 |
+
retrieved_results = self.retriever.retrieve_relevant_case(clinical_query, top_k=1)
|
| 64 |
+
if retrieved_results: # Check if list is not empty
|
| 65 |
+
retrieved_info = retrieved_results[0] # Get the first tuple (case_dict, score)
|
| 66 |
+
self.current_case = retrieved_info[0]
|
| 67 |
+
similarity_score = retrieved_info[1]
|
| 68 |
+
print(f"Retrieved case via query ('{clinical_query}') with score {similarity_score:.4f}: {self.current_case.get('clinical_presentation', 'Unknown Presentation')}")
|
| 69 |
+
else:
|
| 70 |
+
print(f"Error: No case found for query: '{clinical_query}'")
|
| 71 |
+
return {"error": f"No relevant case found for query: {clinical_query}"}
|
| 72 |
+
except Exception as e:
|
| 73 |
+
print(f"Error retrieving case for query '{clinical_query}': {e}")
|
| 74 |
+
return {"error": f"Failed to retrieve case by query: {e}"}
|
| 75 |
+
else:
|
| 76 |
+
# No selection method provided
|
| 77 |
+
print("Error: Must provide either a clinical query or a case index.")
|
| 78 |
+
return {"error": "Please provide either a clinical query or case index."}
|
| 79 |
+
|
| 80 |
+
# --- Post-selection setup ---
|
| 81 |
+
if self.current_case is None:
|
| 82 |
+
# This should ideally be caught above, but double-check
|
| 83 |
+
print("Error: Failed to set current_case.")
|
| 84 |
+
return {"error": "Failed to load the selected case."}
|
| 85 |
+
|
| 86 |
+
# Validate case structure
|
| 87 |
+
if 'questions' not in self.current_case or 'answers' not in self.current_case or \
|
| 88 |
+
not isinstance(self.current_case['questions'], list) or \
|
| 89 |
+
not isinstance(self.current_case['answers'], list) or \
|
| 90 |
+
len(self.current_case['questions']) != len(self.current_case['answers']):
|
| 91 |
+
print(f"Error: Invalid case structure for case ID {self.current_case.get('case_id', 'N/A')}. Mismatched or missing Q/A lists.")
|
| 92 |
+
return {"error": "Selected case has invalid format."}
|
| 93 |
+
|
| 94 |
+
if not self.current_case['questions']:
|
| 95 |
+
print(f"Warning: Selected case ID {self.current_case.get('case_id', 'N/A')} has no questions.")
|
| 96 |
+
# Decide how to handle this - error or proceed? Let's return an error for now.
|
| 97 |
+
return {"error": "Selected case contains no questions."}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# Start a new session record
|
| 101 |
+
self.session_history.append({
|
| 102 |
+
"role": "system",
|
| 103 |
+
"content": f"Clinical scenario started: {self.current_case.get('clinical_presentation', 'Unknown Presentation')} (Case ID: {self.current_case.get('case_id', 'N/A')})"
|
| 104 |
+
})
|
| 105 |
+
|
| 106 |
+
# Get the first question
|
| 107 |
+
first_question = self.current_case['questions'][0]
|
| 108 |
+
|
| 109 |
+
# Record this interaction
|
| 110 |
+
self.session_history.append({
|
| 111 |
+
"role": "examiner",
|
| 112 |
+
"content": first_question
|
| 113 |
+
})
|
| 114 |
+
|
| 115 |
+
print(f"Case successfully started. Total questions: {len(self.current_case['questions'])}")
|
| 116 |
+
print("-" * 50)
|
| 117 |
+
|
| 118 |
+
return {
|
| 119 |
+
"case_id": self.current_case.get('case_id', 'unknown'),
|
| 120 |
+
"clinical_presentation": self.current_case.get('clinical_presentation', 'Unknown'),
|
| 121 |
+
"similarity_score": similarity_score, # Use the score from retrieval
|
| 122 |
+
"current_question": first_question,
|
| 123 |
+
"question_number": 1,
|
| 124 |
+
"total_questions": len(self.current_case['questions'])
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
def process_user_response(self, response):
|
| 128 |
+
"""
|
| 129 |
+
Process the user's answer, get feedback, and return the next question or completion status.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
response (str): User's answer text.
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
dict: Contains feedback, expected answer, completion status, and next question (if applicable), or an error message.
|
| 136 |
+
"""
|
| 137 |
+
if self.current_case is None:
|
| 138 |
+
print("Error: No active case.")
|
| 139 |
+
return {"error": "No active case. Please start a new case first."}
|
| 140 |
+
|
| 141 |
+
if self.current_question_idx >= len(self.current_case['questions']):
|
| 142 |
+
print("Error: Attempting to process response when case is already complete.")
|
| 143 |
+
return {"error": "Case already completed."}
|
| 144 |
+
|
| 145 |
+
print("-" * 50)
|
| 146 |
+
current_q_num = self.current_question_idx + 1
|
| 147 |
+
total_q = len(self.current_case['questions'])
|
| 148 |
+
print(f"Processing response for Question {current_q_num}/{total_q}")
|
| 149 |
+
print(f"User Response: {response}")
|
| 150 |
+
|
| 151 |
+
# Save the user's response to history
|
| 152 |
+
self.session_history.append({
|
| 153 |
+
"role": "resident",
|
| 154 |
+
"content": response
|
| 155 |
+
})
|
| 156 |
+
|
| 157 |
+
# Get the expected answer for the current question
|
| 158 |
+
expected_answer = self.current_case['answers'][self.current_question_idx]
|
| 159 |
+
print(f"Expected Answer: {expected_answer}")
|
| 160 |
+
|
| 161 |
+
# Evaluate the answer
|
| 162 |
+
feedback = self.evaluator.evaluate_answer(
|
| 163 |
+
response,
|
| 164 |
+
expected_answer,
|
| 165 |
+
clinical_context = f"Regarding the case '{self.current_case.get('clinical_presentation', 'N/A')}'"
|
| 166 |
+
)
|
| 167 |
+
print(f"Generated Feedback: {feedback}")
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# Add feedback to history
|
| 171 |
+
self.session_history.append({
|
| 172 |
+
"role": "feedback",
|
| 173 |
+
"content": feedback
|
| 174 |
+
})
|
| 175 |
+
|
| 176 |
+
# Move to the next question index
|
| 177 |
+
self.current_question_idx += 1
|
| 178 |
+
|
| 179 |
+
# Check if the case is complete
|
| 180 |
+
is_complete = self.current_question_idx >= len(self.current_case['questions'])
|
| 181 |
+
|
| 182 |
+
result = {
|
| 183 |
+
"feedback": feedback,
|
| 184 |
+
"expected_answer": expected_answer,
|
| 185 |
+
"is_complete": is_complete,
|
| 186 |
+
"question_number": self.current_question_idx
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
# Add next question if not complete
|
| 190 |
+
if not is_complete:
|
| 191 |
+
next_question = self.current_case['questions'][self.current_question_idx]
|
| 192 |
+
result["next_question"] = next_question
|
| 193 |
+
result["total_questions"] = total_q
|
| 194 |
+
|
| 195 |
+
# Add next question to history
|
| 196 |
+
self.session_history.append({
|
| 197 |
+
"role": "examiner",
|
| 198 |
+
"content": next_question
|
| 199 |
+
})
|
| 200 |
+
print(f"Next question ({result['question_number']}/{total_q}): {next_question}")
|
| 201 |
+
else:
|
| 202 |
+
print("Case completed.")
|
| 203 |
+
summary = self.generate_session_summary()
|
| 204 |
+
result["session_summary"] = summary
|
| 205 |
+
self.session_history.append({
|
| 206 |
+
"role": "system",
|
| 207 |
+
"content": "End of clinical scenario."
|
| 208 |
+
})
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
print("-" * 50)
|
| 212 |
+
return result
|
| 213 |
+
|
| 214 |
+
def generate_session_summary(self):
|
| 215 |
+
"""Generate a summary dictionary of the completed session."""
|
| 216 |
+
if not self.current_case or not self.session_history:
|
| 217 |
+
return {"error": "No active or completed session to summarize."}
|
| 218 |
+
|
| 219 |
+
# Simple summary structure
|
| 220 |
+
return {
|
| 221 |
+
"case_id": self.current_case.get('case_id', 'N/A'),
|
| 222 |
+
"case": self.current_case.get('clinical_presentation', 'Unknown'),
|
| 223 |
+
"total_questions_in_case": len(self.current_case.get('questions', [])),
|
| 224 |
+
"interaction_history": self.session_history # Include the full log
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
def save_session(self, filepath):
|
| 228 |
+
"""Save the current session summary to a JSON file."""
|
| 229 |
+
summary = self.generate_session_summary()
|
| 230 |
+
if "error" in summary:
|
| 231 |
+
print(f"Error generating summary for saving: {summary['error']}")
|
| 232 |
+
return {"error": "No session to save"}
|
| 233 |
+
|
| 234 |
+
try:
|
| 235 |
+
# Add a timestamp to the saved data
|
| 236 |
+
summary["timestamp"] = pd.Timestamp.now().isoformat()
|
| 237 |
+
|
| 238 |
+
# Ensure directory exists
|
| 239 |
+
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
| 240 |
+
|
| 241 |
+
with open(filepath, 'w') as f:
|
| 242 |
+
json.dump(summary, f, indent=2)
|
| 243 |
+
print(f"Session saved successfully to {filepath}")
|
| 244 |
+
return {"status": "Session saved successfully"}
|
| 245 |
+
except Exception as e:
|
| 246 |
+
print(f"Error saving session to {filepath}: {e}")
|
| 247 |
+
return {"error": f"Failed to save session: {e}"}
|
src/synthetic_generator.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import re
|
| 5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 6 |
+
|
| 7 |
+
def generate_synthetic_case(clinical_query, model_id="meta-llama/Llama-3.2-3B-Instruct", max_tokens=800):
|
| 8 |
+
"""Generate a synthetic clinical case with examiner questions and expected answers."""
|
| 9 |
+
print(f"Generating synthetic case for '{clinical_query}' using {model_id}...")
|
| 10 |
+
gen_tokenizer = None
|
| 11 |
+
gen_model = None
|
| 12 |
+
try:
|
| 13 |
+
# Initialize generator model components
|
| 14 |
+
gen_tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 15 |
+
gen_model = AutoModelForCausalLM.from_pretrained(
|
| 16 |
+
model_id,
|
| 17 |
+
torch_dtype=torch.float16,
|
| 18 |
+
device_map="auto"
|
| 19 |
+
)
|
| 20 |
+
gen_model.eval()
|
| 21 |
+
device = gen_model.device
|
| 22 |
+
if gen_tokenizer.pad_token is None:
|
| 23 |
+
gen_tokenizer.pad_token = gen_tokenizer.eos_token
|
| 24 |
+
|
| 25 |
+
except Exception as e:
|
| 26 |
+
print(f"Error initializing generator model {model_id}: {e}")
|
| 27 |
+
return None
|
| 28 |
+
|
| 29 |
+
prompt = f"""<s>[INST] You are a board-certified general surgeon simulating a clinical oral board exam.
|
| 30 |
+
Create a synthetic case on the topic: "{clinical_query}".
|
| 31 |
+
Start by describing the initial clinical presentation in 1–2 sentences.
|
| 32 |
+
Then generate a list of 5–8 examiner questions (Q1, Q2...), each paired with the expected examinee answer (A1, A2...). Ensure Q/A pairs are clearly separated.
|
| 33 |
+
Output ONLY the presentation and Q&A pairs in this exact format:
|
| 34 |
+
Clinical Presentation: ...
|
| 35 |
+
|
| 36 |
+
Q1: ...
|
| 37 |
+
A1: ...
|
| 38 |
+
|
| 39 |
+
Q2: ...
|
| 40 |
+
A2: ...
|
| 41 |
+
|
| 42 |
+
(continue until Qn/An)
|
| 43 |
+
Focus on common scenarios and standard knowledge. Avoid overly complex or rare details.
|
| 44 |
+
[/INST]</s>"""
|
| 45 |
+
|
| 46 |
+
output_text = None
|
| 47 |
+
try:
|
| 48 |
+
inputs = gen_tokenizer(prompt, return_tensors="pt").to(device)
|
| 49 |
+
input_ids_length = inputs.input_ids.shape[1]
|
| 50 |
+
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
outputs = gen_model.generate(
|
| 53 |
+
inputs.input_ids,
|
| 54 |
+
max_new_tokens=max_tokens,
|
| 55 |
+
do_sample=True, # Sample to get potentially varied outputs
|
| 56 |
+
temperature=0.7,
|
| 57 |
+
top_p=0.9,
|
| 58 |
+
pad_token_id=gen_tokenizer.eos_token_id
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# Decode only the newly generated tokens
|
| 62 |
+
generated_ids = outputs[0][input_ids_length:]
|
| 63 |
+
output_text = gen_tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
| 64 |
+
print("Synthetic case generation complete.")
|
| 65 |
+
|
| 66 |
+
except Exception as e:
|
| 67 |
+
print(f"Error during synthetic case generation: {e}")
|
| 68 |
+
finally:
|
| 69 |
+
# Clean up model resources
|
| 70 |
+
del gen_model
|
| 71 |
+
del gen_tokenizer
|
| 72 |
+
if torch.cuda.is_available():
|
| 73 |
+
torch.cuda.empty_cache()
|
| 74 |
+
return output_text
|
| 75 |
+
|
| 76 |
+
def process_synthetic_data(clinical_query, output_text):
|
| 77 |
+
"""Process the raw LLM output text into a structured DataFrame for the DummyRetriever."""
|
| 78 |
+
# Extract clinical presentation
|
| 79 |
+
match = re.search(r"Clinical Presentation:(.*?)(?=\n\nQ1:|$)", output_text, re.DOTALL | re.IGNORECASE)
|
| 80 |
+
clinical_presentation_text = match.group(1).strip() if match else "Synthetic Case: " + clinical_query
|
| 81 |
+
|
| 82 |
+
# Extract Q&A pairs
|
| 83 |
+
qa_pattern = r"Q(\d+):\s*(.*?)\s*A\1:\s*(.*?)(?=\n*Q\d+:|\Z)"
|
| 84 |
+
qa_matches = re.findall(qa_pattern, output_text, flags=re.DOTALL | re.IGNORECASE)
|
| 85 |
+
|
| 86 |
+
qa_list = []
|
| 87 |
+
for match_tuple in qa_matches:
|
| 88 |
+
try:
|
| 89 |
+
q_num = int(match_tuple[0])
|
| 90 |
+
q_text = match_tuple[1].strip()
|
| 91 |
+
a_text = match_tuple[2].strip()
|
| 92 |
+
if q_text and a_text:
|
| 93 |
+
qa_list.append({'turn_id': q_num, 'question': q_text, 'answer': a_text})
|
| 94 |
+
except (IndexError, ValueError) as e:
|
| 95 |
+
print(f"Warning: Skipping malformed Q/A match: {match_tuple} due to {e}")
|
| 96 |
+
|
| 97 |
+
if not qa_list:
|
| 98 |
+
print("Warning: No valid Q&A pairs extracted from synthetic text.")
|
| 99 |
+
return pd.DataFrame()
|
| 100 |
+
|
| 101 |
+
qa_list.sort(key=lambda item: item['turn_id'])
|
| 102 |
+
|
| 103 |
+
rows = []
|
| 104 |
+
for item in qa_list:
|
| 105 |
+
rows.append({
|
| 106 |
+
'case_id': 'SYNTH_01',
|
| 107 |
+
'clinical_presentation': clinical_query, # Use query as presentation title
|
| 108 |
+
'turn_id': item['turn_id'],
|
| 109 |
+
'question': item['question'],
|
| 110 |
+
'answer': item['answer']
|
| 111 |
+
})
|
| 112 |
+
|
| 113 |
+
df_synthetic = pd.DataFrame(rows)
|
| 114 |
+
|
| 115 |
+
if not df_synthetic.empty and clinical_presentation_text:
|
| 116 |
+
# Find the index of the first turn
|
| 117 |
+
first_turn_index = df_synthetic[df_synthetic['turn_id'] == 1].index
|
| 118 |
+
if not first_turn_index.empty:
|
| 119 |
+
idx = first_turn_index[0]
|
| 120 |
+
df_synthetic.loc[idx, 'question'] = clinical_presentation_text + " " + df_synthetic.loc[idx, 'question']
|
| 121 |
+
else:
|
| 122 |
+
print("Warning: Could not find turn_id 1 to prepend presentation.")
|
| 123 |
+
|
| 124 |
+
print(f"Processed synthetic data into DataFrame with {len(df_synthetic)} turns.")
|
| 125 |
+
return df_synthetic
|