File size: 4,927 Bytes
2247e66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import pandas as pd
import numpy as np
import torch
import re
from transformers import AutoTokenizer, AutoModelForCausalLM

def generate_synthetic_case(clinical_query, model_id="meta-llama/Llama-3.2-3B-Instruct", max_tokens=800):
    """Generate a synthetic clinical case with examiner questions and expected answers."""
    print(f"Generating synthetic case for '{clinical_query}' using {model_id}...")
    gen_tokenizer = None
    gen_model = None
    try:
        # Initialize generator model components
        gen_tokenizer = AutoTokenizer.from_pretrained(model_id)
        gen_model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        gen_model.eval()
        device = gen_model.device
        if gen_tokenizer.pad_token is None:
            gen_tokenizer.pad_token = gen_tokenizer.eos_token

    except Exception as e:
        print(f"Error initializing generator model {model_id}: {e}")
        return None

    prompt = f"""<s>[INST] You are a board-certified general surgeon simulating a clinical oral board exam.
    Create a synthetic case on the topic: "{clinical_query}".
    Start by describing the initial clinical presentation in 1–2 sentences.
    Then generate a list of 5–8 examiner questions (Q1, Q2...), each paired with the expected examinee answer (A1, A2...). Ensure Q/A pairs are clearly separated.
    Output ONLY the presentation and Q&A pairs in this exact format:
    Clinical Presentation: ...
    
    Q1: ...
    A1: ...
    
    Q2: ...
    A2: ...
    
    (continue until Qn/An)
    Focus on common scenarios and standard knowledge. Avoid overly complex or rare details.
    [/INST]</s>"""

    output_text = None
    try:
        inputs = gen_tokenizer(prompt, return_tensors="pt").to(device)
        input_ids_length = inputs.input_ids.shape[1]

        with torch.no_grad():
            outputs = gen_model.generate(
                inputs.input_ids,
                max_new_tokens=max_tokens,
                do_sample=True, # Sample to get potentially varied outputs
                temperature=0.7,
                top_p=0.9,
                pad_token_id=gen_tokenizer.eos_token_id
            )

        # Decode only the newly generated tokens
        generated_ids = outputs[0][input_ids_length:]
        output_text = gen_tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
        print("Synthetic case generation complete.")

    except Exception as e:
        print(f"Error during synthetic case generation: {e}")
    finally:
        # Clean up model resources
        del gen_model
        del gen_tokenizer
        if torch.cuda.is_available():
             torch.cuda.empty_cache()
        return output_text

def process_synthetic_data(clinical_query, output_text):
    """Process the raw LLM output text into a structured DataFrame for the DummyRetriever."""
    # Extract clinical presentation
    match = re.search(r"Clinical Presentation:(.*?)(?=\n\nQ1:|$)", output_text, re.DOTALL | re.IGNORECASE)
    clinical_presentation_text = match.group(1).strip() if match else "Synthetic Case: " + clinical_query

    # Extract Q&A pairs
    qa_pattern = r"Q(\d+):\s*(.*?)\s*A\1:\s*(.*?)(?=\n*Q\d+:|\Z)"
    qa_matches = re.findall(qa_pattern, output_text, flags=re.DOTALL | re.IGNORECASE)

    qa_list = []
    for match_tuple in qa_matches:
        try:
            q_num = int(match_tuple[0])
            q_text = match_tuple[1].strip()
            a_text = match_tuple[2].strip()
            if q_text and a_text:
                qa_list.append({'turn_id': q_num, 'question': q_text, 'answer': a_text})
        except (IndexError, ValueError) as e:
             print(f"Warning: Skipping malformed Q/A match: {match_tuple} due to {e}")

    if not qa_list:
        print("Warning: No valid Q&A pairs extracted from synthetic text.")
        return pd.DataFrame()

    qa_list.sort(key=lambda item: item['turn_id'])

    rows = []
    for item in qa_list:
        rows.append({
            'case_id': 'SYNTH_01',
            'clinical_presentation': clinical_query, # Use query as presentation title
            'turn_id': item['turn_id'],
            'question': item['question'],
            'answer': item['answer']
        })

    df_synthetic = pd.DataFrame(rows)

    if not df_synthetic.empty and clinical_presentation_text:
         # Find the index of the first turn
         first_turn_index = df_synthetic[df_synthetic['turn_id'] == 1].index
         if not first_turn_index.empty:
              idx = first_turn_index[0]
              df_synthetic.loc[idx, 'question'] = clinical_presentation_text + " " + df_synthetic.loc[idx, 'question']
         else:
              print("Warning: Could not find turn_id 1 to prepend presentation.")

    print(f"Processed synthetic data into DataFrame with {len(df_synthetic)} turns.")
    return df_synthetic