File size: 18,231 Bytes
80cb919
 
 
 
 
e3a165a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80cb919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3a165a
80cb919
 
e3a165a
 
 
 
80cb919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e138b0e
80cb919
 
 
 
 
 
 
 
e138b0e
80cb919
 
 
e138b0e
 
 
80cb919
a062909
80cb919
 
 
a062909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80cb919
 
 
 
a062909
 
 
 
 
 
80cb919
 
a062909
 
 
 
 
 
 
 
 
 
 
 
 
 
80cb919
 
b0a3faf
80cb919
 
b0a3faf
d668aec
b0a3faf
 
 
d668aec
 
 
88e7ced
 
d668aec
 
 
88e7ced
 
d668aec
 
 
 
 
b0a3faf
d668aec
 
e138b0e
 
 
d668aec
 
e138b0e
 
 
 
80cb919
 
19d62ff
80cb919
d668aec
 
 
 
88e7ced
 
d668aec
 
 
88e7ced
 
d668aec
 
80cb919
 
 
 
19d62ff
80cb919
 
 
d668aec
 
 
 
88e7ced
 
d668aec
 
 
88e7ced
 
d668aec
 
80cb919
 
 
 
 
 
 
d668aec
80cb919
88e7ced
d668aec
88e7ced
80cb919
d668aec
80cb919
d668aec
80cb919
d668aec
 
 
 
 
 
 
88e7ced
d668aec
88e7ced
d668aec
 
 
 
 
 
 
 
 
 
 
 
 
88e7ced
 
d668aec
 
 
 
 
 
 
 
 
 
 
 
 
 
88e7ced
d668aec
 
 
88e7ced
d668aec
 
 
88e7ced
d668aec
 
 
88e7ced
d668aec
 
 
 
 
 
88e7ced
d668aec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
# Round-robin rotator + paraphrasing + translation/backtranslation
import os
import logging
import requests
from typing import Optional

# Dynamic import for Google GenAI (only when not in local mode)
def _import_google_genai():
    """Dynamically import Google GenAI only when needed"""
    try:
        from google import genai
        return genai
    except ImportError as e:
        raise ImportError(f"Google GenAI not available: {e}. Make sure IS_LOCAL=false and google-genai is installed.")

# Check if we're in local mode
IS_LOCAL = os.getenv("IS_LOCAL", "false").lower() == "true"

# Only import Google GenAI if not in local mode
if not IS_LOCAL:
    try:
        genai = _import_google_genai()
    except ImportError:
        genai = None
else:
    genai = None

logger = logging.getLogger("llm")
if not logger.handlers:
    logger.setLevel(logging.INFO)
    handler = logging.StreamHandler()
    logger.addHandler(handler)

# LLM parser limit text to log-out
def snip(s: str, n: int = 12) -> str:
    if not isinstance(s, str): return "∅"
    parts = s.strip().split()
    return " ".join(parts[:n]) + (" …" if len(parts) > n else "")

class KeyRotator:
    def __init__(self, env_prefix: str, max_keys: int = 5):
        keys = []
        for i in range(1, max_keys + 1):
            v = os.getenv(f"{env_prefix}_{i}")
            if v:
                keys.append(v.strip())
        if not keys:
            logger.warning(f"[LLM] No keys found for prefix {env_prefix}_*")
        self.keys = keys
        self.dead = set()
        self.idx = 0

    def next_key(self) -> Optional[str]:
        if not self.keys:
            return None
        for _ in range(len(self.keys)):
            k = self.keys[self.idx % len(self.keys)]
            self.idx += 1
            if k not in self.dead:
                return k
        return None

    def mark_bad(self, key: Optional[str]):
        if key:
            self.dead.add(key)
            logger.warning(f"[LLM] Quarantined key (prefix hidden): {key[:6]}***")

class GeminiClient:
    def __init__(self, rotator: KeyRotator, default_model: str):
        self.rotator = rotator
        self.default_model = default_model
        self.available = genai is not None and not IS_LOCAL

    def generate(self, prompt: str, model: Optional[str] = None, temperature: float = 0.2, max_output_tokens: int = 512) -> Optional[str]:
        if not self.available:
            logger.warning("[LLM][Gemini] Google GenAI not available (local mode or import failed)")
            return None
            
        key = self.rotator.next_key()
        if not key:
            return None
        try:
            client = genai.Client(api_key=key)
            # NOTE: matches your required pattern/use
            res = client.models.generate_content(
                model=model or self.default_model,
                contents=prompt
            )
            text = getattr(res, "text", None)
            if text:
                logger.info(f"[LLM][Gemini] out={snip(text)}")
            return text
        except Exception as e:
            logger.error(f"[LLM][Gemini] {e}")
            self.rotator.mark_bad(key)
            return None

class NvidiaClient:
    def __init__(self, rotator: KeyRotator, default_model: str):
        self.rotator = rotator
        self.default_model = default_model
        self.url = os.getenv("NVIDIA_API_URL", "https://integrate.api.nvidia.com/v1/chat/completions")

    # Regex-based cleaning resp from quotes
    def _clean_resp(self, resp: str) -> str:
        if not resp: return resp
        txt = resp.strip()
        # Remove common boilerplate prefixes
        for pat in [
            r"^Here is (a|the) .*?:\s*",
            r"^Paraphrased(?: version)?:\s*",
            r"^Sure[,.]?\s*",
            r"^Okay[,.]?\s*"
        ]:
            import re
            txt = re.sub(pat, "", txt, flags=re.I)
        return txt.strip()

    def generate(self, prompt: str, model: Optional[str] = None, temperature: float = 0.2, max_tokens: int = 512) -> Optional[str]:
        key = self.rotator.next_key()
        if not key:
            return None
        try:
            headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"}
            payload = {
                "model": model or self.default_model,
                "messages": [{"role": "user", "content": prompt}],
                "temperature": temperature,
                "max_tokens": max_tokens
            }
            r = requests.post(self.url, headers=headers, json=payload, timeout=45)
            if r.status_code >= 400:
                raise RuntimeError(f"HTTP {r.status_code}: {r.text[:200]}")
            data = r.json()
            text = data["choices"][0]["message"]["content"]
            clean = self._clean_resp(text)
            # Log the output here
            logger.info(f"[LLM][NVIDIA] out={snip(clean)}")
            return clean        
        except Exception as e:
            logger.error(f"[LLM][NVIDIA] {e}")
            self.rotator.mark_bad(key)
            return None

class Paraphraser:
    """Prefers NVIDIA (cheap), falls back to Gemini EASY only. Also offers translate/backtranslate and a tiny consistency judge."""
    def __init__(self, nvidia_model: str, gemini_model_easy: str, gemini_model_hard: str):
        self.nv = NvidiaClient(KeyRotator("NVIDIA_API"), nvidia_model)
        self.gm_easy = GeminiClient(KeyRotator("GEMINI_API"), gemini_model_easy)
        # Only use GEMINI_MODEL_EASY, ignore hard model completely
        self.gm_hard = None  # Disabled - only use easy model
        logger.info("Paraphraser initialized: NVIDIA -> GEMINI_EASY (GEMINI_HARD disabled)")

    # Enhanced cleaning to remove conversational elements and comments
    def _clean_resp(self, resp: str) -> str:
        if not resp: return resp
        txt = resp.strip()
        
        # Remove common conversational prefixes and comments
        prefixes_to_remove = [
            "Here's a rewritten version of",
            "Here is a rewritten version of",
            "Here's the rewritten text:",
            "Here is the rewritten text:",
            "Here's the translation:",
            "Here is the translation:",
            "Here's the enhanced text:",
            "Here is the enhanced text:",
            "Here's the improved text:",
            "Here is the improved text:",
            "Here's the medical context:",
            "Here is the medical context:",
            "Here's the cleaned text:",
            "Here is the cleaned text:",
            "Here's the answer:",
            "Here is the answer:",
            "Here's a paraphrased version:",
            "Here is a paraphrased version:",
            "Paraphrased version:",
            "Paraphrased:",
            "Sure,",
            "Okay,",
            "Certainly,",
            "Of course,",
            "I can help you with that.",
            "I'll help you with that.",
            "Let me help you with that.",
            "I can rewrite that for you.",
            "I'll rewrite that for you.",
            "Let me rewrite that for you.",
            "I can translate that for you.",
            "I'll translate that for you.",
            "Let me translate that for you.",
        ]
        
        # Remove prefixes
        for prefix in prefixes_to_remove:
            if txt.lower().startswith(prefix.lower()):
                txt = txt[len(prefix):].strip()
                break
        
        # Remove common boilerplate prefixes with regex
        import re
        for pat in [
            r"^Here is (a|the) .*?:\s*",
            r"^Paraphrased(?: version)?:\s*",
            r"^Sure[,.]?\s*",
            r"^Okay[,.]?\s*",
            r"^Certainly[,.]?\s*",
            r"^Of course[,.]?\s*",
            r"^I can .*?:\s*",
            r"^I'll .*?:\s*",
            r"^Let me .*?:\s*"
        ]:
            txt = re.sub(pat, "", txt, flags=re.I)
        
        # Remove any remaining conversational elements
        lines = txt.split('\n')
        cleaned_lines = []
        for line in lines:
            line = line.strip()
            if line and not any(phrase in line.lower() for phrase in [
                "here's", "here is", "let me", "i can", "i'll", "sure,", "okay,", 
                "certainly,", "of course,", "i hope this helps", "hope this helps",
                "does this help", "is this what you", "let me know if"
            ]):
                cleaned_lines.append(line)
        
        return '\n'.join(cleaned_lines).strip()

    # ————— Paraphrase —————
    def paraphrase(self, text: str, difficulty: str = "easy", custom_prompt: str = None) -> str:
        if not text or len(text) < 12:
            return text
        
        # Use custom prompt if provided, otherwise use optimized medical prompts
        if custom_prompt:
            prompt = custom_prompt
        else:
            # Optimized medical paraphrasing prompts based on difficulty
            if difficulty == "easy":
                prompt = (
                    "Rewrite the following medical text using different words while preserving all medical facts, clinical terms, and meaning. Keep the same level of detail and accuracy. Return only the rewritten text without any introduction or commentary.\n\n"
                    f"{text}"
                )
            else:  # hard difficulty
                prompt = (
                    "Rewrite the following medical text using more sophisticated medical language and different sentence structures while preserving all clinical facts, medical terminology, and diagnostic information. Maintain professional medical tone. Return only the rewritten text without any introduction or commentary.\n\n"
                    f"{text}"
                )
        
        # Optimize temperature and token limits based on difficulty
        temperature = 0.1 if difficulty == "easy" else 0.3
        max_tokens = min(600, max(128, len(text)//2))
        
        # Always try NVIDIA first (optimized for medical tasks)
        out = self.nv.generate(prompt, temperature=temperature, max_tokens=max_tokens)
        if out: 
            return self._clean_resp(out)
        
        # Fallback to GEMINI with optimized parameters
        out = self.gm_easy.generate(prompt, max_output_tokens=max_tokens)
        if out:
            logger.info(f"[LLM][GEMINI] out={snip(self._clean_resp(out))}")
            return self._clean_resp(out)
        return text

    # ————— Translate & Backtranslate —————
    def translate(self, text: str, target_lang: str = "vi") -> Optional[str]:
        if not text: return text
        
        # Optimized medical translation prompts
        if target_lang == "vi":
            prompt = (
                "Translate the following English medical text to Vietnamese while preserving all medical terminology, clinical facts, and professional medical language. Use appropriate Vietnamese medical terms. Return only the translation without any introduction or commentary.\n\n"
                f"{text}"
            )
        else:
            prompt = (
                f"Translate the following medical text to {target_lang} while preserving all medical terminology, clinical facts, and professional medical language. Return only the translation without any introduction or commentary.\n\n"
                f"{text}"
            )
        
        out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(800, len(text)+100))
        if out: return out.strip()
        return self.gm_easy.generate(prompt, max_output_tokens=min(800, len(text)+100))

    def backtranslate(self, text: str, via_lang: str = "vi") -> Optional[str]:
        if not text: return text
        mid = self.translate(text, target_lang=via_lang)
        if not mid: return None
        
        # Optimized backtranslation prompt with medical focus
        if via_lang == "vi":
            prompt = (
                "Translate the following Vietnamese medical text back to English while preserving all medical terminology, clinical facts, and professional medical language. Ensure the translation is medically accurate. Return only the translation without any introduction or commentary.\n\n"
                f"{mid}"
            )
        else:
            prompt = (
                f"Translate the following {via_lang} medical text back to English while preserving all medical terminology, clinical facts, and professional medical language. Return only the translation without any introduction or commentary.\n\n"
                f"{mid}"
            )
        
        out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(900, len(text)+150))
        if out: return out.strip()
        res = self.gm_easy.generate(prompt, max_output_tokens=min(900, len(text)+150))
        return res.strip() if res else None

    # ————— Consistency Judge (cheap, ratio-based) —————
    def consistency_check(self, user: str, output: str) -> bool:
        """Return True if 'output' appears supported by 'user' (context/question). Optimized medical validation."""
        prompt = (
            "Evaluate if the medical answer is consistent with the question/context and medically accurate. Consider medical accuracy, clinical appropriateness, consistency with the question, safety standards, and completeness of medical information. Reply with exactly 'PASS' if the answer is medically sound and consistent, otherwise 'FAIL'.\n\n"
            f"Question/Context: {user}\n\n"
            f"Medical Answer: {output}"
        )
        out = self.nv.generate(prompt, temperature=0.0, max_tokens=5)
        if not out:
            out = self.gm_easy.generate(prompt, max_output_tokens=5)
        return isinstance(out, str) and "PASS" in out.upper()
    
    def medical_accuracy_check(self, question: str, answer: str) -> bool:
        """Check medical accuracy of Q&A pairs using cloud APIs"""
        if not question or not answer:
            return False
            
        prompt = (
            "Evaluate if the medical answer is accurate and appropriate for the question. Consider medical facts, clinical knowledge, appropriate medical terminology, clinical reasoning, logic, and safety considerations. Reply with exactly 'ACCURATE' if the answer is medically correct, otherwise 'INACCURATE'.\n\n"
            f"Medical Question: {question}\n\n"
            f"Medical Answer: {answer}"
        )
        
        out = self.nv.generate(prompt, temperature=0.0, max_tokens=5)
        if not out:
            out = self.gm_easy.generate(prompt, max_output_tokens=5)
        return isinstance(out, str) and "ACCURATE" in out.upper()
    
    def enhance_medical_terminology(self, text: str) -> str:
        """Enhance medical terminology in text using cloud APIs"""
        if not text or len(text) < 20:
            return text
            
        prompt = (
            "Improve the medical terminology in the following text while preserving all factual information and clinical accuracy. Use more precise medical terms where appropriate. Return only the improved text without any introduction or commentary.\n\n"
            f"{text}"
        )
        
        out = self.nv.generate(prompt, temperature=0.1, max_tokens=min(800, len(text)+100))
        if not out:
            out = self.gm_easy.generate(prompt, max_output_tokens=min(800, len(text)+100))
        return out if out else text
    
    def create_clinical_scenarios(self, question: str, answer: str) -> list:
        """Create different clinical scenarios from Q&A pairs using cloud APIs"""
        scenarios = []
        
        # Different clinical context prompts
        context_prompts = [
            (
                "Rewrite this medical question as if asked by a patient in an emergency room setting. Return only the rewritten question without any introduction or commentary:\n\n{question}",
                "emergency_room"
            ),
            (
                "Rewrite this medical question as if asked by a patient during a routine checkup. Return only the rewritten question without any introduction or commentary:\n\n{question}",
                "routine_checkup"
            ),
            (
                "Rewrite this medical question as if asked by a patient with chronic conditions. Return only the rewritten question without any introduction or commentary:\n\n{question}",
                "chronic_care"
            ),
            (
                "Rewrite this medical question as if asked by a patient's family member. Return only the rewritten question without any introduction or commentary:\n\n{question}",
                "family_inquiry"
            )
        ]
        
        for prompt_template, scenario_type in context_prompts:
            try:
                prompt = prompt_template.format(question=question)
                scenario_question = self.paraphrase(question, difficulty="hard", custom_prompt=prompt)
                
                if scenario_question and not self._is_invalid_response(scenario_question):
                    scenarios.append((scenario_question, answer, scenario_type))
            except Exception as e:
                logger.warning(f"Failed to create clinical scenario {scenario_type}: {e}")
                continue
                
        return scenarios
    
    def _is_invalid_response(self, text: str) -> bool:
        """Check if response is invalid"""
        if not text or not isinstance(text, str):
            return True
        
        text_lower = text.lower().strip()
        invalid_patterns = [
            "fail", "invalid", "i couldn't", "i can't", "i cannot", "unable to",
            "sorry", "error", "not available", "no answer", "insufficient",
            "don't know", "do not know", "not sure", "cannot determine",
            "unable to provide", "not possible", "not applicable", "n/a"
        ]
        
        if len(text_lower) < 3:
            return True
        
        for pattern in invalid_patterns:
            if pattern in text_lower:
                return True
        
        return False