LiamKhoaLe commited on
Commit
b0a3faf
·
1 Parent(s): e138b0e

Enhance data enrichment

Browse files
Files changed (3) hide show
  1. utils/augment.py +59 -0
  2. utils/llm.py +12 -6
  3. utils/processor.py +81 -28
utils/augment.py CHANGED
@@ -167,3 +167,62 @@ def retry_invalid_response(text: str, paraphraser, max_retries: int = 3) -> str:
167
 
168
  # If all retries failed, return empty string to indicate drop
169
  return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  # If all retries failed, return empty string to indicate drop
169
  return ""
170
+
171
+ def validate_medical_accuracy(question: str, answer: str, paraphraser) -> bool:
172
+ """Validate medical accuracy of Q&A pairs using LLM consistency check"""
173
+ if not question or not answer:
174
+ return False
175
+
176
+ try:
177
+ # Use the existing consistency check but with medical focus
178
+ return paraphraser.consistency_check(question, answer)
179
+ except Exception as e:
180
+ logger.warning(f"Medical accuracy validation failed: {e}")
181
+ return True # Default to accepting if validation fails
182
+
183
+ def enhance_medical_terminology(text: str, paraphraser) -> str:
184
+ """Enhance medical terminology in text while preserving accuracy"""
185
+ if not text or len(text) < 20:
186
+ return text
187
+
188
+ try:
189
+ prompt = (
190
+ "Improve the medical terminology in this text while preserving all factual information:\n\n"
191
+ f"{text}\n\n"
192
+ "Return only the improved text with better medical terminology:"
193
+ )
194
+
195
+ enhanced = paraphraser.paraphrase(text, difficulty="hard", custom_prompt=prompt)
196
+ if enhanced and not is_invalid_response(enhanced):
197
+ return enhanced
198
+ except Exception as e:
199
+ logger.warning(f"Medical terminology enhancement failed: {e}")
200
+
201
+ return text
202
+
203
+ def create_clinical_scenarios(question: str, answer: str, paraphraser) -> list:
204
+ """Create different clinical scenarios from a Q&A pair"""
205
+ scenarios = []
206
+
207
+ try:
208
+ # Generate different clinical contexts
209
+ context_prompts = [
210
+ f"Rewrite this medical question as if asked by a patient in an emergency room:\n\n{question}",
211
+ f"Rewrite this medical question as if asked by a patient in a routine checkup:\n\n{question}",
212
+ f"Rewrite this medical question as if asked by a patient with chronic conditions:\n\n{question}",
213
+ f"Rewrite this medical question as if asked by a patient's family member:\n\n{question}"
214
+ ]
215
+
216
+ for i, prompt in enumerate(context_prompts):
217
+ try:
218
+ scenario_question = paraphraser.paraphrase(question, difficulty="hard", custom_prompt=prompt)
219
+ if scenario_question and not is_invalid_response(scenario_question):
220
+ scenarios.append((scenario_question, answer, f"clinical_scenario_{i+1}"))
221
+ except Exception as e:
222
+ logger.warning(f"Failed to create clinical scenario {i+1}: {e}")
223
+ continue
224
+
225
+ except Exception as e:
226
+ logger.warning(f"Clinical scenario creation failed: {e}")
227
+
228
+ return scenarios
utils/llm.py CHANGED
@@ -142,14 +142,20 @@ class Paraphraser:
142
  return txt.strip()
143
 
144
  # ————— Paraphrase —————
145
- def paraphrase(self, text: str, difficulty: str = "easy") -> str:
146
  if not text or len(text) < 12:
147
  return text
148
- prompt = (
149
- "Paraphrase the following medical text concisely, preserve meaning and clinical terms.\n"
150
- "Do not fabricate or remove factual claims.\n"
151
- "Return ONLY the rewritten text, without any introduction, commentary.\n"+ text
152
- )
 
 
 
 
 
 
153
  # Always try NVIDIA first
154
  out = self.nv.generate(prompt, temperature=0.1, max_tokens=min(600, max(128, len(text)//2)))
155
  if out:
 
142
  return txt.strip()
143
 
144
  # ————— Paraphrase —————
145
+ def paraphrase(self, text: str, difficulty: str = "easy", custom_prompt: str = None) -> str:
146
  if not text or len(text) < 12:
147
  return text
148
+
149
+ # Use custom prompt if provided, otherwise use default
150
+ if custom_prompt:
151
+ prompt = custom_prompt
152
+ else:
153
+ prompt = (
154
+ "Paraphrase the following medical text concisely, preserve meaning and clinical terms.\n"
155
+ "Do not fabricate or remove factual claims.\n"
156
+ "Return ONLY the rewritten text, without any introduction, commentary.\n"+ text
157
+ )
158
+
159
  # Always try NVIDIA first
160
  out = self.nv.generate(prompt, temperature=0.1, max_tokens=min(600, max(128, len(text)//2)))
161
  if out:
utils/processor.py CHANGED
@@ -52,7 +52,11 @@ def process_file_into_sft(
52
  "backtranslated_input": 0,
53
  "backtranslated_output": 0,
54
  "dedup_skipped": 0,
55
- "consistency_failed": 0
 
 
 
 
56
  }
57
  # Start processing SFT
58
  key_summary = {k: augment_opts.get(k) for k in (
@@ -114,47 +118,63 @@ def _build_variants(user: str, out: str, paraphraser, opts: Dict, stats: Dict):
114
  return variants
115
 
116
  def _build_enriched_variants(user: str, out: str, paraphraser, opts: Dict, stats: Dict, translator=None):
117
- """Build multiple paraphrased variants for SFT enrichment (2-3 answers per question, 2-3 questions per answer)"""
118
  variants = []
119
 
120
- # Generate 2-3 different answers for the same question
121
  answer_variants = []
122
- for i in range(3):
123
- if i == 0:
124
- # Original answer
125
- answer_variants.append((out, ["original_answer"]))
 
 
 
 
 
 
 
126
  else:
127
- # Paraphrased answers with different difficulties
128
- difficulty = "easy" if i == 1 else "hard"
129
  try:
130
- paraphrased_out = paraphraser.paraphrase(out, difficulty=difficulty)
131
- if paraphrased_out and not A.is_invalid_response(paraphrased_out):
 
 
 
132
  if opts.get("style_standardize", True):
133
- paraphrased_out = A.style_standardize_answer(paraphrased_out)
134
- paraphrased_out = A.ensure_terminal_punct(paraphrased_out)
135
- answer_variants.append((paraphrased_out, [f"paraphrase_answer_{difficulty}"]))
136
  stats["paraphrased_output"] += 1
137
  except Exception as e:
138
- logger.warning(f"Failed to paraphrase answer variant {i}: {e}")
139
  continue
140
 
141
- # Generate 2-3 different questions for the same answer
142
  question_variants = []
143
- for i in range(3):
144
- if i == 0:
145
- # Original question
146
- question_variants.append((user, ["original_question"]))
 
 
 
 
 
 
 
147
  else:
148
- # Paraphrased questions with different difficulties
149
- difficulty = "easy" if i == 1 else "hard"
150
  try:
151
- paraphrased_user = paraphraser.paraphrase(user, difficulty=difficulty)
152
- if paraphrased_user and not A.is_invalid_response(paraphrased_user):
153
- paraphrased_user = A.ensure_terminal_punct(paraphrased_user)
154
- question_variants.append((paraphrased_user, [f"paraphrase_question_{difficulty}"]))
 
 
 
155
  stats["paraphrased_input"] += 1
156
  except Exception as e:
157
- logger.warning(f"Failed to paraphrase question variant {i}: {e}")
158
  continue
159
 
160
  # Create combinations: each question variant with each answer variant
@@ -166,7 +186,7 @@ def _build_enriched_variants(user: str, out: str, paraphraser, opts: Dict, stats
166
  # Add Vietnamese variants if translator is available
167
  if translator and translator.is_loaded():
168
  vi_variants = []
169
- for q_user, a_out, tags in variants[:3]: # Limit to first 3 to avoid too many variants
170
  try:
171
  # Translate question and answer
172
  vi_q = translator.translate_text(q_user)
@@ -184,6 +204,26 @@ def _build_enriched_variants(user: str, out: str, paraphraser, opts: Dict, stats
184
 
185
  return variants
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  def _apply_aug(instr: str, user: str, out: str, source: str, opts: Dict, paraphraser, stats: Dict):
188
  # Base cleanup & caps (returns cleaned strings)
189
  user = A.base_cleanup(user, opts.get("max_chars", 5000), opts.get("deidentify", True))
@@ -272,6 +312,11 @@ def _proc_med_dialog(source, path, writer, paraphraser, opts, sample_limit, stat
272
  continue
273
 
274
  # 1) ALWAYS write the original (cleaned/style-standardised only)
 
 
 
 
 
275
  # Optional consistency spot-check (cheap)
276
  if not A.consistency_ok(user, out, opts.get("consistency_check_ratio", 0.0), paraphraser):
277
  stats["consistency_failed"] += 1
@@ -288,6 +333,14 @@ def _proc_med_dialog(source, path, writer, paraphraser, opts, sample_limit, stat
288
  for (u_aug, o_aug, aug_tags) in enriched_variants:
289
  rid_aug = f"{rid}-enriched{random.randint(1000,9999)}"
290
  _commit_row(writer, source, rid_aug, "medical_dialogue", instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator)
 
 
 
 
 
 
 
 
291
 
292
  # Increment count only on success
293
  count += 1
 
52
  "backtranslated_input": 0,
53
  "backtranslated_output": 0,
54
  "dedup_skipped": 0,
55
+ "consistency_failed": 0,
56
+ "medical_accuracy_failed": 0,
57
+ "clinical_scenarios_created": 0,
58
+ "enhanced_terminology": 0,
59
+ "vietnamese_variants": 0
60
  }
61
  # Start processing SFT
62
  key_summary = {k: augment_opts.get(k) for k in (
 
118
  return variants
119
 
120
  def _build_enriched_variants(user: str, out: str, paraphraser, opts: Dict, stats: Dict, translator=None):
121
+ """Build multiple paraphrased variants for SFT enrichment with enhanced diversity strategies"""
122
  variants = []
123
 
124
+ # Enhanced answer generation with different perspectives
125
  answer_variants = []
126
+ answer_strategies = [
127
+ ("original", out, ["original_answer"]),
128
+ ("concise", None, ["concise_answer"]),
129
+ ("detailed", None, ["detailed_answer"]),
130
+ ("clinical", None, ["clinical_answer"]),
131
+ ("patient_friendly", None, ["patient_friendly_answer"])
132
+ ]
133
+
134
+ for strategy, original_text, tags in answer_strategies:
135
+ if strategy == "original":
136
+ answer_variants.append((original_text, tags))
137
  else:
 
 
138
  try:
139
+ # Generate different answer styles
140
+ style_prompt = _get_answer_style_prompt(strategy, user, out)
141
+ enhanced_out = paraphraser.paraphrase(out, difficulty="hard", custom_prompt=style_prompt)
142
+
143
+ if enhanced_out and not A.is_invalid_response(enhanced_out):
144
  if opts.get("style_standardize", True):
145
+ enhanced_out = A.style_standardize_answer(enhanced_out)
146
+ enhanced_out = A.ensure_terminal_punct(enhanced_out)
147
+ answer_variants.append((enhanced_out, tags))
148
  stats["paraphrased_output"] += 1
149
  except Exception as e:
150
+ logger.warning(f"Failed to generate {strategy} answer variant: {e}")
151
  continue
152
 
153
+ # Enhanced question generation with different question types
154
  question_variants = []
155
+ question_strategies = [
156
+ ("original", user, ["original_question"]),
157
+ ("clarifying", None, ["clarifying_question"]),
158
+ ("follow_up", None, ["follow_up_question"]),
159
+ ("symptom_focused", None, ["symptom_focused_question"]),
160
+ ("treatment_focused", None, ["treatment_focused_question"])
161
+ ]
162
+
163
+ for strategy, original_text, tags in question_strategies:
164
+ if strategy == "original":
165
+ question_variants.append((original_text, tags))
166
  else:
 
 
167
  try:
168
+ # Generate different question styles
169
+ style_prompt = _get_question_style_prompt(strategy, user, out)
170
+ enhanced_user = paraphraser.paraphrase(user, difficulty="hard", custom_prompt=style_prompt)
171
+
172
+ if enhanced_user and not A.is_invalid_response(enhanced_user):
173
+ enhanced_user = A.ensure_terminal_punct(enhanced_user)
174
+ question_variants.append((enhanced_user, tags))
175
  stats["paraphrased_input"] += 1
176
  except Exception as e:
177
+ logger.warning(f"Failed to generate {strategy} question variant: {e}")
178
  continue
179
 
180
  # Create combinations: each question variant with each answer variant
 
186
  # Add Vietnamese variants if translator is available
187
  if translator and translator.is_loaded():
188
  vi_variants = []
189
+ for q_user, a_out, tags in variants[:5]: # Limit to first 5 to avoid too many variants
190
  try:
191
  # Translate question and answer
192
  vi_q = translator.translate_text(q_user)
 
204
 
205
  return variants
206
 
207
+ def _get_answer_style_prompt(strategy: str, question: str, original_answer: str) -> str:
208
+ """Generate style-specific prompts for answer enhancement"""
209
+ prompts = {
210
+ "concise": f"Rewrite this medical answer to be more concise while preserving all key medical information:\n\n{original_answer}",
211
+ "detailed": f"Expand this medical answer with more detailed explanations while maintaining accuracy:\n\n{original_answer}",
212
+ "clinical": f"Rewrite this answer using more formal clinical language and medical terminology:\n\n{original_answer}",
213
+ "patient_friendly": f"Rewrite this medical answer in simpler, more patient-friendly language while keeping it medically accurate:\n\n{original_answer}"
214
+ }
215
+ return prompts.get(strategy, f"Paraphrase this medical answer: {original_answer}")
216
+
217
+ def _get_question_style_prompt(strategy: str, original_question: str, answer: str) -> str:
218
+ """Generate style-specific prompts for question enhancement"""
219
+ prompts = {
220
+ "clarifying": f"Rewrite this medical question to ask for clarification or more specific information:\n\n{original_question}",
221
+ "follow_up": f"Create a follow-up question that a patient might ask after this medical question:\n\n{original_question}",
222
+ "symptom_focused": f"Rewrite this question to focus more on symptoms and their characteristics:\n\n{original_question}",
223
+ "treatment_focused": f"Rewrite this question to focus more on treatment options and management:\n\n{original_question}"
224
+ }
225
+ return prompts.get(strategy, f"Paraphrase this medical question: {original_question}")
226
+
227
  def _apply_aug(instr: str, user: str, out: str, source: str, opts: Dict, paraphraser, stats: Dict):
228
  # Base cleanup & caps (returns cleaned strings)
229
  user = A.base_cleanup(user, opts.get("max_chars", 5000), opts.get("deidentify", True))
 
312
  continue
313
 
314
  # 1) ALWAYS write the original (cleaned/style-standardised only)
315
+ # Enhanced medical accuracy validation
316
+ if not A.validate_medical_accuracy(user, out, paraphraser):
317
+ stats["medical_accuracy_failed"] = stats.get("medical_accuracy_failed", 0) + 1
318
+ applied.append("medical_accuracy_flag")
319
+
320
  # Optional consistency spot-check (cheap)
321
  if not A.consistency_ok(user, out, opts.get("consistency_check_ratio", 0.0), paraphraser):
322
  stats["consistency_failed"] += 1
 
333
  for (u_aug, o_aug, aug_tags) in enriched_variants:
334
  rid_aug = f"{rid}-enriched{random.randint(1000,9999)}"
335
  _commit_row(writer, source, rid_aug, "medical_dialogue", instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator)
336
+
337
+ # Add clinical scenarios for enhanced diversity
338
+ if opts.get("clinical_scenarios", True):
339
+ clinical_scenarios = A.create_clinical_scenarios(user, out, paraphraser)
340
+ for (scenario_q, scenario_a, scenario_tag) in clinical_scenarios:
341
+ rid_scenario = f"{rid}-scenario{random.randint(1000,9999)}"
342
+ _commit_row(writer, source, rid_scenario, "medical_dialogue", instr, scenario_q, scenario_a, opts, stats, [scenario_tag], dedupe_seen=dedupe_seen, translator=translator)
343
+ stats["clinical_scenarios_created"] += 1
344
 
345
  # Increment count only on success
346
  count += 1