processing / utils /llm.py
LiamKhoaLe's picture
Upd RAG schema to QAC format
19d62ff
raw
history blame
7.84 kB
# Round-robin rotator + paraphrasing + translation/backtranslation
import os
import logging
import requests
from typing import Optional
from google import genai
logger = logging.getLogger("llm")
if not logger.handlers:
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
logger.addHandler(handler)
# LLM parser limit text to log-out
def snip(s: str, n: int = 12) -> str:
if not isinstance(s, str): return "∅"
parts = s.strip().split()
return " ".join(parts[:n]) + (" …" if len(parts) > n else "")
class KeyRotator:
def __init__(self, env_prefix: str, max_keys: int = 5):
keys = []
for i in range(1, max_keys + 1):
v = os.getenv(f"{env_prefix}_{i}")
if v:
keys.append(v.strip())
if not keys:
logger.warning(f"[LLM] No keys found for prefix {env_prefix}_*")
self.keys = keys
self.dead = set()
self.idx = 0
def next_key(self) -> Optional[str]:
if not self.keys:
return None
for _ in range(len(self.keys)):
k = self.keys[self.idx % len(self.keys)]
self.idx += 1
if k not in self.dead:
return k
return None
def mark_bad(self, key: Optional[str]):
if key:
self.dead.add(key)
logger.warning(f"[LLM] Quarantined key (prefix hidden): {key[:6]}***")
class GeminiClient:
def __init__(self, rotator: KeyRotator, default_model: str):
self.rotator = rotator
self.default_model = default_model
def generate(self, prompt: str, model: Optional[str] = None, temperature: float = 0.2, max_output_tokens: int = 512) -> Optional[str]:
key = self.rotator.next_key()
if not key:
return None
try:
client = genai.Client(api_key=key)
# NOTE: matches your required pattern/use
res = client.models.generate_content(
model=model or self.default_model,
contents=prompt
)
text = getattr(res, "text", None)
if text:
logger.info(f"[LLM][Gemini] out={snip(text)}")
return text
except Exception as e:
logger.error(f"[LLM][Gemini] {e}")
self.rotator.mark_bad(key)
return None
class NvidiaClient:
def __init__(self, rotator: KeyRotator, default_model: str):
self.rotator = rotator
self.default_model = default_model
self.url = os.getenv("NVIDIA_API_URL", "https://integrate.api.nvidia.com/v1/chat/completions")
# Regex-based cleaning resp from quotes
def _clean_resp(self, resp: str) -> str:
if not resp: return resp
txt = resp.strip()
# Remove common boilerplate prefixes
for pat in [
r"^Here is (a|the) .*?:\s*",
r"^Paraphrased(?: version)?:\s*",
r"^Sure[,.]?\s*",
r"^Okay[,.]?\s*"
]:
import re
txt = re.sub(pat, "", txt, flags=re.I)
return txt.strip()
def generate(self, prompt: str, model: Optional[str] = None, temperature: float = 0.2, max_tokens: int = 512) -> Optional[str]:
key = self.rotator.next_key()
if not key:
return None
try:
headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"}
payload = {
"model": model or self.default_model,
"messages": [{"role": "user", "content": prompt}],
"temperature": temperature,
"max_tokens": max_tokens
}
r = requests.post(self.url, headers=headers, json=payload, timeout=45)
if r.status_code >= 400:
raise RuntimeError(f"HTTP {r.status_code}: {r.text[:200]}")
data = r.json()
text = data["choices"][0]["message"]["content"]
clean = self._clean_resp(text)
logger.info(f"[LLM][NVIDIA] out={snip(clean)}")
return clean
except Exception as e:
logger.error(f"[LLM][NVIDIA] {e}")
self.rotator.mark_bad(key)
return None
class Paraphraser:
"""Prefers NVIDIA (cheap), falls back to Gemini. Also offers translate/backtranslate and a tiny consistency judge."""
def __init__(self, nvidia_model: str, gemini_model_easy: str, gemini_model_hard: str):
self.nv = NvidiaClient(KeyRotator("NVIDIA_API"), nvidia_model)
self.gm_easy = GeminiClient(KeyRotator("GEMINI_API"), gemini_model_easy)
self.gm_hard = GeminiClient(KeyRotator("GEMINI_API"), gemini_model_hard)
# Regex-based cleaning resp from quotes
def _clean_resp(self, resp: str) -> str:
if not resp: return resp
txt = resp.strip()
# Remove common boilerplate prefixes
for pat in [
r"^Here is (a|the) .*?:\s*",
r"^Paraphrased(?: version)?:\s*",
r"^Sure[,.]?\s*",
r"^Okay[,.]?\s*"
]:
import re
txt = re.sub(pat, "", txt, flags=re.I)
return txt.strip()
# ————— Paraphrase —————
def paraphrase(self, text: str, difficulty: str = "easy") -> str:
if not text or len(text) < 12:
return text
prompt = (
"Paraphrase the following medical text concisely, preserve meaning and clinical terms.\n"
"Do not fabricate or remove factual claims.\n"
"Return ONLY the rewritten text, without any introduction, commentary.\n"+ text
)
out = self.nv.generate(prompt, temperature=0.1, max_tokens=min(600, max(128, len(text)//2)))
if out: return self._clean_resp(out)
gm = self.gm_easy if difficulty == "easy" else self.gm_hard
out = gm.generate(prompt, max_output_tokens=min(600, max(128, len(text)//2)))
return self._clean_resp(out) if out else text
# ————— Translate & Backtranslate —————
def translate(self, text: str, target_lang: str = "vi") -> Optional[str]:
if not text: return text
prompt = f"Translate to {target_lang}. Keep meaning exact, preserve medical terms:\n\n{text}"
out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(800, len(text)+100))
if out: return out.strip()
return self.gm_easy.generate(prompt, max_output_tokens=min(800, len(text)+100))
def backtranslate(self, text: str, via_lang: str = "vi") -> Optional[str]:
if not text: return text
mid = self.translate(text, target_lang=via_lang)
if not mid: return None
prompt = f"Translate the following Vietnamese text back to English, preserving the exact meaning:\n\n{mid}"
out = self.nv.generate(prompt, temperature=0.0, max_tokens=min(900, len(text)+150))
if out: return out.strip()
res = self.gm_easy.generate(prompt, max_output_tokens=min(900, len(text)+150))
return res.strip() if res else None
# ————— Consistency Judge (cheap, ratio-based) —————
def consistency_check(self, user: str, output: str) -> bool:
"""Return True if 'output' appears supported by 'user' (context/question). Soft heuristic via LLM."""
prompt = (
"You are a strict medical QA validator. Given the USER input (question+context) "
"and the MODEL ANSWER, reply with exactly 'PASS' if the answer is supported and safe, "
"otherwise 'FAIL'. No extra text.\n\n"
f"USER:\n{user}\n\nANSWER:\n{output}"
)
out = self.nv.generate(prompt, temperature=0.0, max_tokens=3)
if not out:
out = self.gm_easy.generate(prompt, max_output_tokens=3)
return isinstance(out, str) and "PASS" in out.upper()