LiamKhoaLe commited on
Commit
a7fd3ba
·
1 Parent(s): 19d62ff

Enhance RAG conciseness and SFT aug

Browse files
Files changed (3) hide show
  1. utils/augment.py +16 -1
  2. utils/rag.py +12 -7
  3. vi/processing.py +34 -2
utils/augment.py CHANGED
@@ -1,5 +1,6 @@
1
  # augmentation utility agent
2
  import re
 
3
  import random
4
  from typing import Dict, Tuple
5
  import ftfy
@@ -94,7 +95,21 @@ def maybe_backtranslate(text: str, ratio: float, paraphraser) -> Tuple[str, bool
94
  if ratio <= 0 or not text: return text, False
95
  if random.random() < ratio:
96
  bt = paraphraser.backtranslate(text, via_lang="vi")
97
- return bt if bt else text, bool(bt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  return text, False
99
 
100
  def consistency_ok(user: str, out: str, ratio: float, paraphraser) -> bool:
 
1
  # augmentation utility agent
2
  import re
3
+ import difflib
4
  import random
5
  from typing import Dict, Tuple
6
  import ftfy
 
95
  if ratio <= 0 or not text: return text, False
96
  if random.random() < ratio:
97
  bt = paraphraser.backtranslate(text, via_lang="vi")
98
+ if not bt:
99
+ return text, False
100
+ # Guardrails: reject if too short/long or too dissimilar/similar
101
+ try:
102
+ orig_len = max(1, len(text))
103
+ len_delta = abs(len(bt) - len(text)) / orig_len
104
+ sim = difflib.SequenceMatcher(None, text, bt).ratio()
105
+ # Accept if moderate change and not excessive drift
106
+ if len_delta > 0.5:
107
+ return text, False
108
+ if sim < 0.45 or sim > 0.98:
109
+ return text, False
110
+ except Exception:
111
+ pass
112
+ return bt, True
113
  return text, False
114
 
115
  def consistency_ok(user: str, out: str, ratio: float, paraphraser) -> bool:
utils/rag.py CHANGED
@@ -44,7 +44,7 @@ class RAGProcessor:
44
  self.nvidia_client = NvidiaClient(KeyRotator("NVIDIA_API"), nvidia_model)
45
 
46
  def clean_conversational_content(self, text: str) -> str:
47
- """Remove conversational elements and non-medical information using NVIDIA model"""
48
  if not text or len(text.strip()) < 10:
49
  return text
50
 
@@ -55,7 +55,7 @@ class RAGProcessor:
55
  3. Keep only medically relevant information
56
  4. Preserve clinical facts, symptoms, diagnoses, treatments, and medical advice
57
  5. Maintain professional medical language
58
- 6. Return only cleaned medical content, only plain text, no special characters, or formatting.
59
 
60
  Text to clean:
61
  {text}
@@ -74,11 +74,11 @@ class RAGProcessor:
74
  return text
75
 
76
  def generate_context_from_qa(self, question: str, answer: str) -> str:
77
- """Generate synthetic context from question and answer using NVIDIA model"""
78
  if not question or not answer:
79
  return ""
80
 
81
- prompt = f"""You are a medical knowledge expert. Given a medical question and its answer, generate a brief relevant medical context that would help someone understand the answer better. Write about 2 sentences that provide relevant background information. Use only plain text without any formatting or symbols.
82
 
83
  Question: {question}
84
 
@@ -92,16 +92,20 @@ class RAGProcessor:
92
  temperature=0.2,
93
  max_tokens=200
94
  )
95
- return context.strip() if context else ""
 
96
  except Exception as e:
97
  logger.warning(f"[RAG] Error generating context: {e}")
98
  return ""
99
 
100
  def convert_to_qca_format(self, instruction: str, user_input: str, output: str) -> Tuple[str, str, str]:
101
- """Convert SFT format to QCA (Question, Context, Answer) format"""
102
  # Clean the content to remove conversational elements
103
  cleaned_input = self.clean_conversational_content(user_input)
104
  cleaned_output = self.clean_conversational_content(output)
 
 
 
105
 
106
  # Extract question from user input
107
  question = self.extract_question(cleaned_input)
@@ -110,7 +114,8 @@ class RAGProcessor:
110
  context = self.extract_context(cleaned_input, question, cleaned_output)
111
 
112
  # Clean answer
113
- answer = cleaned_output
 
114
 
115
  return question, context, answer
116
 
 
44
  self.nvidia_client = NvidiaClient(KeyRotator("NVIDIA_API"), nvidia_model)
45
 
46
  def clean_conversational_content(self, text: str) -> str:
47
+ """Remove conversational elements and non-medical information using NVIDIA model; keep concise for embeddings."""
48
  if not text or len(text.strip()) < 10:
49
  return text
50
 
 
55
  3. Keep only medically relevant information
56
  4. Preserve clinical facts, symptoms, diagnoses, treatments, and medical advice
57
  5. Maintain professional medical language
58
+ 6. Return only cleaned medical content in 1-2 concise sentences suitable for dense retrieval embeddings. No lists, no headers.
59
 
60
  Text to clean:
61
  {text}
 
74
  return text
75
 
76
  def generate_context_from_qa(self, question: str, answer: str) -> str:
77
+ """Generate synthetic, concise context (<=2 sentences) from question and answer, embedding-friendly."""
78
  if not question or not answer:
79
  return ""
80
 
81
+ prompt = f"""You are a medical knowledge expert. Given a medical question and its answer, generate a brief relevant medical context that helps retrieval. Limit to 1–2 sentences, concise, avoid boilerplate, no enumerations.
82
 
83
  Question: {question}
84
 
 
92
  temperature=0.2,
93
  max_tokens=200
94
  )
95
+ # Trim to a single short paragraph
96
+ return (context or "").strip().split("\n")[0][:600]
97
  except Exception as e:
98
  logger.warning(f"[RAG] Error generating context: {e}")
99
  return ""
100
 
101
  def convert_to_qca_format(self, instruction: str, user_input: str, output: str) -> Tuple[str, str, str]:
102
+ """Convert SFT format to QCA (Question, Context, Answer) format, compressing for embedding suitability."""
103
  # Clean the content to remove conversational elements
104
  cleaned_input = self.clean_conversational_content(user_input)
105
  cleaned_output = self.clean_conversational_content(output)
106
+ # Hard caps for embedding friendliness
107
+ cleaned_input = (cleaned_input or "")[:1200]
108
+ cleaned_output = (cleaned_output or "")[:1200]
109
 
110
  # Extract question from user input
111
  question = self.extract_question(cleaned_input)
 
114
  context = self.extract_context(cleaned_input, question, cleaned_output)
115
 
116
  # Clean answer
117
+ # Prefer short, direct answers
118
+ answer = cleaned_output[:800]
119
 
120
  return question, context, answer
121
 
vi/processing.py CHANGED
@@ -7,6 +7,30 @@ from typing import Dict, Any, List, Optional, Callable
7
 
8
  logger = logging.getLogger(__name__)
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def translate_sft_row(row: Dict[str, Any], translator, text_fields: List[str] = None) -> Dict[str, Any]:
11
  """
12
  Translate specific text fields in an SFT row from English to Vietnamese.
@@ -29,6 +53,10 @@ def translate_sft_row(row: Dict[str, Any], translator, text_fields: List[str] =
29
 
30
  try:
31
  translated_row = translator.translate_dict(row, text_fields)
 
 
 
 
32
  logger.debug(f"Translated SFT row with fields: {text_fields}")
33
  return translated_row
34
  except Exception as e:
@@ -52,11 +80,15 @@ def translate_rag_row(row: Dict[str, Any], translator, text_fields: List[str] =
52
  return row
53
 
54
  if text_fields is None:
55
- # Default fields to translate in RAG format
56
- text_fields = ["instruction", "input", "output"]
57
 
58
  try:
59
  translated_row = translator.translate_dict(row, text_fields)
 
 
 
 
60
  logger.debug(f"Translated RAG row with fields: {text_fields}")
61
  return translated_row
62
  except Exception as e:
 
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
+ def _vi_sanitize_text(s: str) -> str:
11
+ """Light Vietnamese sanitization for finetuning and RAG: strip extra spaces, limit repetition, preserve numbers/units."""
12
+ if not isinstance(s, str):
13
+ return s
14
+ t = s.strip()
15
+ # Collapse repeated punctuation and whitespace
16
+ import re
17
+ t = re.sub(r"\s+", " ", t)
18
+ t = re.sub(r"([.?!]){3,}", r"..", t)
19
+ # Remove obvious repetition chunks (very heuristic)
20
+ parts = t.split()
21
+ if len(parts) > 20:
22
+ window = 6
23
+ seen = set()
24
+ filtered = []
25
+ for i in range(len(parts)):
26
+ ngram = " ".join(parts[max(0, i-window):i+1])
27
+ if ngram in seen:
28
+ continue
29
+ seen.add(ngram)
30
+ filtered.append(parts[i])
31
+ t = " ".join(filtered)
32
+ return t
33
+
34
  def translate_sft_row(row: Dict[str, Any], translator, text_fields: List[str] = None) -> Dict[str, Any]:
35
  """
36
  Translate specific text fields in an SFT row from English to Vietnamese.
 
53
 
54
  try:
55
  translated_row = translator.translate_dict(row, text_fields)
56
+ # Sanitize translated fields
57
+ for f in text_fields:
58
+ if f in translated_row.get("sft", {}):
59
+ translated_row["sft"][f] = _vi_sanitize_text(translated_row["sft"][f])
60
  logger.debug(f"Translated SFT row with fields: {text_fields}")
61
  return translated_row
62
  except Exception as e:
 
80
  return row
81
 
82
  if text_fields is None:
83
+ # Default fields to translate in RAG format (Q, A, C)
84
+ text_fields = ["question", "answer", "context"]
85
 
86
  try:
87
  translated_row = translator.translate_dict(row, text_fields)
88
+ # Sanitize translated fields
89
+ for f in text_fields:
90
+ if f in translated_row:
91
+ translated_row[f] = _vi_sanitize_text(translated_row[f])
92
  logger.debug(f"Translated RAG row with fields: {text_fields}")
93
  return translated_row
94
  except Exception as e: