Spaces:
Sleeping
Sleeping
Commit
·
a7fd3ba
1
Parent(s):
19d62ff
Enhance RAG conciseness and SFT aug
Browse files- utils/augment.py +16 -1
- utils/rag.py +12 -7
- vi/processing.py +34 -2
utils/augment.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
# augmentation utility agent
|
| 2 |
import re
|
|
|
|
| 3 |
import random
|
| 4 |
from typing import Dict, Tuple
|
| 5 |
import ftfy
|
|
@@ -94,7 +95,21 @@ def maybe_backtranslate(text: str, ratio: float, paraphraser) -> Tuple[str, bool
|
|
| 94 |
if ratio <= 0 or not text: return text, False
|
| 95 |
if random.random() < ratio:
|
| 96 |
bt = paraphraser.backtranslate(text, via_lang="vi")
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
return text, False
|
| 99 |
|
| 100 |
def consistency_ok(user: str, out: str, ratio: float, paraphraser) -> bool:
|
|
|
|
| 1 |
# augmentation utility agent
|
| 2 |
import re
|
| 3 |
+
import difflib
|
| 4 |
import random
|
| 5 |
from typing import Dict, Tuple
|
| 6 |
import ftfy
|
|
|
|
| 95 |
if ratio <= 0 or not text: return text, False
|
| 96 |
if random.random() < ratio:
|
| 97 |
bt = paraphraser.backtranslate(text, via_lang="vi")
|
| 98 |
+
if not bt:
|
| 99 |
+
return text, False
|
| 100 |
+
# Guardrails: reject if too short/long or too dissimilar/similar
|
| 101 |
+
try:
|
| 102 |
+
orig_len = max(1, len(text))
|
| 103 |
+
len_delta = abs(len(bt) - len(text)) / orig_len
|
| 104 |
+
sim = difflib.SequenceMatcher(None, text, bt).ratio()
|
| 105 |
+
# Accept if moderate change and not excessive drift
|
| 106 |
+
if len_delta > 0.5:
|
| 107 |
+
return text, False
|
| 108 |
+
if sim < 0.45 or sim > 0.98:
|
| 109 |
+
return text, False
|
| 110 |
+
except Exception:
|
| 111 |
+
pass
|
| 112 |
+
return bt, True
|
| 113 |
return text, False
|
| 114 |
|
| 115 |
def consistency_ok(user: str, out: str, ratio: float, paraphraser) -> bool:
|
utils/rag.py
CHANGED
|
@@ -44,7 +44,7 @@ class RAGProcessor:
|
|
| 44 |
self.nvidia_client = NvidiaClient(KeyRotator("NVIDIA_API"), nvidia_model)
|
| 45 |
|
| 46 |
def clean_conversational_content(self, text: str) -> str:
|
| 47 |
-
"""Remove conversational elements and non-medical information using NVIDIA model"""
|
| 48 |
if not text or len(text.strip()) < 10:
|
| 49 |
return text
|
| 50 |
|
|
@@ -55,7 +55,7 @@ class RAGProcessor:
|
|
| 55 |
3. Keep only medically relevant information
|
| 56 |
4. Preserve clinical facts, symptoms, diagnoses, treatments, and medical advice
|
| 57 |
5. Maintain professional medical language
|
| 58 |
-
6. Return only cleaned medical content
|
| 59 |
|
| 60 |
Text to clean:
|
| 61 |
{text}
|
|
@@ -74,11 +74,11 @@ class RAGProcessor:
|
|
| 74 |
return text
|
| 75 |
|
| 76 |
def generate_context_from_qa(self, question: str, answer: str) -> str:
|
| 77 |
-
"""Generate synthetic context from question and answer
|
| 78 |
if not question or not answer:
|
| 79 |
return ""
|
| 80 |
|
| 81 |
-
prompt = f"""You are a medical knowledge expert. Given a medical question and its answer, generate a brief relevant medical context that
|
| 82 |
|
| 83 |
Question: {question}
|
| 84 |
|
|
@@ -92,16 +92,20 @@ class RAGProcessor:
|
|
| 92 |
temperature=0.2,
|
| 93 |
max_tokens=200
|
| 94 |
)
|
| 95 |
-
|
|
|
|
| 96 |
except Exception as e:
|
| 97 |
logger.warning(f"[RAG] Error generating context: {e}")
|
| 98 |
return ""
|
| 99 |
|
| 100 |
def convert_to_qca_format(self, instruction: str, user_input: str, output: str) -> Tuple[str, str, str]:
|
| 101 |
-
"""Convert SFT format to QCA (Question, Context, Answer) format"""
|
| 102 |
# Clean the content to remove conversational elements
|
| 103 |
cleaned_input = self.clean_conversational_content(user_input)
|
| 104 |
cleaned_output = self.clean_conversational_content(output)
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
# Extract question from user input
|
| 107 |
question = self.extract_question(cleaned_input)
|
|
@@ -110,7 +114,8 @@ class RAGProcessor:
|
|
| 110 |
context = self.extract_context(cleaned_input, question, cleaned_output)
|
| 111 |
|
| 112 |
# Clean answer
|
| 113 |
-
|
|
|
|
| 114 |
|
| 115 |
return question, context, answer
|
| 116 |
|
|
|
|
| 44 |
self.nvidia_client = NvidiaClient(KeyRotator("NVIDIA_API"), nvidia_model)
|
| 45 |
|
| 46 |
def clean_conversational_content(self, text: str) -> str:
|
| 47 |
+
"""Remove conversational elements and non-medical information using NVIDIA model; keep concise for embeddings."""
|
| 48 |
if not text or len(text.strip()) < 10:
|
| 49 |
return text
|
| 50 |
|
|
|
|
| 55 |
3. Keep only medically relevant information
|
| 56 |
4. Preserve clinical facts, symptoms, diagnoses, treatments, and medical advice
|
| 57 |
5. Maintain professional medical language
|
| 58 |
+
6. Return only cleaned medical content in 1-2 concise sentences suitable for dense retrieval embeddings. No lists, no headers.
|
| 59 |
|
| 60 |
Text to clean:
|
| 61 |
{text}
|
|
|
|
| 74 |
return text
|
| 75 |
|
| 76 |
def generate_context_from_qa(self, question: str, answer: str) -> str:
|
| 77 |
+
"""Generate synthetic, concise context (<=2 sentences) from question and answer, embedding-friendly."""
|
| 78 |
if not question or not answer:
|
| 79 |
return ""
|
| 80 |
|
| 81 |
+
prompt = f"""You are a medical knowledge expert. Given a medical question and its answer, generate a brief relevant medical context that helps retrieval. Limit to 1–2 sentences, concise, avoid boilerplate, no enumerations.
|
| 82 |
|
| 83 |
Question: {question}
|
| 84 |
|
|
|
|
| 92 |
temperature=0.2,
|
| 93 |
max_tokens=200
|
| 94 |
)
|
| 95 |
+
# Trim to a single short paragraph
|
| 96 |
+
return (context or "").strip().split("\n")[0][:600]
|
| 97 |
except Exception as e:
|
| 98 |
logger.warning(f"[RAG] Error generating context: {e}")
|
| 99 |
return ""
|
| 100 |
|
| 101 |
def convert_to_qca_format(self, instruction: str, user_input: str, output: str) -> Tuple[str, str, str]:
|
| 102 |
+
"""Convert SFT format to QCA (Question, Context, Answer) format, compressing for embedding suitability."""
|
| 103 |
# Clean the content to remove conversational elements
|
| 104 |
cleaned_input = self.clean_conversational_content(user_input)
|
| 105 |
cleaned_output = self.clean_conversational_content(output)
|
| 106 |
+
# Hard caps for embedding friendliness
|
| 107 |
+
cleaned_input = (cleaned_input or "")[:1200]
|
| 108 |
+
cleaned_output = (cleaned_output or "")[:1200]
|
| 109 |
|
| 110 |
# Extract question from user input
|
| 111 |
question = self.extract_question(cleaned_input)
|
|
|
|
| 114 |
context = self.extract_context(cleaned_input, question, cleaned_output)
|
| 115 |
|
| 116 |
# Clean answer
|
| 117 |
+
# Prefer short, direct answers
|
| 118 |
+
answer = cleaned_output[:800]
|
| 119 |
|
| 120 |
return question, context, answer
|
| 121 |
|
vi/processing.py
CHANGED
|
@@ -7,6 +7,30 @@ from typing import Dict, Any, List, Optional, Callable
|
|
| 7 |
|
| 8 |
logger = logging.getLogger(__name__)
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
def translate_sft_row(row: Dict[str, Any], translator, text_fields: List[str] = None) -> Dict[str, Any]:
|
| 11 |
"""
|
| 12 |
Translate specific text fields in an SFT row from English to Vietnamese.
|
|
@@ -29,6 +53,10 @@ def translate_sft_row(row: Dict[str, Any], translator, text_fields: List[str] =
|
|
| 29 |
|
| 30 |
try:
|
| 31 |
translated_row = translator.translate_dict(row, text_fields)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
logger.debug(f"Translated SFT row with fields: {text_fields}")
|
| 33 |
return translated_row
|
| 34 |
except Exception as e:
|
|
@@ -52,11 +80,15 @@ def translate_rag_row(row: Dict[str, Any], translator, text_fields: List[str] =
|
|
| 52 |
return row
|
| 53 |
|
| 54 |
if text_fields is None:
|
| 55 |
-
# Default fields to translate in RAG format
|
| 56 |
-
text_fields = ["
|
| 57 |
|
| 58 |
try:
|
| 59 |
translated_row = translator.translate_dict(row, text_fields)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
logger.debug(f"Translated RAG row with fields: {text_fields}")
|
| 61 |
return translated_row
|
| 62 |
except Exception as e:
|
|
|
|
| 7 |
|
| 8 |
logger = logging.getLogger(__name__)
|
| 9 |
|
| 10 |
+
def _vi_sanitize_text(s: str) -> str:
|
| 11 |
+
"""Light Vietnamese sanitization for finetuning and RAG: strip extra spaces, limit repetition, preserve numbers/units."""
|
| 12 |
+
if not isinstance(s, str):
|
| 13 |
+
return s
|
| 14 |
+
t = s.strip()
|
| 15 |
+
# Collapse repeated punctuation and whitespace
|
| 16 |
+
import re
|
| 17 |
+
t = re.sub(r"\s+", " ", t)
|
| 18 |
+
t = re.sub(r"([.?!]){3,}", r"..", t)
|
| 19 |
+
# Remove obvious repetition chunks (very heuristic)
|
| 20 |
+
parts = t.split()
|
| 21 |
+
if len(parts) > 20:
|
| 22 |
+
window = 6
|
| 23 |
+
seen = set()
|
| 24 |
+
filtered = []
|
| 25 |
+
for i in range(len(parts)):
|
| 26 |
+
ngram = " ".join(parts[max(0, i-window):i+1])
|
| 27 |
+
if ngram in seen:
|
| 28 |
+
continue
|
| 29 |
+
seen.add(ngram)
|
| 30 |
+
filtered.append(parts[i])
|
| 31 |
+
t = " ".join(filtered)
|
| 32 |
+
return t
|
| 33 |
+
|
| 34 |
def translate_sft_row(row: Dict[str, Any], translator, text_fields: List[str] = None) -> Dict[str, Any]:
|
| 35 |
"""
|
| 36 |
Translate specific text fields in an SFT row from English to Vietnamese.
|
|
|
|
| 53 |
|
| 54 |
try:
|
| 55 |
translated_row = translator.translate_dict(row, text_fields)
|
| 56 |
+
# Sanitize translated fields
|
| 57 |
+
for f in text_fields:
|
| 58 |
+
if f in translated_row.get("sft", {}):
|
| 59 |
+
translated_row["sft"][f] = _vi_sanitize_text(translated_row["sft"][f])
|
| 60 |
logger.debug(f"Translated SFT row with fields: {text_fields}")
|
| 61 |
return translated_row
|
| 62 |
except Exception as e:
|
|
|
|
| 80 |
return row
|
| 81 |
|
| 82 |
if text_fields is None:
|
| 83 |
+
# Default fields to translate in RAG format (Q, A, C)
|
| 84 |
+
text_fields = ["question", "answer", "context"]
|
| 85 |
|
| 86 |
try:
|
| 87 |
translated_row = translator.translate_dict(row, text_fields)
|
| 88 |
+
# Sanitize translated fields
|
| 89 |
+
for f in text_fields:
|
| 90 |
+
if f in translated_row:
|
| 91 |
+
translated_row[f] = _vi_sanitize_text(translated_row[f])
|
| 92 |
logger.debug(f"Translated RAG row with fields: {text_fields}")
|
| 93 |
return translated_row
|
| 94 |
except Exception as e:
|