File size: 15,295 Bytes
80cb919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e138b0e
80cb919
 
 
 
 
 
 
 
e138b0e
80cb919
 
 
e138b0e
 
 
80cb919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0a3faf
80cb919
 
b0a3faf
d668aec
b0a3faf
 
 
d668aec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0a3faf
d668aec
 
e138b0e
 
 
d668aec
 
e138b0e
 
 
 
80cb919
 
19d62ff
80cb919
d668aec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80cb919
 
 
 
19d62ff
80cb919
 
 
d668aec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80cb919
 
 
 
 
 
 
d668aec
80cb919
d668aec
 
 
 
 
 
 
 
 
80cb919
d668aec
80cb919
d668aec
80cb919
d668aec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
# Round-robin rotator + paraphrasing + translation/backtranslation
import os
import logging
import requests
from typing import Optional
from google import genai

logger = logging.getLogger("llm")
if not logger.handlers:
    logger.setLevel(logging.INFO)
    handler = logging.StreamHandler()
    logger.addHandler(handler)

# LLM parser limit text to log-out
def snip(s: str, n: int = 12) -> str:
    if not isinstance(s, str): return "∅"
    parts = s.strip().split()
    return " ".join(parts[:n]) + (" …" if len(parts) > n else "")

class KeyRotator:
    def __init__(self, env_prefix: str, max_keys: int = 5):
        keys = []
        for i in range(1, max_keys + 1):
            v = os.getenv(f"{env_prefix}_{i}")
            if v:
                keys.append(v.strip())
        if not keys:
            logger.warning(f"[LLM] No keys found for prefix {env_prefix}_*")
        self.keys = keys
        self.dead = set()
        self.idx = 0

    def next_key(self) -> Optional[str]:
        if not self.keys:
            return None
        for _ in range(len(self.keys)):
            k = self.keys[self.idx % len(self.keys)]
            self.idx += 1
            if k not in self.dead:
                return k
        return None

    def mark_bad(self, key: Optional[str]):
        if key:
            self.dead.add(key)
            logger.warning(f"[LLM] Quarantined key (prefix hidden): {key[:6]}***")

class GeminiClient:
    def __init__(self, rotator: KeyRotator, default_model: str):
        self.rotator = rotator
        self.default_model = default_model

    def generate(self, prompt: str, model: Optional[str] = None, temperature: float = 0.2, max_output_tokens: int = 512) -> Optional[str]:
        key = self.rotator.next_key()
        if not key:
            return None
        try:
            client = genai.Client(api_key=key)
            # NOTE: matches your required pattern/use
            res = client.models.generate_content(
                model=model or self.default_model,
                contents=prompt
            )
            text = getattr(res, "text", None)
            if text:
                logger.info(f"[LLM][Gemini] out={snip(text)}")
            return text
        except Exception as e:
            logger.error(f"[LLM][Gemini] {e}")
            self.rotator.mark_bad(key)
            return None

class NvidiaClient:
    def __init__(self, rotator: KeyRotator, default_model: str):
        self.rotator = rotator
        self.default_model = default_model
        self.url = os.getenv("NVIDIA_API_URL", "https://integrate.api.nvidia.com/v1/chat/completions")

    # Regex-based cleaning resp from quotes
    def _clean_resp(self, resp: str) -> str:
        if not resp: return resp
        txt = resp.strip()
        # Remove common boilerplate prefixes
        for pat in [
            r"^Here is (a|the) .*?:\s*",
            r"^Paraphrased(?: version)?:\s*",
            r"^Sure[,.]?\s*",
            r"^Okay[,.]?\s*"
        ]:
            import re
            txt = re.sub(pat, "", txt, flags=re.I)
        return txt.strip()

    def generate(self, prompt: str, model: Optional[str] = None, temperature: float = 0.2, max_tokens: int = 512) -> Optional[str]:
        key = self.rotator.next_key()
        if not key:
            return None
        try:
            headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"}
            payload = {
                "model": model or self.default_model,
                "messages": [{"role": "user", "content": prompt}],
                "temperature": temperature,
                "max_tokens": max_tokens
            }
            r = requests.post(self.url, headers=headers, json=payload, timeout=45)
            if r.status_code >= 400:
                raise RuntimeError(f"HTTP {r.status_code}: {r.text[:200]}")
            data = r.json()
            text = data["choices"][0]["message"]["content"]
            clean = self._clean_resp(text)
            # Log the output here
            logger.info(f"[LLM][NVIDIA] out={snip(clean)}")
            return clean        
        except Exception as e:
            logger.error(f"[LLM][NVIDIA] {e}")
            self.rotator.mark_bad(key)
            return None

class Paraphraser:
    """Prefers NVIDIA (cheap), falls back to Gemini EASY only. Also offers translate/backtranslate and a tiny consistency judge."""
    def __init__(self, nvidia_model: str, gemini_model_easy: str, gemini_model_hard: str):
        self.nv = NvidiaClient(KeyRotator("NVIDIA_API"), nvidia_model)
        self.gm_easy = GeminiClient(KeyRotator("GEMINI_API"), gemini_model_easy)
        # Only use GEMINI_MODEL_EASY, ignore hard model completely
        self.gm_hard = None  # Disabled - only use easy model
        logger.info("Paraphraser initialized: NVIDIA -> GEMINI_EASY (GEMINI_HARD disabled)")

    # Regex-based cleaning resp from quotes
    def _clean_resp(self, resp: str) -> str:
        if not resp: return resp
        txt = resp.strip()
        # Remove common boilerplate prefixes
        for pat in [
            r"^Here is (a|the) .*?:\s*",
            r"^Paraphrased(?: version)?:\s*",
            r"^Sure[,.]?\s*",
            r"^Okay[,.]?\s*"
        ]:
            import re
            txt = re.sub(pat, "", txt, flags=re.I)
        return txt.strip()

    # ————— Paraphrase —————
    def paraphrase(self, text: str, difficulty: str = "easy", custom_prompt: str = None) -> str:
        if not text or len(text) < 12:
            return text
        
        # Use custom prompt if provided, otherwise use optimized medical prompts
        if custom_prompt:
            prompt = custom_prompt
        else:
            # Optimized medical paraphrasing prompts based on difficulty
            if difficulty == "easy":
                prompt = (
                    "You are a medical professional. Rewrite the following medical text using different words while preserving all medical facts, clinical terms, and meaning. Keep the same level of detail and accuracy.\n\n"
                    f"Original medical text: {text}\n\n"
                    "Rewritten medical text:"
                )
            else:  # hard difficulty
                prompt = (
                    "You are a medical expert. Rewrite the following medical text using more sophisticated medical language and different sentence structures while preserving all clinical facts, medical terminology, and diagnostic information. Maintain professional medical tone.\n\n"
                    f"Original medical text: {text}\n\n"
                    "Enhanced medical text:"
                )
        
        # Optimize temperature and token limits based on difficulty
        temperature = 0.1 if difficulty == "easy" else 0.3
        max_tokens = min(600, max(128, len(text)//2))
        
        # Always try NVIDIA first (optimized for medical tasks)
        out = self.nv.generate(prompt, temperature=temperature, max_tokens=max_tokens)
        if out: 
            return self._clean_resp(out)
        
        # Fallback to GEMINI with optimized parameters
        out = self.gm_easy.generate(prompt, max_output_tokens=max_tokens)
        if out:
            logger.info(f"[LLM][GEMINI] out={snip(self._clean_resp(out))}")
            return self._clean_resp(out)
        return text

    # ————— Translate & Backtranslate —————
    def translate(self, text: str, target_lang: str = "vi") -> Optional[str]:
        if not text: return text
        
        # Optimized medical translation prompts
        if target_lang == "vi":
            prompt = (
                "You are a medical translator. Translate the following English medical text to Vietnamese while preserving all medical terminology, clinical facts, and professional medical language. Use appropriate Vietnamese medical terms.\n\n"
                f"English medical text: {text}\n\n"
                "Vietnamese medical translation:"
            )
        else:
            prompt = (
                f"You are a medical translator. Translate the following medical text to {target_lang} while preserving all medical terminology, clinical facts, and professional medical language.\n\n"
                f"Original medical text: {text}\n\n"
                f"{target_lang} medical translation:"
            )
        
        out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(800, len(text)+100))
        if out: return out.strip()
        return self.gm_easy.generate(prompt, max_output_tokens=min(800, len(text)+100))

    def backtranslate(self, text: str, via_lang: str = "vi") -> Optional[str]:
        if not text: return text
        mid = self.translate(text, target_lang=via_lang)
        if not mid: return None
        
        # Optimized backtranslation prompt with medical focus
        if via_lang == "vi":
            prompt = (
                "You are a medical translator. Translate the following Vietnamese medical text back to English while preserving all medical terminology, clinical facts, and professional medical language. Ensure the translation is medically accurate.\n\n"
                f"Vietnamese medical text: {mid}\n\n"
                "English medical translation:"
            )
        else:
            prompt = (
                f"You are a medical translator. Translate the following {via_lang} medical text back to English while preserving all medical terminology, clinical facts, and professional medical language.\n\n"
                f"{via_lang} medical text: {mid}\n\n"
                "English medical translation:"
            )
        
        out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(900, len(text)+150))
        if out: return out.strip()
        res = self.gm_easy.generate(prompt, max_output_tokens=min(900, len(text)+150))
        return res.strip() if res else None

    # ————— Consistency Judge (cheap, ratio-based) —————
    def consistency_check(self, user: str, output: str) -> bool:
        """Return True if 'output' appears supported by 'user' (context/question). Optimized medical validation."""
        prompt = (
            "You are a medical quality assurance expert. Evaluate if the medical answer is consistent with the question/context and medically accurate. Consider:\n"
            "1. Medical accuracy and clinical appropriateness\n"
            "2. Consistency with the question asked\n"
            "3. Safety and professional medical standards\n"
            "4. Completeness of the medical information\n\n"
            "Reply with exactly 'PASS' if the answer is medically sound and consistent, otherwise 'FAIL'.\n\n"
            f"Question/Context: {user}\n\n"
            f"Medical Answer: {output}\n\n"
            "Evaluation:"
        )
        out = self.nv.generate(prompt, temperature=0.0, max_tokens=5)
        if not out:
            out = self.gm_easy.generate(prompt, max_output_tokens=5)
        return isinstance(out, str) and "PASS" in out.upper()
    
    def medical_accuracy_check(self, question: str, answer: str) -> bool:
        """Check medical accuracy of Q&A pairs using cloud APIs"""
        if not question or not answer:
            return False
            
        prompt = (
            "You are a medical accuracy validator. Evaluate if the medical answer is accurate and appropriate for the question. Consider:\n"
            "1. Medical facts and clinical knowledge\n"
            "2. Appropriate medical terminology\n"
            "3. Clinical reasoning and logic\n"
            "4. Safety considerations\n\n"
            "Reply with exactly 'ACCURATE' if the answer is medically correct, otherwise 'INACCURATE'.\n\n"
            f"Medical Question: {question}\n\n"
            f"Medical Answer: {answer}\n\n"
            "Medical Accuracy Assessment:"
        )
        
        out = self.nv.generate(prompt, temperature=0.0, max_tokens=5)
        if not out:
            out = self.gm_easy.generate(prompt, max_output_tokens=5)
        return isinstance(out, str) and "ACCURATE" in out.upper()
    
    def enhance_medical_terminology(self, text: str) -> str:
        """Enhance medical terminology in text using cloud APIs"""
        if not text or len(text) < 20:
            return text
            
        prompt = (
            "You are a medical terminology expert. Improve the medical terminology in the following text while preserving all factual information and clinical accuracy. Use more precise medical terms where appropriate.\n\n"
            f"Original text: {text}\n\n"
            "Enhanced medical text:"
        )
        
        out = self.nv.generate(prompt, temperature=0.1, max_tokens=min(800, len(text)+100))
        if not out:
            out = self.gm_easy.generate(prompt, max_output_tokens=min(800, len(text)+100))
        return out if out else text
    
    def create_clinical_scenarios(self, question: str, answer: str) -> list:
        """Create different clinical scenarios from Q&A pairs using cloud APIs"""
        scenarios = []
        
        # Different clinical context prompts
        context_prompts = [
            (
                "Rewrite this medical question as if asked by a patient in an emergency room setting:",
                "emergency_room"
            ),
            (
                "Rewrite this medical question as if asked by a patient during a routine checkup:",
                "routine_checkup"
            ),
            (
                "Rewrite this medical question as if asked by a patient with chronic conditions:",
                "chronic_care"
            ),
            (
                "Rewrite this medical question as if asked by a patient's family member:",
                "family_inquiry"
            )
        ]
        
        for prompt_template, scenario_type in context_prompts:
            try:
                prompt = f"{prompt_template}\n\nOriginal question: {question}\n\nRewritten question:"
                scenario_question = self.paraphrase(question, difficulty="hard", custom_prompt=prompt)
                
                if scenario_question and not self._is_invalid_response(scenario_question):
                    scenarios.append((scenario_question, answer, scenario_type))
            except Exception as e:
                logger.warning(f"Failed to create clinical scenario {scenario_type}: {e}")
                continue
                
        return scenarios
    
    def _is_invalid_response(self, text: str) -> bool:
        """Check if response is invalid"""
        if not text or not isinstance(text, str):
            return True
        
        text_lower = text.lower().strip()
        invalid_patterns = [
            "fail", "invalid", "i couldn't", "i can't", "i cannot", "unable to",
            "sorry", "error", "not available", "no answer", "insufficient",
            "don't know", "do not know", "not sure", "cannot determine",
            "unable to provide", "not possible", "not applicable", "n/a"
        ]
        
        if len(text_lower) < 3:
            return True
        
        for pattern in invalid_patterns:
            if pattern in text_lower:
                return True
        
        return False