File size: 16,068 Bytes
80cb919
 
 
 
 
 
 
19d62ff
a89888b
e76f718
5dcfc82
 
80cb919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e76f718
 
 
 
 
 
 
 
80cb919
 
e76f718
80cb919
 
 
88e7ced
80cb919
88e7ced
80cb919
 
e76f718
 
 
 
 
 
 
 
 
 
 
 
80cb919
 
 
 
 
 
a7fd3ba
80cb919
 
 
88e7ced
80cb919
 
 
88e7ced
80cb919
 
e76f718
 
 
 
 
 
 
 
 
 
 
 
a7fd3ba
 
80cb919
 
 
 
 
a7fd3ba
80cb919
 
 
a7fd3ba
 
 
80cb919
 
 
 
 
 
 
 
a7fd3ba
 
80cb919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d46eb9
80cb919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dcfc82
 
 
 
 
 
 
 
 
80cb919
 
 
19d62ff
 
1d46eb9
80cb919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d46eb9
80cb919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dcfc82
 
 
 
 
 
 
 
 
80cb919
 
 
 
 
19d62ff
 
1d46eb9
80cb919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19d62ff
1d46eb9
19d62ff
80cb919
 
19d62ff
80cb919
 
 
 
19d62ff
 
 
 
1d46eb9
 
5dcfc82
f359dc2
 
 
 
1d46eb9
 
f359dc2
19d62ff
80cb919
 
 
 
 
 
 
 
 
 
 
1d46eb9
5dcfc82
e76f718
 
 
80cb919
 
 
 
 
 
 
 
 
 
 
e76f718
80cb919
 
 
1d46eb9
 
 
80cb919
 
 
 
1d46eb9
80cb919
 
 
 
 
1d46eb9
80cb919
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
# RAG-specific dataset processor
import json
import logging
import hashlib
import random
from typing import Dict, List, Tuple, Optional, Callable

from utils.schema import sft_row, rag_row
from utils.cloud_llm import NvidiaClient, KeyRotator
from utils.local_llm import MedAlpacaClient
from vi.processing import should_translate, translate_rag_row
from utils import augment as A

# Logger
logger = logging.getLogger("rag_processor")
if not logger.handlers:
    logger.setLevel(logging.INFO)
    logger.addHandler(logging.StreamHandler())

def _hash_id(*parts) -> str:
    """Generate a hash ID for RAG entries"""
    h = hashlib.sha256()
    for p in parts:
        h.update(str(p).encode("utf-8"))
    return h.hexdigest()[:16]

def _iter_json_or_jsonl(path: str):
    """Iterate over JSON or JSONL files"""
    with open(path, "r", encoding="utf-8") as f:
        first = f.read(1)
        f.seek(0)
        if first == "[":
            data = json.load(f)
            for obj in data:
                yield obj
        else:
            for line in f:
                line = line.strip()
                if line:
                    yield json.loads(line)

class RAGProcessor:
    """Processes medical datasets into RAG-specific QCA (Question, Context, Answer) format"""
    
    def __init__(self, nvidia_model: str, is_local: bool = False, hf_token: str = None):
        self.is_local = is_local
        if is_local:
            self.medalpaca_client = MedAlpacaClient(hf_token=hf_token)
            self.nvidia_client = None
        else:
            self.nvidia_client = NvidiaClient(KeyRotator("NVIDIA_API"), nvidia_model)
            self.medalpaca_client = None
        
    def clean_conversational_content(self, text: str) -> str:
        """Remove conversational elements and non-medical information using MedAlpaca or NVIDIA model; keep concise for embeddings."""
        if not text or len(text.strip()) < 10:
            return text
            
        prompt = f"""Clean the following text by removing conversational elements (greetings, pleasantries), non-medical small talk, and social interactions. Keep only medically relevant information while preserving clinical facts, symptoms, diagnoses, treatments, and medical advice. Maintain professional medical language. Return only cleaned medical content in 1-2 concise sentences suitable for dense retrieval embeddings. No lists, no headers, no introduction or commentary:

{text}"""

        try:
            if self.is_local and self.medalpaca_client:
                cleaned = self.medalpaca_client.generate(
                    prompt, 
                    temperature=0.1, 
                    max_tokens=min(1000, len(text) + 200)
                )
            else:
                cleaned = self.nvidia_client.generate(
                    prompt, 
                    temperature=0.1, 
                    max_tokens=min(1000, len(text) + 200)
                )
            return cleaned.strip() if cleaned else text
        except Exception as e:
            logger.warning(f"[RAG] Error cleaning text: {e}")
            return text
    
    def generate_context_from_qa(self, question: str, answer: str) -> str:
        """Generate synthetic, concise context (<=2 sentences) from question and answer, embedding-friendly."""
        if not question or not answer:
            return ""
            
        prompt = f"""Given a medical question and its answer, generate a brief relevant medical context that helps retrieval. Limit to 1–2 sentences, concise, avoid boilerplate, no enumerations. Return only the medical context without any introduction or commentary:

        Question: {question}

        Answer: {answer}"""

        try:
            if self.is_local and self.medalpaca_client:
                context = self.medalpaca_client.generate(
                    prompt,
                    temperature=0.2,
                    max_tokens=200
                )
            else:
                context = self.nvidia_client.generate(
                    prompt,
                    temperature=0.2,
                    max_tokens=200
                )
            # Trim to a single short paragraph
            return (context or "").strip().split("\n")[0][:600]
        except Exception as e:
            logger.warning(f"[RAG] Error generating context: {e}")
            return ""
    
    def convert_to_qca_format(self, instruction: str, user_input: str, output: str) -> Tuple[str, str, str]:
        """Convert SFT format to QCA (Question, Context, Answer) format, compressing for embedding suitability."""
        # Clean the content to remove conversational elements
        cleaned_input = self.clean_conversational_content(user_input)
        cleaned_output = self.clean_conversational_content(output)
        # Hard caps for embedding friendliness
        cleaned_input = (cleaned_input or "")[:1200]
        cleaned_output = (cleaned_output or "")[:1200]
        
        # Extract question from user input
        question = self.extract_question(cleaned_input)
        
        # Extract or generate context
        context = self.extract_context(cleaned_input, question, cleaned_output)
        
        # Clean answer
        # Prefer short, direct answers
        answer = cleaned_output[:800]
        
        return question, context, answer
    
    def extract_question(self, user_input: str) -> str:
        """Extract the main question from user input"""
        if not user_input:
            return ""
            
        # Try to identify question patterns
        lines = user_input.split('\n')
        for line in lines:
            line = line.strip()
            if line.startswith('Question:') or line.startswith('Q:'):
                return line.replace('Question:', '').replace('Q:', '').strip()
            elif '?' in line and len(line) > 10:
                return line
        
        # If no clear question found, use the first meaningful line
        for line in lines:
            line = line.strip()
            if len(line) > 10:
                return line
                
        return user_input
    
    def extract_context(self, user_input: str, question: str, answer: str) -> str:
        """Extract context from user input or generate synthetic context"""
        # Look for context in the original input
        context_candidates = []
        lines = user_input.split('\n')
        
        for line in lines:
            line = line.strip()
            if (line.startswith('Context:') or 
                line.startswith('Background:') or 
                line.startswith('Information:') or
                (len(line) > 50 and not line.startswith('Question:') and '?' not in line)):
                context_candidates.append(line)
        
        if context_candidates:
            # Clean and combine context candidates
            context = ' '.join(context_candidates)
            context = self.clean_conversational_content(context)
            if len(context) > 20:  # Ensure we have meaningful context
                return context
        
        # Generate synthetic context if none found
        if question and answer:
            synthetic_context = self.generate_context_from_qa(question, answer)
            if synthetic_context:
                return synthetic_context
        
        return ""
    
    def process_medical_dialog(self, source: str, path: str, writer, sample_limit: Optional[int], 
                             stats: Dict, progress_cb: Optional[Callable], dedupe_seen: set = None, translator=None, opts=None) -> int:
        """Process medical dialogue datasets into RAG format"""
        count = 0
        written = 0
        
        for i, obj in enumerate(_iter_json_or_jsonl(path), start=1):
            try:
                instr_raw = obj.get("instruction") or "Answer the medical question based on the provided context."
                user_raw = obj.get("input") or ""
                out_raw = obj.get("output") or ""
                
                instr = str(instr_raw).strip()
                user = str(user_raw).strip()
                out = str(out_raw).strip()
                rid = _hash_id(source, i, len(user), len(out))
                
                # Convert to QCA format
                question, context, answer = self.convert_to_qca_format(instr, user, out)
                
                # Clean invalid responses with retry logic
                if A.is_invalid_response(answer):
                    if paraphraser:
                        answer = A.retry_invalid_response(answer, paraphraser, max_retries=3)
                    else:
                        answer = A.clean_invalid_response(answer, "")
                    if not answer:  # If retry failed, skip this sample
                        continue
                
                if not question or not answer:
                    continue
                
                # Commit the RAG-formatted row (QAC)
                if self._commit_rag_row(writer, rid, question, context, answer,
                                      stats, dedupe_seen=dedupe_seen, translator=translator, opts=opts):
                    written += 1
                
                count += 1
                
            except Exception as e:
                logger.warning(f"[RAG] {source} error processing item {i}: {e}")
                continue
                
            if sample_limit and count >= sample_limit:
                break
            if progress_cb and i % 1000 == 0:
                progress_cb(min(0.9, 0.05 + i/200000), f"{source}: processed {i} rows for RAG")
        
        if progress_cb:
            progress_cb(0.92, f"{source} RAG processing done ({count})")
        
        logger.info(f"[RAG] {source} RAG processing done count={count} written={written}")
        return count
    
    def process_pubmedqa(self, source: str, path: str, writer, sample_limit: Optional[int], 
                        stats: Dict, progress_cb: Optional[Callable], dedupe_seen: set = None, translator=None, opts=None) -> int:
        """Process PubMedQA datasets into RAG format"""
        with open(path, "r", encoding="utf-8") as f:
            data = json.load(f)
        
        count = 0
        written = 0
        
        for k, v in data.items():
            try:
                q_raw = v.get("QUESTION") or ""
                ctx_list = v.get("CONTEXTS") or []
                long_ans_raw = v.get("LONG_ANSWER") or ""
                final_raw = v.get("final_decision") or ""
                
                question = str(q_raw).strip() if q_raw else ""
                if isinstance(ctx_list, list):
                    context = "\n".join(str(ctx) for ctx in ctx_list).strip()
                else:
                    context = str(ctx_list).strip()
                answer = str(long_ans_raw).strip() if long_ans_raw else str(final_raw).strip()
                
                if not question or not answer:
                    continue
                
                # Clean the content
                question = self.clean_conversational_content(question)
                context = self.clean_conversational_content(context)
                answer = self.clean_conversational_content(answer)
                
                # Clean invalid responses with retry logic
                if A.is_invalid_response(answer):
                    if paraphraser:
                        answer = A.retry_invalid_response(answer, paraphraser, max_retries=3)
                    else:
                        answer = A.clean_invalid_response(answer, "")
                    if not answer:  # If retry failed, skip this sample
                        continue
                
                # Generate context if missing
                if not context:
                    context = self.generate_context_from_qa(question, answer)
                
                rid = str(k)
                # Commit the RAG-formatted row (QAC)
                if self._commit_rag_row(writer, rid, question, context, answer,
                                      stats, dedupe_seen=dedupe_seen, translator=translator, opts=opts):
                    written += 1
                
                count += 1
                
            except Exception as e:
                logger.warning(f"[RAG] {source} error processing item {k}: {e}")
                continue
                
            if sample_limit and count >= sample_limit:
                break
            if progress_cb and count % 1000 == 0:
                progress_cb(min(0.9, 0.05 + count/60000), f"{source} RAG processed {count}")
        
        if progress_cb:
            progress_cb(0.93, f"{source} RAG processing done ({count})")
        
        logger.info(f"[RAG] {source} RAG processing done count={count} written={written}")
        return count
    
    def _commit_rag_row(self, writer, rid: str, question: str, context: str, answer: str, 
                       stats: Dict, dedupe_seen: set = None, translator=None, opts=None) -> bool:
        """Commit a RAG-formatted row (QAC) to the writer"""
        # Simple deduplication based on content hash
        if dedupe_seen is not None:
            content_hash = hashlib.md5(f"{question}{context}{answer}".encode()).hexdigest()
            if content_hash in dedupe_seen:
                stats["dedup_skipped"] = stats.get("dedup_skipped", 0) + 1
                return False
            dedupe_seen.add(content_hash)

        row = rag_row(question=question, context=context, answer=answer, rid=rid)

        # Apply Vietnamese translation if requested (translate Q/A/C fields directly)
        if should_translate(opts.get("vietnamese_translation", False) if opts else False, translator):
            try:
                row = translate_rag_row(row, translator, ["question", "answer", "context"])
                # Add translation metadata
                if "meta" not in row:
                    row["meta"] = {}
                row["meta"]["vietnamese_translated"] = True
            except Exception as e:
                logger.error(f"Failed to translate RAG row: {e}")
                # Continue with original row if translation fails

        writer.write(row)
        stats["written"] = stats.get("written", 0) + 1
        return True

def process_file_into_rag(
    dataset_key: str,
    input_path: str,
    writer,
    nvidia_model: str,
    sample_limit: Optional[int],
    seed: int,
    progress_cb: Optional[Callable[[float, str], None]],
    translator=None,
    paraphraser=None,
    is_local: bool = False,
    hf_token: str = None
) -> Tuple[int, Dict]:
    """Main entry point for RAG processing"""
    random.seed(seed)
    stats = {
        "written": 0,
        "dedup_skipped": 0
    }
    
    logger.info(f"[RAG] Begin RAG processing dataset={dataset_key} sample_limit={sample_limit}")
    
    # Initialize RAG processor
    rag_processor = RAGProcessor(nvidia_model, is_local=is_local, hf_token=hf_token)
    dedupe_seen = set()
    
    key = dataset_key.lower()
    # Create opts with Vietnamese translation flag
    opts = {"vietnamese_translation": translator is not None}
    
    if key in ("healthcaremagic", "icliniq"):
        count = rag_processor.process_medical_dialog(
            source=key, path=input_path, writer=writer,
            sample_limit=sample_limit, stats=stats, 
            progress_cb=progress_cb, dedupe_seen=dedupe_seen, translator=translator, opts=opts
        )
    elif key in ("pubmedqa_l", "pubmedqa_u", "pubmedqa_map"):
        count = rag_processor.process_pubmedqa(
            source=key, path=input_path, writer=writer,
            sample_limit=sample_limit, stats=stats, 
            progress_cb=progress_cb, dedupe_seen=dedupe_seen, translator=translator, opts=opts
        )
    else:
        raise ValueError(f"Unknown dataset for RAG processing: {dataset_key}")
    
    logger.info(f"[RAG] End RAG processing dataset={dataset_key} stats={stats}")
    return count, stats