boardllm / src /synthetic_generator.py
melmoheb's picture
Upload folder using huggingface_hub
2247e66 verified
import pandas as pd
import numpy as np
import torch
import re
from transformers import AutoTokenizer, AutoModelForCausalLM
def generate_synthetic_case(clinical_query, model_id="meta-llama/Llama-3.2-3B-Instruct", max_tokens=800):
"""Generate a synthetic clinical case with examiner questions and expected answers."""
print(f"Generating synthetic case for '{clinical_query}' using {model_id}...")
gen_tokenizer = None
gen_model = None
try:
# Initialize generator model components
gen_tokenizer = AutoTokenizer.from_pretrained(model_id)
gen_model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto"
)
gen_model.eval()
device = gen_model.device
if gen_tokenizer.pad_token is None:
gen_tokenizer.pad_token = gen_tokenizer.eos_token
except Exception as e:
print(f"Error initializing generator model {model_id}: {e}")
return None
prompt = f"""<s>[INST] You are a board-certified general surgeon simulating a clinical oral board exam.
Create a synthetic case on the topic: "{clinical_query}".
Start by describing the initial clinical presentation in 1–2 sentences.
Then generate a list of 5–8 examiner questions (Q1, Q2...), each paired with the expected examinee answer (A1, A2...). Ensure Q/A pairs are clearly separated.
Output ONLY the presentation and Q&A pairs in this exact format:
Clinical Presentation: ...
Q1: ...
A1: ...
Q2: ...
A2: ...
(continue until Qn/An)
Focus on common scenarios and standard knowledge. Avoid overly complex or rare details.
[/INST]</s>"""
output_text = None
try:
inputs = gen_tokenizer(prompt, return_tensors="pt").to(device)
input_ids_length = inputs.input_ids.shape[1]
with torch.no_grad():
outputs = gen_model.generate(
inputs.input_ids,
max_new_tokens=max_tokens,
do_sample=True, # Sample to get potentially varied outputs
temperature=0.7,
top_p=0.9,
pad_token_id=gen_tokenizer.eos_token_id
)
# Decode only the newly generated tokens
generated_ids = outputs[0][input_ids_length:]
output_text = gen_tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
print("Synthetic case generation complete.")
except Exception as e:
print(f"Error during synthetic case generation: {e}")
finally:
# Clean up model resources
del gen_model
del gen_tokenizer
if torch.cuda.is_available():
torch.cuda.empty_cache()
return output_text
def process_synthetic_data(clinical_query, output_text):
"""Process the raw LLM output text into a structured DataFrame for the DummyRetriever."""
# Extract clinical presentation
match = re.search(r"Clinical Presentation:(.*?)(?=\n\nQ1:|$)", output_text, re.DOTALL | re.IGNORECASE)
clinical_presentation_text = match.group(1).strip() if match else "Synthetic Case: " + clinical_query
# Extract Q&A pairs
qa_pattern = r"Q(\d+):\s*(.*?)\s*A\1:\s*(.*?)(?=\n*Q\d+:|\Z)"
qa_matches = re.findall(qa_pattern, output_text, flags=re.DOTALL | re.IGNORECASE)
qa_list = []
for match_tuple in qa_matches:
try:
q_num = int(match_tuple[0])
q_text = match_tuple[1].strip()
a_text = match_tuple[2].strip()
if q_text and a_text:
qa_list.append({'turn_id': q_num, 'question': q_text, 'answer': a_text})
except (IndexError, ValueError) as e:
print(f"Warning: Skipping malformed Q/A match: {match_tuple} due to {e}")
if not qa_list:
print("Warning: No valid Q&A pairs extracted from synthetic text.")
return pd.DataFrame()
qa_list.sort(key=lambda item: item['turn_id'])
rows = []
for item in qa_list:
rows.append({
'case_id': 'SYNTH_01',
'clinical_presentation': clinical_query, # Use query as presentation title
'turn_id': item['turn_id'],
'question': item['question'],
'answer': item['answer']
})
df_synthetic = pd.DataFrame(rows)
if not df_synthetic.empty and clinical_presentation_text:
# Find the index of the first turn
first_turn_index = df_synthetic[df_synthetic['turn_id'] == 1].index
if not first_turn_index.empty:
idx = first_turn_index[0]
df_synthetic.loc[idx, 'question'] = clinical_presentation_text + " " + df_synthetic.loc[idx, 'question']
else:
print("Warning: Could not find turn_id 1 to prepend presentation.")
print(f"Processed synthetic data into DataFrame with {len(df_synthetic)} turns.")
return df_synthetic