LiamKhoaLe commited on
Commit
d668aec
·
1 Parent(s): e76f718

Enh model prompting and sequential processing

Browse files
Files changed (3) hide show
  1. utils/cloud_llm.py +161 -19
  2. utils/local_llm.py +27 -15
  3. utils/processor.py +49 -12
utils/cloud_llm.py CHANGED
@@ -146,23 +146,35 @@ class Paraphraser:
146
  if not text or len(text) < 12:
147
  return text
148
 
149
- # Use custom prompt if provided, otherwise use default
150
  if custom_prompt:
151
  prompt = custom_prompt
152
  else:
153
- prompt = (
154
- "Paraphrase the following medical text concisely, preserve meaning and clinical terms.\n"
155
- "Do not fabricate or remove factual claims.\n"
156
- "Return ONLY the rewritten text, without any introduction, commentary.\n"+ text
157
- )
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
- # Always try NVIDIA first
160
- out = self.nv.generate(prompt, temperature=0.1, max_tokens=min(600, max(128, len(text)//2)))
161
  if out:
162
  return self._clean_resp(out)
163
 
164
- # Only fallback to GEMINI_MODEL_EASY (ignore difficulty parameter)
165
- out = self.gm_easy.generate(prompt, max_output_tokens=min(600, max(128, len(text)//2)))
166
  if out:
167
  logger.info(f"[LLM][GEMINI] out={snip(self._clean_resp(out))}")
168
  return self._clean_resp(out)
@@ -171,7 +183,21 @@ class Paraphraser:
171
  # ————— Translate & Backtranslate —————
172
  def translate(self, text: str, target_lang: str = "vi") -> Optional[str]:
173
  if not text: return text
174
- prompt = f"Translate to {target_lang}. Keep meaning exact, preserve medical terms:\n\n{text}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(800, len(text)+100))
176
  if out: return out.strip()
177
  return self.gm_easy.generate(prompt, max_output_tokens=min(800, len(text)+100))
@@ -180,7 +206,21 @@ class Paraphraser:
180
  if not text: return text
181
  mid = self.translate(text, target_lang=via_lang)
182
  if not mid: return None
183
- prompt = f"Translate the following Vietnamese text back to English, preserving the exact meaning:\n\n{mid}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(900, len(text)+150))
185
  if out: return out.strip()
186
  res = self.gm_easy.generate(prompt, max_output_tokens=min(900, len(text)+150))
@@ -188,14 +228,116 @@ class Paraphraser:
188
 
189
  # ————— Consistency Judge (cheap, ratio-based) —————
190
  def consistency_check(self, user: str, output: str) -> bool:
191
- """Return True if 'output' appears supported by 'user' (context/question). Soft heuristic via LLM."""
192
  prompt = (
193
- "You are a strict medical QA validator. Given the USER input (question+context) "
194
- "and the MODEL ANSWER, reply with exactly 'PASS' if the answer is supported and safe, "
195
- "otherwise 'FAIL'. No extra text.\n\n"
196
- f"USER:\n{user}\n\nANSWER:\n{output}"
 
 
 
 
 
197
  )
198
- out = self.nv.generate(prompt, temperature=0.0, max_tokens=3)
199
  if not out:
200
- out = self.gm_easy.generate(prompt, max_output_tokens=3)
201
  return isinstance(out, str) and "PASS" in out.upper()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  if not text or len(text) < 12:
147
  return text
148
 
149
+ # Use custom prompt if provided, otherwise use optimized medical prompts
150
  if custom_prompt:
151
  prompt = custom_prompt
152
  else:
153
+ # Optimized medical paraphrasing prompts based on difficulty
154
+ if difficulty == "easy":
155
+ prompt = (
156
+ "You are a medical professional. Rewrite the following medical text using different words while preserving all medical facts, clinical terms, and meaning. Keep the same level of detail and accuracy.\n\n"
157
+ f"Original medical text: {text}\n\n"
158
+ "Rewritten medical text:"
159
+ )
160
+ else: # hard difficulty
161
+ prompt = (
162
+ "You are a medical expert. Rewrite the following medical text using more sophisticated medical language and different sentence structures while preserving all clinical facts, medical terminology, and diagnostic information. Maintain professional medical tone.\n\n"
163
+ f"Original medical text: {text}\n\n"
164
+ "Enhanced medical text:"
165
+ )
166
+
167
+ # Optimize temperature and token limits based on difficulty
168
+ temperature = 0.1 if difficulty == "easy" else 0.3
169
+ max_tokens = min(600, max(128, len(text)//2))
170
 
171
+ # Always try NVIDIA first (optimized for medical tasks)
172
+ out = self.nv.generate(prompt, temperature=temperature, max_tokens=max_tokens)
173
  if out:
174
  return self._clean_resp(out)
175
 
176
+ # Fallback to GEMINI with optimized parameters
177
+ out = self.gm_easy.generate(prompt, max_output_tokens=max_tokens)
178
  if out:
179
  logger.info(f"[LLM][GEMINI] out={snip(self._clean_resp(out))}")
180
  return self._clean_resp(out)
 
183
  # ————— Translate & Backtranslate —————
184
  def translate(self, text: str, target_lang: str = "vi") -> Optional[str]:
185
  if not text: return text
186
+
187
+ # Optimized medical translation prompts
188
+ if target_lang == "vi":
189
+ prompt = (
190
+ "You are a medical translator. Translate the following English medical text to Vietnamese while preserving all medical terminology, clinical facts, and professional medical language. Use appropriate Vietnamese medical terms.\n\n"
191
+ f"English medical text: {text}\n\n"
192
+ "Vietnamese medical translation:"
193
+ )
194
+ else:
195
+ prompt = (
196
+ f"You are a medical translator. Translate the following medical text to {target_lang} while preserving all medical terminology, clinical facts, and professional medical language.\n\n"
197
+ f"Original medical text: {text}\n\n"
198
+ f"{target_lang} medical translation:"
199
+ )
200
+
201
  out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(800, len(text)+100))
202
  if out: return out.strip()
203
  return self.gm_easy.generate(prompt, max_output_tokens=min(800, len(text)+100))
 
206
  if not text: return text
207
  mid = self.translate(text, target_lang=via_lang)
208
  if not mid: return None
209
+
210
+ # Optimized backtranslation prompt with medical focus
211
+ if via_lang == "vi":
212
+ prompt = (
213
+ "You are a medical translator. Translate the following Vietnamese medical text back to English while preserving all medical terminology, clinical facts, and professional medical language. Ensure the translation is medically accurate.\n\n"
214
+ f"Vietnamese medical text: {mid}\n\n"
215
+ "English medical translation:"
216
+ )
217
+ else:
218
+ prompt = (
219
+ f"You are a medical translator. Translate the following {via_lang} medical text back to English while preserving all medical terminology, clinical facts, and professional medical language.\n\n"
220
+ f"{via_lang} medical text: {mid}\n\n"
221
+ "English medical translation:"
222
+ )
223
+
224
  out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(900, len(text)+150))
225
  if out: return out.strip()
226
  res = self.gm_easy.generate(prompt, max_output_tokens=min(900, len(text)+150))
 
228
 
229
  # ————— Consistency Judge (cheap, ratio-based) —————
230
  def consistency_check(self, user: str, output: str) -> bool:
231
+ """Return True if 'output' appears supported by 'user' (context/question). Optimized medical validation."""
232
  prompt = (
233
+ "You are a medical quality assurance expert. Evaluate if the medical answer is consistent with the question/context and medically accurate. Consider:\n"
234
+ "1. Medical accuracy and clinical appropriateness\n"
235
+ "2. Consistency with the question asked\n"
236
+ "3. Safety and professional medical standards\n"
237
+ "4. Completeness of the medical information\n\n"
238
+ "Reply with exactly 'PASS' if the answer is medically sound and consistent, otherwise 'FAIL'.\n\n"
239
+ f"Question/Context: {user}\n\n"
240
+ f"Medical Answer: {output}\n\n"
241
+ "Evaluation:"
242
  )
243
+ out = self.nv.generate(prompt, temperature=0.0, max_tokens=5)
244
  if not out:
245
+ out = self.gm_easy.generate(prompt, max_output_tokens=5)
246
  return isinstance(out, str) and "PASS" in out.upper()
247
+
248
+ def medical_accuracy_check(self, question: str, answer: str) -> bool:
249
+ """Check medical accuracy of Q&A pairs using cloud APIs"""
250
+ if not question or not answer:
251
+ return False
252
+
253
+ prompt = (
254
+ "You are a medical accuracy validator. Evaluate if the medical answer is accurate and appropriate for the question. Consider:\n"
255
+ "1. Medical facts and clinical knowledge\n"
256
+ "2. Appropriate medical terminology\n"
257
+ "3. Clinical reasoning and logic\n"
258
+ "4. Safety considerations\n\n"
259
+ "Reply with exactly 'ACCURATE' if the answer is medically correct, otherwise 'INACCURATE'.\n\n"
260
+ f"Medical Question: {question}\n\n"
261
+ f"Medical Answer: {answer}\n\n"
262
+ "Medical Accuracy Assessment:"
263
+ )
264
+
265
+ out = self.nv.generate(prompt, temperature=0.0, max_tokens=5)
266
+ if not out:
267
+ out = self.gm_easy.generate(prompt, max_output_tokens=5)
268
+ return isinstance(out, str) and "ACCURATE" in out.upper()
269
+
270
+ def enhance_medical_terminology(self, text: str) -> str:
271
+ """Enhance medical terminology in text using cloud APIs"""
272
+ if not text or len(text) < 20:
273
+ return text
274
+
275
+ prompt = (
276
+ "You are a medical terminology expert. Improve the medical terminology in the following text while preserving all factual information and clinical accuracy. Use more precise medical terms where appropriate.\n\n"
277
+ f"Original text: {text}\n\n"
278
+ "Enhanced medical text:"
279
+ )
280
+
281
+ out = self.nv.generate(prompt, temperature=0.1, max_tokens=min(800, len(text)+100))
282
+ if not out:
283
+ out = self.gm_easy.generate(prompt, max_output_tokens=min(800, len(text)+100))
284
+ return out if out else text
285
+
286
+ def create_clinical_scenarios(self, question: str, answer: str) -> list:
287
+ """Create different clinical scenarios from Q&A pairs using cloud APIs"""
288
+ scenarios = []
289
+
290
+ # Different clinical context prompts
291
+ context_prompts = [
292
+ (
293
+ "Rewrite this medical question as if asked by a patient in an emergency room setting:",
294
+ "emergency_room"
295
+ ),
296
+ (
297
+ "Rewrite this medical question as if asked by a patient during a routine checkup:",
298
+ "routine_checkup"
299
+ ),
300
+ (
301
+ "Rewrite this medical question as if asked by a patient with chronic conditions:",
302
+ "chronic_care"
303
+ ),
304
+ (
305
+ "Rewrite this medical question as if asked by a patient's family member:",
306
+ "family_inquiry"
307
+ )
308
+ ]
309
+
310
+ for prompt_template, scenario_type in context_prompts:
311
+ try:
312
+ prompt = f"{prompt_template}\n\nOriginal question: {question}\n\nRewritten question:"
313
+ scenario_question = self.paraphrase(question, difficulty="hard", custom_prompt=prompt)
314
+
315
+ if scenario_question and not self._is_invalid_response(scenario_question):
316
+ scenarios.append((scenario_question, answer, scenario_type))
317
+ except Exception as e:
318
+ logger.warning(f"Failed to create clinical scenario {scenario_type}: {e}")
319
+ continue
320
+
321
+ return scenarios
322
+
323
+ def _is_invalid_response(self, text: str) -> bool:
324
+ """Check if response is invalid"""
325
+ if not text or not isinstance(text, str):
326
+ return True
327
+
328
+ text_lower = text.lower().strip()
329
+ invalid_patterns = [
330
+ "fail", "invalid", "i couldn't", "i can't", "i cannot", "unable to",
331
+ "sorry", "error", "not available", "no answer", "insufficient",
332
+ "don't know", "do not know", "not sure", "cannot determine",
333
+ "unable to provide", "not possible", "not applicable", "n/a"
334
+ ]
335
+
336
+ if len(text_lower) < 3:
337
+ return True
338
+
339
+ for pattern in invalid_patterns:
340
+ if pattern in text_lower:
341
+ return True
342
+
343
+ return False
utils/local_llm.py CHANGED
@@ -385,39 +385,51 @@ class LocalParaphraser:
385
  return result if result else text
386
 
387
  def create_clinical_scenarios(self, question: str, answer: str) -> list:
388
- """Create different clinical scenarios from Q&A pairs using MedAlpaca"""
389
  scenarios = []
390
 
391
  # Different clinical context prompts
392
  context_prompts = [
393
  (
394
- "Rewrite this medical question as if asked by a patient in an emergency room setting:",
395
  "emergency_room"
396
  ),
397
  (
398
- "Rewrite this medical question as if asked by a patient during a routine checkup:",
399
  "routine_checkup"
400
  ),
401
  (
402
- "Rewrite this medical question as if asked by a patient with chronic conditions:",
403
  "chronic_care"
404
  ),
405
  (
406
- "Rewrite this medical question as if asked by a patient's family member:",
407
  "family_inquiry"
408
  )
409
  ]
410
 
411
- for prompt_template, scenario_type in context_prompts:
412
- try:
413
- prompt = f"{prompt_template}\n\nOriginal question: {question}\n\nRewritten question:"
414
- scenario_question = self.client.generate(prompt, max_tokens=min(400, len(question)+50), temperature=0.2)
415
-
416
- if scenario_question and not self._is_invalid_response(scenario_question):
417
- scenarios.append((scenario_question, answer, scenario_type))
418
- except Exception as e:
419
- logger.warning(f"Failed to create clinical scenario {scenario_type}: {e}")
420
- continue
 
 
 
 
 
 
 
 
 
 
 
 
421
 
422
  return scenarios
423
 
 
385
  return result if result else text
386
 
387
  def create_clinical_scenarios(self, question: str, answer: str) -> list:
388
+ """Create different clinical scenarios from Q&A pairs using MedAlpaca with batch optimization"""
389
  scenarios = []
390
 
391
  # Different clinical context prompts
392
  context_prompts = [
393
  (
394
+ "You are a medical professional. Rewrite this medical question as if asked by a patient in an emergency room setting:\n\nOriginal question: {question}\n\nEmergency room question:",
395
  "emergency_room"
396
  ),
397
  (
398
+ "You are a medical professional. Rewrite this medical question as if asked by a patient during a routine checkup:\n\nOriginal question: {question}\n\nRoutine checkup question:",
399
  "routine_checkup"
400
  ),
401
  (
402
+ "You are a medical professional. Rewrite this medical question as if asked by a patient with chronic conditions:\n\nOriginal question: {question}\n\nChronic care question:",
403
  "chronic_care"
404
  ),
405
  (
406
+ "You are a medical professional. Rewrite this medical question as if asked by a patient's family member:\n\nOriginal question: {question}\n\nFamily inquiry question:",
407
  "family_inquiry"
408
  )
409
  ]
410
 
411
+ # Use batch processing for better efficiency
412
+ try:
413
+ prompts = [prompt_template.format(question=question) for prompt_template, _ in context_prompts]
414
+ results = self.client.generate_batch(prompts, max_tokens=min(400, len(question)+50), temperature=0.2)
415
+
416
+ for i, (result, (_, scenario_type)) in enumerate(zip(results, context_prompts)):
417
+ if result and not self._is_invalid_response(result):
418
+ scenarios.append((result, answer, scenario_type))
419
+
420
+ except Exception as e:
421
+ logger.warning(f"Batch clinical scenario creation failed, falling back to individual: {e}")
422
+ # Fallback to individual processing
423
+ for prompt_template, scenario_type in context_prompts:
424
+ try:
425
+ prompt = prompt_template.format(question=question)
426
+ scenario_question = self.client.generate(prompt, max_tokens=min(400, len(question)+50), temperature=0.2)
427
+
428
+ if scenario_question and not self._is_invalid_response(scenario_question):
429
+ scenarios.append((scenario_question, answer, scenario_type))
430
+ except Exception as e:
431
+ logger.warning(f"Failed to create clinical scenario {scenario_type}: {e}")
432
+ continue
433
 
434
  return scenarios
435
 
utils/processor.py CHANGED
@@ -209,22 +209,54 @@ def _build_enriched_variants(user: str, out: str, paraphraser, opts: Dict, stats
209
  return variants
210
 
211
  def _get_answer_style_prompt(strategy: str, question: str, original_answer: str) -> str:
212
- """Generate style-specific prompts for answer enhancement"""
213
  prompts = {
214
- "concise": f"Rewrite this medical answer to be more concise while preserving all key medical information:\n\n{original_answer}",
215
- "detailed": f"Expand this medical answer with more detailed explanations while maintaining accuracy:\n\n{original_answer}",
216
- "clinical": f"Rewrite this answer using more formal clinical language and medical terminology:\n\n{original_answer}",
217
- "patient_friendly": f"Rewrite this medical answer in simpler, more patient-friendly language while keeping it medically accurate:\n\n{original_answer}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  }
219
  return prompts.get(strategy, f"Paraphrase this medical answer: {original_answer}")
220
 
221
  def _get_question_style_prompt(strategy: str, original_question: str, answer: str) -> str:
222
- """Generate style-specific prompts for question enhancement"""
223
  prompts = {
224
- "clarifying": f"Rewrite this medical question to ask for clarification or more specific information:\n\n{original_question}",
225
- "follow_up": f"Create a follow-up question that a patient might ask after this medical question:\n\n{original_question}",
226
- "symptom_focused": f"Rewrite this question to focus more on symptoms and their characteristics:\n\n{original_question}",
227
- "treatment_focused": f"Rewrite this question to focus more on treatment options and management:\n\n{original_question}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  }
229
  return prompts.get(strategy, f"Paraphrase this medical question: {original_question}")
230
 
@@ -321,7 +353,7 @@ def _proc_med_dialog(source, path, writer, paraphraser, opts, sample_limit, stat
321
  continue
322
 
323
  # 1) ALWAYS write the original (cleaned/style-standardised only)
324
- # Enhanced medical accuracy validation
325
  if not A.validate_medical_accuracy(user, out, paraphraser):
326
  stats["medical_accuracy_failed"] = stats.get("medical_accuracy_failed", 0) + 1
327
  applied.append("medical_accuracy_flag")
@@ -345,7 +377,12 @@ def _proc_med_dialog(source, path, writer, paraphraser, opts, sample_limit, stat
345
 
346
  # Add clinical scenarios for enhanced diversity
347
  if opts.get("clinical_scenarios", True):
348
- clinical_scenarios = A.create_clinical_scenarios(user, out, paraphraser)
 
 
 
 
 
349
  for (scenario_q, scenario_a, scenario_tag) in clinical_scenarios:
350
  rid_scenario = f"{rid}-scenario{random.randint(1000,9999)}"
351
  _commit_row(writer, source, rid_scenario, "medical_dialogue", instr, scenario_q, scenario_a, opts, stats, [scenario_tag], dedupe_seen=dedupe_seen, translator=translator)
 
209
  return variants
210
 
211
  def _get_answer_style_prompt(strategy: str, question: str, original_answer: str) -> str:
212
+ """Generate style-specific prompts for answer enhancement with medical focus"""
213
  prompts = {
214
+ "concise": (
215
+ "You are a medical professional. Rewrite this medical answer to be more concise while preserving all key medical information, clinical facts, and diagnostic details:\n\n"
216
+ f"Original answer: {original_answer}\n\n"
217
+ "Concise medical answer:"
218
+ ),
219
+ "detailed": (
220
+ "You are a medical expert. Expand this medical answer with more detailed explanations, clinical context, and additional medical information while maintaining accuracy:\n\n"
221
+ f"Original answer: {original_answer}\n\n"
222
+ "Detailed medical answer:"
223
+ ),
224
+ "clinical": (
225
+ "You are a clinical specialist. Rewrite this answer using more formal clinical language, precise medical terminology, and professional medical communication style:\n\n"
226
+ f"Original answer: {original_answer}\n\n"
227
+ "Clinical medical answer:"
228
+ ),
229
+ "patient_friendly": (
230
+ "You are a medical professional. Rewrite this medical answer in simpler, more patient-friendly language while keeping it medically accurate and informative:\n\n"
231
+ f"Original answer: {original_answer}\n\n"
232
+ "Patient-friendly medical answer:"
233
+ )
234
  }
235
  return prompts.get(strategy, f"Paraphrase this medical answer: {original_answer}")
236
 
237
  def _get_question_style_prompt(strategy: str, original_question: str, answer: str) -> str:
238
+ """Generate style-specific prompts for question enhancement with medical focus"""
239
  prompts = {
240
+ "clarifying": (
241
+ "You are a medical professional. Rewrite this medical question to ask for clarification or more specific medical information:\n\n"
242
+ f"Original question: {original_question}\n\n"
243
+ "Clarifying medical question:"
244
+ ),
245
+ "follow_up": (
246
+ "You are a medical professional. Create a follow-up question that a patient might ask after this medical question, focusing on related medical concerns:\n\n"
247
+ f"Original question: {original_question}\n\n"
248
+ "Follow-up medical question:"
249
+ ),
250
+ "symptom_focused": (
251
+ "You are a medical professional. Rewrite this question to focus more on symptoms, their characteristics, and clinical presentation:\n\n"
252
+ f"Original question: {original_question}\n\n"
253
+ "Symptom-focused medical question:"
254
+ ),
255
+ "treatment_focused": (
256
+ "You are a medical professional. Rewrite this question to focus more on treatment options, management strategies, and therapeutic approaches:\n\n"
257
+ f"Original question: {original_question}\n\n"
258
+ "Treatment-focused medical question:"
259
+ )
260
  }
261
  return prompts.get(strategy, f"Paraphrase this medical question: {original_question}")
262
 
 
353
  continue
354
 
355
  # 1) ALWAYS write the original (cleaned/style-standardised only)
356
+ # Enhanced medical accuracy validation (optimized for both cloud and local modes)
357
  if not A.validate_medical_accuracy(user, out, paraphraser):
358
  stats["medical_accuracy_failed"] = stats.get("medical_accuracy_failed", 0) + 1
359
  applied.append("medical_accuracy_flag")
 
377
 
378
  # Add clinical scenarios for enhanced diversity
379
  if opts.get("clinical_scenarios", True):
380
+ # Use dedicated method if available (both cloud and local modes now support this)
381
+ if hasattr(paraphraser, 'create_clinical_scenarios'):
382
+ clinical_scenarios = paraphraser.create_clinical_scenarios(user, out)
383
+ else:
384
+ clinical_scenarios = A.create_clinical_scenarios(user, out, paraphraser)
385
+
386
  for (scenario_q, scenario_a, scenario_tag) in clinical_scenarios:
387
  rid_scenario = f"{rid}-scenario{random.randint(1000,9999)}"
388
  _commit_row(writer, source, rid_scenario, "medical_dialogue", instr, scenario_q, scenario_a, opts, stats, [scenario_tag], dedupe_seen=dedupe_seen, translator=translator)