# Round-robin rotator + paraphrasing + translation/backtranslation import os import logging import requests from typing import Optional from google import genai logger = logging.getLogger("llm") if not logger.handlers: logger.setLevel(logging.INFO) handler = logging.StreamHandler() logger.addHandler(handler) # LLM parser limit text to log-out def snip(s: str, n: int = 12) -> str: if not isinstance(s, str): return "∅" parts = s.strip().split() return " ".join(parts[:n]) + (" …" if len(parts) > n else "") class KeyRotator: def __init__(self, env_prefix: str, max_keys: int = 5): keys = [] for i in range(1, max_keys + 1): v = os.getenv(f"{env_prefix}_{i}") if v: keys.append(v.strip()) if not keys: logger.warning(f"[LLM] No keys found for prefix {env_prefix}_*") self.keys = keys self.dead = set() self.idx = 0 def next_key(self) -> Optional[str]: if not self.keys: return None for _ in range(len(self.keys)): k = self.keys[self.idx % len(self.keys)] self.idx += 1 if k not in self.dead: return k return None def mark_bad(self, key: Optional[str]): if key: self.dead.add(key) logger.warning(f"[LLM] Quarantined key (prefix hidden): {key[:6]}***") class GeminiClient: def __init__(self, rotator: KeyRotator, default_model: str): self.rotator = rotator self.default_model = default_model def generate(self, prompt: str, model: Optional[str] = None, temperature: float = 0.2, max_output_tokens: int = 512) -> Optional[str]: key = self.rotator.next_key() if not key: return None try: client = genai.Client(api_key=key) # NOTE: matches your required pattern/use res = client.models.generate_content( model=model or self.default_model, contents=prompt ) text = getattr(res, "text", None) if text: logger.info(f"[LLM][Gemini] out={snip(text)}") return text except Exception as e: logger.error(f"[LLM][Gemini] {e}") self.rotator.mark_bad(key) return None class NvidiaClient: def __init__(self, rotator: KeyRotator, default_model: str): self.rotator = rotator self.default_model = default_model self.url = os.getenv("NVIDIA_API_URL", "https://integrate.api.nvidia.com/v1/chat/completions") # Regex-based cleaning resp from quotes def _clean_resp(self, resp: str) -> str: if not resp: return resp txt = resp.strip() # Remove common boilerplate prefixes for pat in [ r"^Here is (a|the) .*?:\s*", r"^Paraphrased(?: version)?:\s*", r"^Sure[,.]?\s*", r"^Okay[,.]?\s*" ]: import re txt = re.sub(pat, "", txt, flags=re.I) return txt.strip() def generate(self, prompt: str, model: Optional[str] = None, temperature: float = 0.2, max_tokens: int = 512) -> Optional[str]: key = self.rotator.next_key() if not key: return None try: headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"} payload = { "model": model or self.default_model, "messages": [{"role": "user", "content": prompt}], "temperature": temperature, "max_tokens": max_tokens } r = requests.post(self.url, headers=headers, json=payload, timeout=45) if r.status_code >= 400: raise RuntimeError(f"HTTP {r.status_code}: {r.text[:200]}") data = r.json() text = data["choices"][0]["message"]["content"] clean = self._clean_resp(text) logger.info(f"[LLM][NVIDIA] out={snip(clean)}") return clean except Exception as e: logger.error(f"[LLM][NVIDIA] {e}") self.rotator.mark_bad(key) return None class Paraphraser: """Prefers NVIDIA (cheap), falls back to Gemini. Also offers translate/backtranslate and a tiny consistency judge.""" def __init__(self, nvidia_model: str, gemini_model_easy: str, gemini_model_hard: str): self.nv = NvidiaClient(KeyRotator("NVIDIA_API"), nvidia_model) self.gm_easy = GeminiClient(KeyRotator("GEMINI_API"), gemini_model_easy) self.gm_hard = GeminiClient(KeyRotator("GEMINI_API"), gemini_model_hard) # Regex-based cleaning resp from quotes def _clean_resp(self, resp: str) -> str: if not resp: return resp txt = resp.strip() # Remove common boilerplate prefixes for pat in [ r"^Here is (a|the) .*?:\s*", r"^Paraphrased(?: version)?:\s*", r"^Sure[,.]?\s*", r"^Okay[,.]?\s*" ]: import re txt = re.sub(pat, "", txt, flags=re.I) return txt.strip() # ————— Paraphrase ————— def paraphrase(self, text: str, difficulty: str = "easy") -> str: if not text or len(text) < 12: return text prompt = ( "Paraphrase the following medical text concisely, preserve meaning and clinical terms.\n" "Do not fabricate or remove factual claims.\n" "Return ONLY the rewritten text, without any introduction, commentary.\n"+ text ) out = self.nv.generate(prompt, temperature=0.1, max_tokens=min(600, max(128, len(text)//2))) if out: return self._clean_resp(out) gm = self.gm_easy if difficulty == "easy" else self.gm_hard out = gm.generate(prompt, max_output_tokens=min(600, max(128, len(text)//2))) return self._clean_resp(out) if out else text # ————— Translate & Backtranslate ————— def translate(self, text: str, target_lang: str = "vi") -> Optional[str]: if not text: return text prompt = f"Translate to {target_lang}. Keep meaning exact, preserve medical terms:\n\n{text}" out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(800, len(text)+100)) if out: return out.strip() return self.gm_easy.generate(prompt, max_output_tokens=min(800, len(text)+100)) def backtranslate(self, text: str, via_lang: str = "vi") -> Optional[str]: if not text: return text mid = self.translate(text, target_lang=via_lang) if not mid: return None prompt = f"Translate the following Vietnamese text back to English, preserving the exact meaning:\n\n{mid}" out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(900, len(text)+150)) if out: return out.strip() res = self.gm_easy.generate(prompt, max_output_tokens=min(900, len(text)+150)) return res.strip() if res else None # ————— Consistency Judge (cheap, ratio-based) ————— def consistency_check(self, user: str, output: str) -> bool: """Return True if 'output' appears supported by 'user' (context/question). Soft heuristic via LLM.""" prompt = ( "You are a strict medical QA validator. Given the USER input (question+context) " "and the MODEL ANSWER, reply with exactly 'PASS' if the answer is supported and safe, " "otherwise 'FAIL'. No extra text.\n\n" f"USER:\n{user}\n\nANSWER:\n{output}" ) out = self.nv.generate(prompt, temperature=0.0, max_tokens=3) if not out: out = self.gm_easy.generate(prompt, max_output_tokens=3) return isinstance(out, str) and "PASS" in out.upper()