LiamKhoaLe commited on
Commit
e76f718
·
1 Parent(s): dca816b

Upd local llm infer

Browse files
Files changed (4) hide show
  1. app.py +3 -1
  2. utils/augment.py +40 -27
  3. utils/local_llm.py +234 -29
  4. utils/rag.py +38 -15
app.py CHANGED
@@ -456,7 +456,9 @@ def _run_job(dataset_key: str, params: ProcessParams):
456
  seed=params.seed,
457
  progress_cb=lambda p, msg=None: set_state(progress=p, message=msg or STATE["message"]),
458
  translator=translator,
459
- paraphraser=paraphraser
 
 
460
  )
461
  else:
462
  # Standard SFT processing mode
 
456
  seed=params.seed,
457
  progress_cb=lambda p, msg=None: set_state(progress=p, message=msg or STATE["message"]),
458
  translator=translator,
459
+ paraphraser=paraphraser,
460
+ is_local=IS_LOCAL,
461
+ hf_token=os.getenv("HF_TOKEN")
462
  )
463
  else:
464
  # Standard SFT processing mode
utils/augment.py CHANGED
@@ -252,8 +252,11 @@ def validate_medical_accuracy(question: str, answer: str, paraphraser) -> bool:
252
  return False
253
 
254
  try:
255
- # Use the existing consistency check but with medical focus
256
- return paraphraser.consistency_check(question, answer)
 
 
 
257
  except Exception as e:
258
  logger.warning(f"Medical accuracy validation failed: {e}")
259
  return True # Default to accepting if validation fails
@@ -264,15 +267,21 @@ def enhance_medical_terminology(text: str, paraphraser) -> str:
264
  return text
265
 
266
  try:
267
- prompt = (
268
- "Improve the medical terminology in this text while preserving all factual information:\n\n"
269
- f"{text}\n\n"
270
- "Return only the improved text with better medical terminology:"
271
- )
272
-
273
- enhanced = paraphraser.paraphrase(text, difficulty="hard", custom_prompt=prompt)
274
- if enhanced and not is_invalid_response(enhanced):
275
- return enhanced
 
 
 
 
 
 
276
  except Exception as e:
277
  logger.warning(f"Medical terminology enhancement failed: {e}")
278
 
@@ -283,22 +292,26 @@ def create_clinical_scenarios(question: str, answer: str, paraphraser) -> list:
283
  scenarios = []
284
 
285
  try:
286
- # Generate different clinical contexts
287
- context_prompts = [
288
- f"Rewrite this medical question as if asked by a patient in an emergency room:\n\n{question}",
289
- f"Rewrite this medical question as if asked by a patient in a routine checkup:\n\n{question}",
290
- f"Rewrite this medical question as if asked by a patient with chronic conditions:\n\n{question}",
291
- f"Rewrite this medical question as if asked by a patient's family member:\n\n{question}"
292
- ]
293
-
294
- for i, prompt in enumerate(context_prompts):
295
- try:
296
- scenario_question = paraphraser.paraphrase(question, difficulty="hard", custom_prompt=prompt)
297
- if scenario_question and not is_invalid_response(scenario_question):
298
- scenarios.append((scenario_question, answer, f"clinical_scenario_{i+1}"))
299
- except Exception as e:
300
- logger.warning(f"Failed to create clinical scenario {i+1}: {e}")
301
- continue
 
 
 
 
302
 
303
  except Exception as e:
304
  logger.warning(f"Clinical scenario creation failed: {e}")
 
252
  return False
253
 
254
  try:
255
+ # Use medical accuracy check if available (local mode), otherwise fallback to consistency check
256
+ if hasattr(paraphraser, 'medical_accuracy_check'):
257
+ return paraphraser.medical_accuracy_check(question, answer)
258
+ else:
259
+ return paraphraser.consistency_check(question, answer)
260
  except Exception as e:
261
  logger.warning(f"Medical accuracy validation failed: {e}")
262
  return True # Default to accepting if validation fails
 
267
  return text
268
 
269
  try:
270
+ # Use dedicated method if available (local mode), otherwise use paraphrase with custom prompt
271
+ if hasattr(paraphraser, 'enhance_medical_terminology'):
272
+ enhanced = paraphraser.enhance_medical_terminology(text)
273
+ if enhanced and not is_invalid_response(enhanced):
274
+ return enhanced
275
+ else:
276
+ prompt = (
277
+ "Improve the medical terminology in this text while preserving all factual information:\n\n"
278
+ f"{text}\n\n"
279
+ "Return only the improved text with better medical terminology:"
280
+ )
281
+
282
+ enhanced = paraphraser.paraphrase(text, difficulty="hard", custom_prompt=prompt)
283
+ if enhanced and not is_invalid_response(enhanced):
284
+ return enhanced
285
  except Exception as e:
286
  logger.warning(f"Medical terminology enhancement failed: {e}")
287
 
 
292
  scenarios = []
293
 
294
  try:
295
+ # Use dedicated method if available (local mode), otherwise use paraphrase with custom prompts
296
+ if hasattr(paraphraser, 'create_clinical_scenarios'):
297
+ scenarios = paraphraser.create_clinical_scenarios(question, answer)
298
+ else:
299
+ # Fallback to original implementation
300
+ context_prompts = [
301
+ f"Rewrite this medical question as if asked by a patient in an emergency room:\n\n{question}",
302
+ f"Rewrite this medical question as if asked by a patient in a routine checkup:\n\n{question}",
303
+ f"Rewrite this medical question as if asked by a patient with chronic conditions:\n\n{question}",
304
+ f"Rewrite this medical question as if asked by a patient's family member:\n\n{question}"
305
+ ]
306
+
307
+ for i, prompt in enumerate(context_prompts):
308
+ try:
309
+ scenario_question = paraphraser.paraphrase(question, difficulty="hard", custom_prompt=prompt)
310
+ if scenario_question and not is_invalid_response(scenario_question):
311
+ scenarios.append((scenario_question, answer, f"clinical_scenario_{i+1}"))
312
+ except Exception as e:
313
+ logger.warning(f"Failed to create clinical scenario {i+1}: {e}")
314
+ continue
315
 
316
  except Exception as e:
317
  logger.warning(f"Clinical scenario creation failed: {e}")
utils/local_llm.py CHANGED
@@ -94,16 +94,20 @@ class MedAlpacaClient:
94
  max_length=2048
95
  ).to(self.device)
96
 
97
- # Generate
98
  with torch.no_grad():
99
  outputs = self.model.generate(
100
  **inputs,
101
  max_new_tokens=max_tokens,
102
  temperature=temperature,
103
- do_sample=True,
104
  pad_token_id=self.tokenizer.eos_token_id,
105
  eos_token_id=self.tokenizer.eos_token_id,
106
- repetition_penalty=1.1
 
 
 
 
107
  )
108
 
109
  # Decode output
@@ -123,28 +127,36 @@ class MedAlpacaClient:
123
  return None
124
 
125
  def _format_prompt(self, prompt: str) -> str:
126
- """Format prompt for MedAlpaca model"""
127
- # MedAlpaca uses a specific format for medical Q&A
128
  if "Question:" in prompt and "Answer:" in prompt:
129
  return prompt
130
  elif "Context:" in prompt and "Question:" in prompt:
131
  return prompt
 
 
 
132
  else:
133
- # Simple medical Q&A format
134
- return f"Question: {prompt}\n\nAnswer:"
135
 
136
  def _clean_response(self, text: str) -> str:
137
- """Clean generated response"""
138
  if not text:
139
  return text
140
 
141
- # Remove common prefixes
142
  prefixes_to_remove = [
143
  "Answer:",
144
  "The answer is:",
145
  "Based on the information provided:",
146
  "Here's the answer:",
147
  "Here is the answer:",
 
 
 
 
 
148
  ]
149
 
150
  text = text.strip()
@@ -152,7 +164,13 @@ class MedAlpacaClient:
152
  if text.startswith(prefix):
153
  text = text[len(prefix):].strip()
154
  break
155
-
 
 
 
 
 
 
156
  return text
157
 
158
  def _snip(self, text: str, max_words: int = 12) -> str:
@@ -162,6 +180,61 @@ class MedAlpacaClient:
162
  words = text.strip().split()
163
  return " ".join(words[:max_words]) + (" …" if len(words) > max_words else "")
164
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  def unload_model(self):
166
  """Unload model to free memory"""
167
  if self.model is not None:
@@ -185,34 +258,56 @@ class LocalParaphraser:
185
  self.client = MedAlpacaClient(model_name, hf_token)
186
 
187
  def paraphrase(self, text: str, difficulty: str = "easy", custom_prompt: str = None) -> str:
188
- """Paraphrase text using MedAlpaca"""
189
  if not text or len(text) < 12:
190
  return text
191
 
192
  if custom_prompt:
193
  prompt = custom_prompt
194
  else:
195
- prompt = (
196
- "Paraphrase the following medical text concisely, preserve meaning and clinical terms.\n"
197
- "Do not fabricate or remove factual claims.\n"
198
- "Return ONLY the rewritten text, without any introduction, commentary.\n\n"
199
- f"Original text: {text}"
200
- )
 
 
 
 
 
 
 
201
 
202
- result = self.client.generate(prompt, max_tokens=min(600, max(128, len(text)//2)), temperature=0.1)
 
 
203
  return result if result else text
204
 
205
  def translate(self, text: str, target_lang: str = "vi") -> Optional[str]:
206
- """Translate text using MedAlpaca"""
207
  if not text:
208
  return text
209
 
210
- prompt = f"Translate the following medical text to {target_lang}. Keep meaning exact, preserve medical terms:\n\n{text}"
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  result = self.client.generate(prompt, max_tokens=min(800, len(text)+100), temperature=0.0)
212
  return result.strip() if result else None
213
 
214
  def backtranslate(self, text: str, via_lang: str = "vi") -> Optional[str]:
215
- """Backtranslate text using MedAlpaca"""
216
  if not text:
217
  return text
218
 
@@ -221,23 +316,133 @@ class LocalParaphraser:
221
  if not translated:
222
  return None
223
 
224
- # Then translate back to English
225
- prompt = f"Translate the following {via_lang} text back to English, preserving the exact meaning:\n\n{translated}"
 
 
 
 
 
 
 
 
 
 
 
 
226
  result = self.client.generate(prompt, max_tokens=min(900, len(text)+150), temperature=0.0)
227
  return result.strip() if result else None
228
 
229
  def consistency_check(self, user: str, output: str) -> bool:
230
- """Check consistency using MedAlpaca"""
231
  prompt = (
232
- "You are a strict medical QA validator. Given the USER input (question+context) "
233
- "and the MODEL ANSWER, reply with exactly 'PASS' if the answer is supported and safe, "
234
- "otherwise 'FAIL'. No extra text.\n\n"
235
- f"USER:\n{user}\n\nANSWER:\n{output}"
 
 
 
 
 
236
  )
237
 
238
- result = self.client.generate(prompt, max_tokens=3, temperature=0.0)
239
  return isinstance(result, str) and "PASS" in result.upper()
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  def unload(self):
242
  """Unload the model"""
243
  self.client.unload_model()
 
94
  max_length=2048
95
  ).to(self.device)
96
 
97
+ # Generate with optimized parameters for MedAlpaca
98
  with torch.no_grad():
99
  outputs = self.model.generate(
100
  **inputs,
101
  max_new_tokens=max_tokens,
102
  temperature=temperature,
103
+ do_sample=True if temperature > 0 else False,
104
  pad_token_id=self.tokenizer.eos_token_id,
105
  eos_token_id=self.tokenizer.eos_token_id,
106
+ repetition_penalty=1.1,
107
+ top_p=0.9 if temperature > 0 else 1.0,
108
+ top_k=50 if temperature > 0 else 0,
109
+ num_beams=1 if temperature > 0 else 4,
110
+ early_stopping=True
111
  )
112
 
113
  # Decode output
 
127
  return None
128
 
129
  def _format_prompt(self, prompt: str) -> str:
130
+ """Format prompt for MedAlpaca model with medical-specific formatting"""
131
+ # MedAlpaca was trained on medical Q&A pairs, so we use its expected format
132
  if "Question:" in prompt and "Answer:" in prompt:
133
  return prompt
134
  elif "Context:" in prompt and "Question:" in prompt:
135
  return prompt
136
+ elif "You are a" in prompt or "medical" in prompt.lower():
137
+ # For medical instructions, use Alpaca format
138
+ return f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{prompt}\n\n### Response:"
139
  else:
140
+ # Default medical Q&A format for MedAlpaca
141
+ return f"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nAnswer the following medical question accurately and professionally.\n\n### Input:\n{prompt}\n\n### Response:"
142
 
143
  def _clean_response(self, text: str) -> str:
144
+ """Clean generated response with medical-specific cleaning"""
145
  if not text:
146
  return text
147
 
148
+ # Remove common prefixes and Alpaca format artifacts
149
  prefixes_to_remove = [
150
  "Answer:",
151
  "The answer is:",
152
  "Based on the information provided:",
153
  "Here's the answer:",
154
  "Here is the answer:",
155
+ "### Response:",
156
+ "Response:",
157
+ "Below is an instruction",
158
+ "### Instruction:",
159
+ "Instruction:",
160
  ]
161
 
162
  text = text.strip()
 
164
  if text.startswith(prefix):
165
  text = text[len(prefix):].strip()
166
  break
167
+
168
+ # Remove any remaining Alpaca format artifacts
169
+ if "### Response:" in text:
170
+ text = text.split("### Response:")[-1].strip()
171
+ if "### Input:" in text:
172
+ text = text.split("### Input:")[0].strip()
173
+
174
  return text
175
 
176
  def _snip(self, text: str, max_words: int = 12) -> str:
 
180
  words = text.strip().split()
181
  return " ".join(words[:max_words]) + (" …" if len(words) > max_words else "")
182
 
183
+ def generate_batch(self, prompts: list, max_tokens: int = 512, temperature: float = 0.2) -> list:
184
+ """Generate text for multiple prompts in batch for better efficiency"""
185
+ if not self.is_loaded:
186
+ self.load_model()
187
+
188
+ if not prompts:
189
+ return []
190
+
191
+ try:
192
+ # Format all prompts
193
+ formatted_prompts = [self._format_prompt(prompt) for prompt in prompts]
194
+
195
+ # Tokenize all inputs
196
+ inputs = self.tokenizer(
197
+ formatted_prompts,
198
+ return_tensors="pt",
199
+ padding=True,
200
+ truncation=True,
201
+ max_length=2048
202
+ ).to(self.device)
203
+
204
+ # Generate for all prompts
205
+ with torch.no_grad():
206
+ outputs = self.model.generate(
207
+ **inputs,
208
+ max_new_tokens=max_tokens,
209
+ temperature=temperature,
210
+ do_sample=True if temperature > 0 else False,
211
+ pad_token_id=self.tokenizer.eos_token_id,
212
+ eos_token_id=self.tokenizer.eos_token_id,
213
+ repetition_penalty=1.1,
214
+ top_p=0.9 if temperature > 0 else 1.0,
215
+ top_k=50 if temperature > 0 else 0,
216
+ num_beams=1 if temperature > 0 else 4,
217
+ early_stopping=True
218
+ )
219
+
220
+ # Decode all outputs
221
+ results = []
222
+ input_length = inputs['input_ids'].shape[1]
223
+ for i, output in enumerate(outputs):
224
+ generated_text = self.tokenizer.decode(
225
+ output[input_length:],
226
+ skip_special_tokens=True
227
+ ).strip()
228
+ cleaned_text = self._clean_response(generated_text)
229
+ results.append(cleaned_text)
230
+
231
+ logger.info(f"[LOCAL_LLM] Generated batch of {len(prompts)} texts")
232
+ return results
233
+
234
+ except Exception as e:
235
+ logger.error(f"[LOCAL_LLM] Batch generation failed: {e}")
236
+ return [None] * len(prompts)
237
+
238
  def unload_model(self):
239
  """Unload model to free memory"""
240
  if self.model is not None:
 
258
  self.client = MedAlpacaClient(model_name, hf_token)
259
 
260
  def paraphrase(self, text: str, difficulty: str = "easy", custom_prompt: str = None) -> str:
261
+ """Paraphrase text using MedAlpaca with medical-specific optimization"""
262
  if not text or len(text) < 12:
263
  return text
264
 
265
  if custom_prompt:
266
  prompt = custom_prompt
267
  else:
268
+ # Medical-specific paraphrasing prompts based on difficulty
269
+ if difficulty == "easy":
270
+ prompt = (
271
+ "You are a medical professional. Rewrite the following medical text using different words while preserving all medical facts, clinical terms, and meaning. Keep the same level of detail and accuracy.\n\n"
272
+ f"Original medical text: {text}\n\n"
273
+ "Rewritten medical text:"
274
+ )
275
+ else: # hard difficulty
276
+ prompt = (
277
+ "You are a medical expert. Rewrite the following medical text using more sophisticated medical language and different sentence structures while preserving all clinical facts, medical terminology, and diagnostic information. Maintain professional medical tone.\n\n"
278
+ f"Original medical text: {text}\n\n"
279
+ "Enhanced medical text:"
280
+ )
281
 
282
+ # Adjust temperature based on difficulty
283
+ temperature = 0.1 if difficulty == "easy" else 0.3
284
+ result = self.client.generate(prompt, max_tokens=min(600, max(128, len(text)//2)), temperature=temperature)
285
  return result if result else text
286
 
287
  def translate(self, text: str, target_lang: str = "vi") -> Optional[str]:
288
+ """Translate text using MedAlpaca with medical terminology preservation"""
289
  if not text:
290
  return text
291
 
292
+ # Medical-specific translation prompt
293
+ if target_lang == "vi":
294
+ prompt = (
295
+ "You are a medical translator. Translate the following English medical text to Vietnamese while preserving all medical terminology, clinical facts, and professional medical language. Use appropriate Vietnamese medical terms.\n\n"
296
+ f"English medical text: {text}\n\n"
297
+ "Vietnamese medical translation:"
298
+ )
299
+ else:
300
+ prompt = (
301
+ f"You are a medical translator. Translate the following medical text to {target_lang} while preserving all medical terminology, clinical facts, and professional medical language.\n\n"
302
+ f"Original medical text: {text}\n\n"
303
+ f"{target_lang} medical translation:"
304
+ )
305
+
306
  result = self.client.generate(prompt, max_tokens=min(800, len(text)+100), temperature=0.0)
307
  return result.strip() if result else None
308
 
309
  def backtranslate(self, text: str, via_lang: str = "vi") -> Optional[str]:
310
+ """Backtranslate text using MedAlpaca with medical accuracy"""
311
  if not text:
312
  return text
313
 
 
316
  if not translated:
317
  return None
318
 
319
+ # Then translate back to English with medical focus
320
+ if via_lang == "vi":
321
+ prompt = (
322
+ "You are a medical translator. Translate the following Vietnamese medical text back to English while preserving all medical terminology, clinical facts, and professional medical language. Ensure the translation is medically accurate.\n\n"
323
+ f"Vietnamese medical text: {translated}\n\n"
324
+ "English medical translation:"
325
+ )
326
+ else:
327
+ prompt = (
328
+ f"You are a medical translator. Translate the following {via_lang} medical text back to English while preserving all medical terminology, clinical facts, and professional medical language.\n\n"
329
+ f"{via_lang} medical text: {translated}\n\n"
330
+ "English medical translation:"
331
+ )
332
+
333
  result = self.client.generate(prompt, max_tokens=min(900, len(text)+150), temperature=0.0)
334
  return result.strip() if result else None
335
 
336
  def consistency_check(self, user: str, output: str) -> bool:
337
+ """Check consistency using MedAlpaca with medical validation focus"""
338
  prompt = (
339
+ "You are a medical quality assurance expert. Evaluate if the medical answer is consistent with the question/context and medically accurate. Consider:\n"
340
+ "1. Medical accuracy and clinical appropriateness\n"
341
+ "2. Consistency with the question asked\n"
342
+ "3. Safety and professional medical standards\n"
343
+ "4. Completeness of the medical information\n\n"
344
+ "Reply with exactly 'PASS' if the answer is medically sound and consistent, otherwise 'FAIL'.\n\n"
345
+ f"Question/Context: {user}\n\n"
346
+ f"Medical Answer: {output}\n\n"
347
+ "Evaluation:"
348
  )
349
 
350
+ result = self.client.generate(prompt, max_tokens=5, temperature=0.0)
351
  return isinstance(result, str) and "PASS" in result.upper()
352
 
353
+ def medical_accuracy_check(self, question: str, answer: str) -> bool:
354
+ """Check medical accuracy of Q&A pairs using MedAlpaca"""
355
+ if not question or not answer:
356
+ return False
357
+
358
+ prompt = (
359
+ "You are a medical accuracy validator. Evaluate if the medical answer is accurate and appropriate for the question. Consider:\n"
360
+ "1. Medical facts and clinical knowledge\n"
361
+ "2. Appropriate medical terminology\n"
362
+ "3. Clinical reasoning and logic\n"
363
+ "4. Safety considerations\n\n"
364
+ "Reply with exactly 'ACCURATE' if the answer is medically correct, otherwise 'INACCURATE'.\n\n"
365
+ f"Medical Question: {question}\n\n"
366
+ f"Medical Answer: {answer}\n\n"
367
+ "Medical Accuracy Assessment:"
368
+ )
369
+
370
+ result = self.client.generate(prompt, max_tokens=5, temperature=0.0)
371
+ return isinstance(result, str) and "ACCURATE" in result.upper()
372
+
373
+ def enhance_medical_terminology(self, text: str) -> str:
374
+ """Enhance medical terminology in text using MedAlpaca"""
375
+ if not text or len(text) < 20:
376
+ return text
377
+
378
+ prompt = (
379
+ "You are a medical terminology expert. Improve the medical terminology in the following text while preserving all factual information and clinical accuracy. Use more precise medical terms where appropriate.\n\n"
380
+ f"Original text: {text}\n\n"
381
+ "Enhanced medical text:"
382
+ )
383
+
384
+ result = self.client.generate(prompt, max_tokens=min(800, len(text)+100), temperature=0.1)
385
+ return result if result else text
386
+
387
+ def create_clinical_scenarios(self, question: str, answer: str) -> list:
388
+ """Create different clinical scenarios from Q&A pairs using MedAlpaca"""
389
+ scenarios = []
390
+
391
+ # Different clinical context prompts
392
+ context_prompts = [
393
+ (
394
+ "Rewrite this medical question as if asked by a patient in an emergency room setting:",
395
+ "emergency_room"
396
+ ),
397
+ (
398
+ "Rewrite this medical question as if asked by a patient during a routine checkup:",
399
+ "routine_checkup"
400
+ ),
401
+ (
402
+ "Rewrite this medical question as if asked by a patient with chronic conditions:",
403
+ "chronic_care"
404
+ ),
405
+ (
406
+ "Rewrite this medical question as if asked by a patient's family member:",
407
+ "family_inquiry"
408
+ )
409
+ ]
410
+
411
+ for prompt_template, scenario_type in context_prompts:
412
+ try:
413
+ prompt = f"{prompt_template}\n\nOriginal question: {question}\n\nRewritten question:"
414
+ scenario_question = self.client.generate(prompt, max_tokens=min(400, len(question)+50), temperature=0.2)
415
+
416
+ if scenario_question and not self._is_invalid_response(scenario_question):
417
+ scenarios.append((scenario_question, answer, scenario_type))
418
+ except Exception as e:
419
+ logger.warning(f"Failed to create clinical scenario {scenario_type}: {e}")
420
+ continue
421
+
422
+ return scenarios
423
+
424
+ def _is_invalid_response(self, text: str) -> bool:
425
+ """Check if response is invalid (similar to augment.py)"""
426
+ if not text or not isinstance(text, str):
427
+ return True
428
+
429
+ text_lower = text.lower().strip()
430
+ invalid_patterns = [
431
+ "fail", "invalid", "i couldn't", "i can't", "i cannot", "unable to",
432
+ "sorry", "error", "not available", "no answer", "insufficient",
433
+ "don't know", "do not know", "not sure", "cannot determine",
434
+ "unable to provide", "not possible", "not applicable", "n/a"
435
+ ]
436
+
437
+ if len(text_lower) < 3:
438
+ return True
439
+
440
+ for pattern in invalid_patterns:
441
+ if pattern in text_lower:
442
+ return True
443
+
444
+ return False
445
+
446
  def unload(self):
447
  """Unload the model"""
448
  self.client.unload_model()
utils/rag.py CHANGED
@@ -7,6 +7,7 @@ from typing import Dict, List, Tuple, Optional, Callable
7
 
8
  from utils.schema import sft_row, rag_row
9
  from utils.cloud_llm import NvidiaClient, KeyRotator
 
10
  from vi.processing import should_translate, translate_rag_row
11
  from utils import augment as A
12
 
@@ -41,11 +42,17 @@ def _iter_json_or_jsonl(path: str):
41
  class RAGProcessor:
42
  """Processes medical datasets into RAG-specific QCA (Question, Context, Answer) format"""
43
 
44
- def __init__(self, nvidia_model: str):
45
- self.nvidia_client = NvidiaClient(KeyRotator("NVIDIA_API"), nvidia_model)
 
 
 
 
 
 
46
 
47
  def clean_conversational_content(self, text: str) -> str:
48
- """Remove conversational elements and non-medical information using NVIDIA model; keep concise for embeddings."""
49
  if not text or len(text.strip()) < 10:
50
  return text
51
 
@@ -64,11 +71,18 @@ class RAGProcessor:
64
  Cleaned medical content:"""
65
 
66
  try:
67
- cleaned = self.nvidia_client.generate(
68
- prompt,
69
- temperature=0.1,
70
- max_tokens=min(1000, len(text) + 200)
71
- )
 
 
 
 
 
 
 
72
  return cleaned.strip() if cleaned else text
73
  except Exception as e:
74
  logger.warning(f"[RAG] Error cleaning text: {e}")
@@ -88,11 +102,18 @@ class RAGProcessor:
88
  Generate a concise medical context:"""
89
 
90
  try:
91
- context = self.nvidia_client.generate(
92
- prompt,
93
- temperature=0.2,
94
- max_tokens=200
95
- )
 
 
 
 
 
 
 
96
  # Trim to a single short paragraph
97
  return (context or "").strip().split("\n")[0][:600]
98
  except Exception as e:
@@ -330,7 +351,9 @@ def process_file_into_rag(
330
  seed: int,
331
  progress_cb: Optional[Callable[[float, str], None]],
332
  translator=None,
333
- paraphraser=None
 
 
334
  ) -> Tuple[int, Dict]:
335
  """Main entry point for RAG processing"""
336
  random.seed(seed)
@@ -342,7 +365,7 @@ def process_file_into_rag(
342
  logger.info(f"[RAG] Begin RAG processing dataset={dataset_key} sample_limit={sample_limit}")
343
 
344
  # Initialize RAG processor
345
- rag_processor = RAGProcessor(nvidia_model)
346
  dedupe_seen = set()
347
 
348
  key = dataset_key.lower()
 
7
 
8
  from utils.schema import sft_row, rag_row
9
  from utils.cloud_llm import NvidiaClient, KeyRotator
10
+ from utils.local_llm import MedAlpacaClient
11
  from vi.processing import should_translate, translate_rag_row
12
  from utils import augment as A
13
 
 
42
  class RAGProcessor:
43
  """Processes medical datasets into RAG-specific QCA (Question, Context, Answer) format"""
44
 
45
+ def __init__(self, nvidia_model: str, is_local: bool = False, hf_token: str = None):
46
+ self.is_local = is_local
47
+ if is_local:
48
+ self.medalpaca_client = MedAlpacaClient(hf_token=hf_token)
49
+ self.nvidia_client = None
50
+ else:
51
+ self.nvidia_client = NvidiaClient(KeyRotator("NVIDIA_API"), nvidia_model)
52
+ self.medalpaca_client = None
53
 
54
  def clean_conversational_content(self, text: str) -> str:
55
+ """Remove conversational elements and non-medical information using MedAlpaca or NVIDIA model; keep concise for embeddings."""
56
  if not text or len(text.strip()) < 10:
57
  return text
58
 
 
71
  Cleaned medical content:"""
72
 
73
  try:
74
+ if self.is_local and self.medalpaca_client:
75
+ cleaned = self.medalpaca_client.generate(
76
+ prompt,
77
+ temperature=0.1,
78
+ max_tokens=min(1000, len(text) + 200)
79
+ )
80
+ else:
81
+ cleaned = self.nvidia_client.generate(
82
+ prompt,
83
+ temperature=0.1,
84
+ max_tokens=min(1000, len(text) + 200)
85
+ )
86
  return cleaned.strip() if cleaned else text
87
  except Exception as e:
88
  logger.warning(f"[RAG] Error cleaning text: {e}")
 
102
  Generate a concise medical context:"""
103
 
104
  try:
105
+ if self.is_local and self.medalpaca_client:
106
+ context = self.medalpaca_client.generate(
107
+ prompt,
108
+ temperature=0.2,
109
+ max_tokens=200
110
+ )
111
+ else:
112
+ context = self.nvidia_client.generate(
113
+ prompt,
114
+ temperature=0.2,
115
+ max_tokens=200
116
+ )
117
  # Trim to a single short paragraph
118
  return (context or "").strip().split("\n")[0][:600]
119
  except Exception as e:
 
351
  seed: int,
352
  progress_cb: Optional[Callable[[float, str], None]],
353
  translator=None,
354
+ paraphraser=None,
355
+ is_local: bool = False,
356
+ hf_token: str = None
357
  ) -> Tuple[int, Dict]:
358
  """Main entry point for RAG processing"""
359
  random.seed(seed)
 
365
  logger.info(f"[RAG] Begin RAG processing dataset={dataset_key} sample_limit={sample_limit}")
366
 
367
  # Initialize RAG processor
368
+ rag_processor = RAGProcessor(nvidia_model, is_local=is_local, hf_token=hf_token)
369
  dedupe_seen = set()
370
 
371
  key = dataset_key.lower()