|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import torch |
|
|
import re |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
def generate_synthetic_case(clinical_query, model_id="meta-llama/Llama-3.2-3B-Instruct", max_tokens=800): |
|
|
"""Generate a synthetic clinical case with examiner questions and expected answers.""" |
|
|
print(f"Generating synthetic case for '{clinical_query}' using {model_id}...") |
|
|
gen_tokenizer = None |
|
|
gen_model = None |
|
|
try: |
|
|
|
|
|
gen_tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
gen_model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto" |
|
|
) |
|
|
gen_model.eval() |
|
|
device = gen_model.device |
|
|
if gen_tokenizer.pad_token is None: |
|
|
gen_tokenizer.pad_token = gen_tokenizer.eos_token |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error initializing generator model {model_id}: {e}") |
|
|
return None |
|
|
|
|
|
prompt = f"""<s>[INST] You are a board-certified general surgeon simulating a clinical oral board exam. |
|
|
Create a synthetic case on the topic: "{clinical_query}". |
|
|
Start by describing the initial clinical presentation in 1–2 sentences. |
|
|
Then generate a list of 5–8 examiner questions (Q1, Q2...), each paired with the expected examinee answer (A1, A2...). Ensure Q/A pairs are clearly separated. |
|
|
Output ONLY the presentation and Q&A pairs in this exact format: |
|
|
Clinical Presentation: ... |
|
|
|
|
|
Q1: ... |
|
|
A1: ... |
|
|
|
|
|
Q2: ... |
|
|
A2: ... |
|
|
|
|
|
(continue until Qn/An) |
|
|
Focus on common scenarios and standard knowledge. Avoid overly complex or rare details. |
|
|
[/INST]</s>""" |
|
|
|
|
|
output_text = None |
|
|
try: |
|
|
inputs = gen_tokenizer(prompt, return_tensors="pt").to(device) |
|
|
input_ids_length = inputs.input_ids.shape[1] |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = gen_model.generate( |
|
|
inputs.input_ids, |
|
|
max_new_tokens=max_tokens, |
|
|
do_sample=True, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
pad_token_id=gen_tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
generated_ids = outputs[0][input_ids_length:] |
|
|
output_text = gen_tokenizer.decode(generated_ids, skip_special_tokens=True).strip() |
|
|
print("Synthetic case generation complete.") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during synthetic case generation: {e}") |
|
|
finally: |
|
|
|
|
|
del gen_model |
|
|
del gen_tokenizer |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
return output_text |
|
|
|
|
|
def process_synthetic_data(clinical_query, output_text): |
|
|
"""Process the raw LLM output text into a structured DataFrame for the DummyRetriever.""" |
|
|
|
|
|
match = re.search(r"Clinical Presentation:(.*?)(?=\n\nQ1:|$)", output_text, re.DOTALL | re.IGNORECASE) |
|
|
clinical_presentation_text = match.group(1).strip() if match else "Synthetic Case: " + clinical_query |
|
|
|
|
|
|
|
|
qa_pattern = r"Q(\d+):\s*(.*?)\s*A\1:\s*(.*?)(?=\n*Q\d+:|\Z)" |
|
|
qa_matches = re.findall(qa_pattern, output_text, flags=re.DOTALL | re.IGNORECASE) |
|
|
|
|
|
qa_list = [] |
|
|
for match_tuple in qa_matches: |
|
|
try: |
|
|
q_num = int(match_tuple[0]) |
|
|
q_text = match_tuple[1].strip() |
|
|
a_text = match_tuple[2].strip() |
|
|
if q_text and a_text: |
|
|
qa_list.append({'turn_id': q_num, 'question': q_text, 'answer': a_text}) |
|
|
except (IndexError, ValueError) as e: |
|
|
print(f"Warning: Skipping malformed Q/A match: {match_tuple} due to {e}") |
|
|
|
|
|
if not qa_list: |
|
|
print("Warning: No valid Q&A pairs extracted from synthetic text.") |
|
|
return pd.DataFrame() |
|
|
|
|
|
qa_list.sort(key=lambda item: item['turn_id']) |
|
|
|
|
|
rows = [] |
|
|
for item in qa_list: |
|
|
rows.append({ |
|
|
'case_id': 'SYNTH_01', |
|
|
'clinical_presentation': clinical_query, |
|
|
'turn_id': item['turn_id'], |
|
|
'question': item['question'], |
|
|
'answer': item['answer'] |
|
|
}) |
|
|
|
|
|
df_synthetic = pd.DataFrame(rows) |
|
|
|
|
|
if not df_synthetic.empty and clinical_presentation_text: |
|
|
|
|
|
first_turn_index = df_synthetic[df_synthetic['turn_id'] == 1].index |
|
|
if not first_turn_index.empty: |
|
|
idx = first_turn_index[0] |
|
|
df_synthetic.loc[idx, 'question'] = clinical_presentation_text + " " + df_synthetic.loc[idx, 'question'] |
|
|
else: |
|
|
print("Warning: Could not find turn_id 1 to prepend presentation.") |
|
|
|
|
|
print(f"Processed synthetic data into DataFrame with {len(df_synthetic)} turns.") |
|
|
return df_synthetic |