Spaces:
Sleeping
Sleeping
Commit
·
b0a3faf
1
Parent(s):
e138b0e
Enhance data enrichment
Browse files- utils/augment.py +59 -0
- utils/llm.py +12 -6
- utils/processor.py +81 -28
utils/augment.py
CHANGED
|
@@ -167,3 +167,62 @@ def retry_invalid_response(text: str, paraphraser, max_retries: int = 3) -> str:
|
|
| 167 |
|
| 168 |
# If all retries failed, return empty string to indicate drop
|
| 169 |
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
# If all retries failed, return empty string to indicate drop
|
| 169 |
return ""
|
| 170 |
+
|
| 171 |
+
def validate_medical_accuracy(question: str, answer: str, paraphraser) -> bool:
|
| 172 |
+
"""Validate medical accuracy of Q&A pairs using LLM consistency check"""
|
| 173 |
+
if not question or not answer:
|
| 174 |
+
return False
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
# Use the existing consistency check but with medical focus
|
| 178 |
+
return paraphraser.consistency_check(question, answer)
|
| 179 |
+
except Exception as e:
|
| 180 |
+
logger.warning(f"Medical accuracy validation failed: {e}")
|
| 181 |
+
return True # Default to accepting if validation fails
|
| 182 |
+
|
| 183 |
+
def enhance_medical_terminology(text: str, paraphraser) -> str:
|
| 184 |
+
"""Enhance medical terminology in text while preserving accuracy"""
|
| 185 |
+
if not text or len(text) < 20:
|
| 186 |
+
return text
|
| 187 |
+
|
| 188 |
+
try:
|
| 189 |
+
prompt = (
|
| 190 |
+
"Improve the medical terminology in this text while preserving all factual information:\n\n"
|
| 191 |
+
f"{text}\n\n"
|
| 192 |
+
"Return only the improved text with better medical terminology:"
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
enhanced = paraphraser.paraphrase(text, difficulty="hard", custom_prompt=prompt)
|
| 196 |
+
if enhanced and not is_invalid_response(enhanced):
|
| 197 |
+
return enhanced
|
| 198 |
+
except Exception as e:
|
| 199 |
+
logger.warning(f"Medical terminology enhancement failed: {e}")
|
| 200 |
+
|
| 201 |
+
return text
|
| 202 |
+
|
| 203 |
+
def create_clinical_scenarios(question: str, answer: str, paraphraser) -> list:
|
| 204 |
+
"""Create different clinical scenarios from a Q&A pair"""
|
| 205 |
+
scenarios = []
|
| 206 |
+
|
| 207 |
+
try:
|
| 208 |
+
# Generate different clinical contexts
|
| 209 |
+
context_prompts = [
|
| 210 |
+
f"Rewrite this medical question as if asked by a patient in an emergency room:\n\n{question}",
|
| 211 |
+
f"Rewrite this medical question as if asked by a patient in a routine checkup:\n\n{question}",
|
| 212 |
+
f"Rewrite this medical question as if asked by a patient with chronic conditions:\n\n{question}",
|
| 213 |
+
f"Rewrite this medical question as if asked by a patient's family member:\n\n{question}"
|
| 214 |
+
]
|
| 215 |
+
|
| 216 |
+
for i, prompt in enumerate(context_prompts):
|
| 217 |
+
try:
|
| 218 |
+
scenario_question = paraphraser.paraphrase(question, difficulty="hard", custom_prompt=prompt)
|
| 219 |
+
if scenario_question and not is_invalid_response(scenario_question):
|
| 220 |
+
scenarios.append((scenario_question, answer, f"clinical_scenario_{i+1}"))
|
| 221 |
+
except Exception as e:
|
| 222 |
+
logger.warning(f"Failed to create clinical scenario {i+1}: {e}")
|
| 223 |
+
continue
|
| 224 |
+
|
| 225 |
+
except Exception as e:
|
| 226 |
+
logger.warning(f"Clinical scenario creation failed: {e}")
|
| 227 |
+
|
| 228 |
+
return scenarios
|
utils/llm.py
CHANGED
|
@@ -142,14 +142,20 @@ class Paraphraser:
|
|
| 142 |
return txt.strip()
|
| 143 |
|
| 144 |
# ————— Paraphrase —————
|
| 145 |
-
def paraphrase(self, text: str, difficulty: str = "easy") -> str:
|
| 146 |
if not text or len(text) < 12:
|
| 147 |
return text
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
# Always try NVIDIA first
|
| 154 |
out = self.nv.generate(prompt, temperature=0.1, max_tokens=min(600, max(128, len(text)//2)))
|
| 155 |
if out:
|
|
|
|
| 142 |
return txt.strip()
|
| 143 |
|
| 144 |
# ————— Paraphrase —————
|
| 145 |
+
def paraphrase(self, text: str, difficulty: str = "easy", custom_prompt: str = None) -> str:
|
| 146 |
if not text or len(text) < 12:
|
| 147 |
return text
|
| 148 |
+
|
| 149 |
+
# Use custom prompt if provided, otherwise use default
|
| 150 |
+
if custom_prompt:
|
| 151 |
+
prompt = custom_prompt
|
| 152 |
+
else:
|
| 153 |
+
prompt = (
|
| 154 |
+
"Paraphrase the following medical text concisely, preserve meaning and clinical terms.\n"
|
| 155 |
+
"Do not fabricate or remove factual claims.\n"
|
| 156 |
+
"Return ONLY the rewritten text, without any introduction, commentary.\n"+ text
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
# Always try NVIDIA first
|
| 160 |
out = self.nv.generate(prompt, temperature=0.1, max_tokens=min(600, max(128, len(text)//2)))
|
| 161 |
if out:
|
utils/processor.py
CHANGED
|
@@ -52,7 +52,11 @@ def process_file_into_sft(
|
|
| 52 |
"backtranslated_input": 0,
|
| 53 |
"backtranslated_output": 0,
|
| 54 |
"dedup_skipped": 0,
|
| 55 |
-
"consistency_failed": 0
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
}
|
| 57 |
# Start processing SFT
|
| 58 |
key_summary = {k: augment_opts.get(k) for k in (
|
|
@@ -114,47 +118,63 @@ def _build_variants(user: str, out: str, paraphraser, opts: Dict, stats: Dict):
|
|
| 114 |
return variants
|
| 115 |
|
| 116 |
def _build_enriched_variants(user: str, out: str, paraphraser, opts: Dict, stats: Dict, translator=None):
|
| 117 |
-
"""Build multiple paraphrased variants for SFT enrichment
|
| 118 |
variants = []
|
| 119 |
|
| 120 |
-
#
|
| 121 |
answer_variants = []
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
else:
|
| 127 |
-
# Paraphrased answers with different difficulties
|
| 128 |
-
difficulty = "easy" if i == 1 else "hard"
|
| 129 |
try:
|
| 130 |
-
|
| 131 |
-
|
|
|
|
|
|
|
|
|
|
| 132 |
if opts.get("style_standardize", True):
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
answer_variants.append((
|
| 136 |
stats["paraphrased_output"] += 1
|
| 137 |
except Exception as e:
|
| 138 |
-
logger.warning(f"Failed to
|
| 139 |
continue
|
| 140 |
|
| 141 |
-
#
|
| 142 |
question_variants = []
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
else:
|
| 148 |
-
# Paraphrased questions with different difficulties
|
| 149 |
-
difficulty = "easy" if i == 1 else "hard"
|
| 150 |
try:
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
| 155 |
stats["paraphrased_input"] += 1
|
| 156 |
except Exception as e:
|
| 157 |
-
logger.warning(f"Failed to
|
| 158 |
continue
|
| 159 |
|
| 160 |
# Create combinations: each question variant with each answer variant
|
|
@@ -166,7 +186,7 @@ def _build_enriched_variants(user: str, out: str, paraphraser, opts: Dict, stats
|
|
| 166 |
# Add Vietnamese variants if translator is available
|
| 167 |
if translator and translator.is_loaded():
|
| 168 |
vi_variants = []
|
| 169 |
-
for q_user, a_out, tags in variants[:
|
| 170 |
try:
|
| 171 |
# Translate question and answer
|
| 172 |
vi_q = translator.translate_text(q_user)
|
|
@@ -184,6 +204,26 @@ def _build_enriched_variants(user: str, out: str, paraphraser, opts: Dict, stats
|
|
| 184 |
|
| 185 |
return variants
|
| 186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
def _apply_aug(instr: str, user: str, out: str, source: str, opts: Dict, paraphraser, stats: Dict):
|
| 188 |
# Base cleanup & caps (returns cleaned strings)
|
| 189 |
user = A.base_cleanup(user, opts.get("max_chars", 5000), opts.get("deidentify", True))
|
|
@@ -272,6 +312,11 @@ def _proc_med_dialog(source, path, writer, paraphraser, opts, sample_limit, stat
|
|
| 272 |
continue
|
| 273 |
|
| 274 |
# 1) ALWAYS write the original (cleaned/style-standardised only)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
# Optional consistency spot-check (cheap)
|
| 276 |
if not A.consistency_ok(user, out, opts.get("consistency_check_ratio", 0.0), paraphraser):
|
| 277 |
stats["consistency_failed"] += 1
|
|
@@ -288,6 +333,14 @@ def _proc_med_dialog(source, path, writer, paraphraser, opts, sample_limit, stat
|
|
| 288 |
for (u_aug, o_aug, aug_tags) in enriched_variants:
|
| 289 |
rid_aug = f"{rid}-enriched{random.randint(1000,9999)}"
|
| 290 |
_commit_row(writer, source, rid_aug, "medical_dialogue", instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
# Increment count only on success
|
| 293 |
count += 1
|
|
|
|
| 52 |
"backtranslated_input": 0,
|
| 53 |
"backtranslated_output": 0,
|
| 54 |
"dedup_skipped": 0,
|
| 55 |
+
"consistency_failed": 0,
|
| 56 |
+
"medical_accuracy_failed": 0,
|
| 57 |
+
"clinical_scenarios_created": 0,
|
| 58 |
+
"enhanced_terminology": 0,
|
| 59 |
+
"vietnamese_variants": 0
|
| 60 |
}
|
| 61 |
# Start processing SFT
|
| 62 |
key_summary = {k: augment_opts.get(k) for k in (
|
|
|
|
| 118 |
return variants
|
| 119 |
|
| 120 |
def _build_enriched_variants(user: str, out: str, paraphraser, opts: Dict, stats: Dict, translator=None):
|
| 121 |
+
"""Build multiple paraphrased variants for SFT enrichment with enhanced diversity strategies"""
|
| 122 |
variants = []
|
| 123 |
|
| 124 |
+
# Enhanced answer generation with different perspectives
|
| 125 |
answer_variants = []
|
| 126 |
+
answer_strategies = [
|
| 127 |
+
("original", out, ["original_answer"]),
|
| 128 |
+
("concise", None, ["concise_answer"]),
|
| 129 |
+
("detailed", None, ["detailed_answer"]),
|
| 130 |
+
("clinical", None, ["clinical_answer"]),
|
| 131 |
+
("patient_friendly", None, ["patient_friendly_answer"])
|
| 132 |
+
]
|
| 133 |
+
|
| 134 |
+
for strategy, original_text, tags in answer_strategies:
|
| 135 |
+
if strategy == "original":
|
| 136 |
+
answer_variants.append((original_text, tags))
|
| 137 |
else:
|
|
|
|
|
|
|
| 138 |
try:
|
| 139 |
+
# Generate different answer styles
|
| 140 |
+
style_prompt = _get_answer_style_prompt(strategy, user, out)
|
| 141 |
+
enhanced_out = paraphraser.paraphrase(out, difficulty="hard", custom_prompt=style_prompt)
|
| 142 |
+
|
| 143 |
+
if enhanced_out and not A.is_invalid_response(enhanced_out):
|
| 144 |
if opts.get("style_standardize", True):
|
| 145 |
+
enhanced_out = A.style_standardize_answer(enhanced_out)
|
| 146 |
+
enhanced_out = A.ensure_terminal_punct(enhanced_out)
|
| 147 |
+
answer_variants.append((enhanced_out, tags))
|
| 148 |
stats["paraphrased_output"] += 1
|
| 149 |
except Exception as e:
|
| 150 |
+
logger.warning(f"Failed to generate {strategy} answer variant: {e}")
|
| 151 |
continue
|
| 152 |
|
| 153 |
+
# Enhanced question generation with different question types
|
| 154 |
question_variants = []
|
| 155 |
+
question_strategies = [
|
| 156 |
+
("original", user, ["original_question"]),
|
| 157 |
+
("clarifying", None, ["clarifying_question"]),
|
| 158 |
+
("follow_up", None, ["follow_up_question"]),
|
| 159 |
+
("symptom_focused", None, ["symptom_focused_question"]),
|
| 160 |
+
("treatment_focused", None, ["treatment_focused_question"])
|
| 161 |
+
]
|
| 162 |
+
|
| 163 |
+
for strategy, original_text, tags in question_strategies:
|
| 164 |
+
if strategy == "original":
|
| 165 |
+
question_variants.append((original_text, tags))
|
| 166 |
else:
|
|
|
|
|
|
|
| 167 |
try:
|
| 168 |
+
# Generate different question styles
|
| 169 |
+
style_prompt = _get_question_style_prompt(strategy, user, out)
|
| 170 |
+
enhanced_user = paraphraser.paraphrase(user, difficulty="hard", custom_prompt=style_prompt)
|
| 171 |
+
|
| 172 |
+
if enhanced_user and not A.is_invalid_response(enhanced_user):
|
| 173 |
+
enhanced_user = A.ensure_terminal_punct(enhanced_user)
|
| 174 |
+
question_variants.append((enhanced_user, tags))
|
| 175 |
stats["paraphrased_input"] += 1
|
| 176 |
except Exception as e:
|
| 177 |
+
logger.warning(f"Failed to generate {strategy} question variant: {e}")
|
| 178 |
continue
|
| 179 |
|
| 180 |
# Create combinations: each question variant with each answer variant
|
|
|
|
| 186 |
# Add Vietnamese variants if translator is available
|
| 187 |
if translator and translator.is_loaded():
|
| 188 |
vi_variants = []
|
| 189 |
+
for q_user, a_out, tags in variants[:5]: # Limit to first 5 to avoid too many variants
|
| 190 |
try:
|
| 191 |
# Translate question and answer
|
| 192 |
vi_q = translator.translate_text(q_user)
|
|
|
|
| 204 |
|
| 205 |
return variants
|
| 206 |
|
| 207 |
+
def _get_answer_style_prompt(strategy: str, question: str, original_answer: str) -> str:
|
| 208 |
+
"""Generate style-specific prompts for answer enhancement"""
|
| 209 |
+
prompts = {
|
| 210 |
+
"concise": f"Rewrite this medical answer to be more concise while preserving all key medical information:\n\n{original_answer}",
|
| 211 |
+
"detailed": f"Expand this medical answer with more detailed explanations while maintaining accuracy:\n\n{original_answer}",
|
| 212 |
+
"clinical": f"Rewrite this answer using more formal clinical language and medical terminology:\n\n{original_answer}",
|
| 213 |
+
"patient_friendly": f"Rewrite this medical answer in simpler, more patient-friendly language while keeping it medically accurate:\n\n{original_answer}"
|
| 214 |
+
}
|
| 215 |
+
return prompts.get(strategy, f"Paraphrase this medical answer: {original_answer}")
|
| 216 |
+
|
| 217 |
+
def _get_question_style_prompt(strategy: str, original_question: str, answer: str) -> str:
|
| 218 |
+
"""Generate style-specific prompts for question enhancement"""
|
| 219 |
+
prompts = {
|
| 220 |
+
"clarifying": f"Rewrite this medical question to ask for clarification or more specific information:\n\n{original_question}",
|
| 221 |
+
"follow_up": f"Create a follow-up question that a patient might ask after this medical question:\n\n{original_question}",
|
| 222 |
+
"symptom_focused": f"Rewrite this question to focus more on symptoms and their characteristics:\n\n{original_question}",
|
| 223 |
+
"treatment_focused": f"Rewrite this question to focus more on treatment options and management:\n\n{original_question}"
|
| 224 |
+
}
|
| 225 |
+
return prompts.get(strategy, f"Paraphrase this medical question: {original_question}")
|
| 226 |
+
|
| 227 |
def _apply_aug(instr: str, user: str, out: str, source: str, opts: Dict, paraphraser, stats: Dict):
|
| 228 |
# Base cleanup & caps (returns cleaned strings)
|
| 229 |
user = A.base_cleanup(user, opts.get("max_chars", 5000), opts.get("deidentify", True))
|
|
|
|
| 312 |
continue
|
| 313 |
|
| 314 |
# 1) ALWAYS write the original (cleaned/style-standardised only)
|
| 315 |
+
# Enhanced medical accuracy validation
|
| 316 |
+
if not A.validate_medical_accuracy(user, out, paraphraser):
|
| 317 |
+
stats["medical_accuracy_failed"] = stats.get("medical_accuracy_failed", 0) + 1
|
| 318 |
+
applied.append("medical_accuracy_flag")
|
| 319 |
+
|
| 320 |
# Optional consistency spot-check (cheap)
|
| 321 |
if not A.consistency_ok(user, out, opts.get("consistency_check_ratio", 0.0), paraphraser):
|
| 322 |
stats["consistency_failed"] += 1
|
|
|
|
| 333 |
for (u_aug, o_aug, aug_tags) in enriched_variants:
|
| 334 |
rid_aug = f"{rid}-enriched{random.randint(1000,9999)}"
|
| 335 |
_commit_row(writer, source, rid_aug, "medical_dialogue", instr, u_aug, o_aug, opts, stats, aug_tags, dedupe_seen=dedupe_seen, translator=translator)
|
| 336 |
+
|
| 337 |
+
# Add clinical scenarios for enhanced diversity
|
| 338 |
+
if opts.get("clinical_scenarios", True):
|
| 339 |
+
clinical_scenarios = A.create_clinical_scenarios(user, out, paraphraser)
|
| 340 |
+
for (scenario_q, scenario_a, scenario_tag) in clinical_scenarios:
|
| 341 |
+
rid_scenario = f"{rid}-scenario{random.randint(1000,9999)}"
|
| 342 |
+
_commit_row(writer, source, rid_scenario, "medical_dialogue", instr, scenario_q, scenario_a, opts, stats, [scenario_tag], dedupe_seen=dedupe_seen, translator=translator)
|
| 343 |
+
stats["clinical_scenarios_created"] += 1
|
| 344 |
|
| 345 |
# Increment count only on success
|
| 346 |
count += 1
|