BioRAG / src /bio_rag /nli_evaluator.py
aseelflihan's picture
Deploy Bio-RAG
2a2c039
from __future__ import annotations
import logging
import re
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
logger = logging.getLogger(__name__)
class NLIEvaluator:
def __init__(self, model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-sst2"):
import os
import warnings
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
logger.info(f"Loading NLI model: {model_name}")
# Removed batch_size to prevent PyTorch DataLoader deadlock on Windows CPU
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
print(f"LOADED MODEL: {self.model.config._name_or_path}")
print(f"LABELS: {self.model.config.id2label}")
def _chunk_evidence(self, text: str, window_size: int = 3, stride: int = 1) -> list[str]:
# Split text into sentences
sentences = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if len(s.strip()) > 15]
if not sentences:
return [text]
chunks = []
for i in range(0, len(sentences), stride):
chunk = " ".join(sentences[i:i+window_size])
chunks.append(chunk)
if i + window_size >= len(sentences):
break
return chunks
def evaluate(self, claim: str, evidence_texts: list[str]) -> float:
if not evidence_texts:
return 1.0
all_scores = []
chunked_evidences = []
for text in evidence_texts:
chunked_evidences.extend(self._chunk_evidence(text, window_size=3, stride=3))
for evidence in chunked_evidences:
try:
inputs = self.tokenizer(
evidence,
claim,
return_tensors="pt",
truncation=True,
max_length=512
)
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.softmax(outputs.logits, dim=-1)
entail_prob = 0.0
contradict_prob = 0.0
neutral_prob = 0.0
for i, label in self.model.config.id2label.items():
prob = probs[0][i].item()
label_lower = label.lower()
if 'entail' in label_lower:
entail_prob = prob
elif 'contradict' in label_lower:
contradict_prob = prob
elif 'neutral' in label_lower:
neutral_prob = prob
nli_prob = (0.5 * neutral_prob) + contradict_prob
all_scores.append(nli_prob)
except Exception as e:
logger.warning(f"NLI Evaluation failed for a pair: {e}")
if not all_scores:
return 1.0
all_scores.sort()
min_score = all_scores[0]
max_score = max(all_scores)
# If strong support exists (< 0.05) AND no strong contradiction (> 0.7),
# trust the supporting evidence
if min_score < 0.05 and max_score < 0.7:
result = min_score
elif min_score < 0.05 and max_score >= 0.7:
# Mixed evidence: some support, some contradict — use percentile 25
idx = max(0, len(all_scores) // 4)
result = all_scores[idx]
else:
# No strong support — use percentile 25
idx = max(0, len(all_scores) // 4)
result = all_scores[idx]
# Unverified claim handler
if result > 0.45 and result < 0.55:
result = 0.8501
return result