Spaces:
Running
Running
| from __future__ import annotations | |
| import logging | |
| import re | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| logger = logging.getLogger(__name__) | |
| class NLIEvaluator: | |
| def __init__(self, model_name="pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-sst2"): | |
| import os | |
| import warnings | |
| os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1" | |
| logger.info(f"Loading NLI model: {model_name}") | |
| # Removed batch_size to prevent PyTorch DataLoader deadlock on Windows CPU | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| print(f"LOADED MODEL: {self.model.config._name_or_path}") | |
| print(f"LABELS: {self.model.config.id2label}") | |
| def _chunk_evidence(self, text: str, window_size: int = 3, stride: int = 1) -> list[str]: | |
| # Split text into sentences | |
| sentences = [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if len(s.strip()) > 15] | |
| if not sentences: | |
| return [text] | |
| chunks = [] | |
| for i in range(0, len(sentences), stride): | |
| chunk = " ".join(sentences[i:i+window_size]) | |
| chunks.append(chunk) | |
| if i + window_size >= len(sentences): | |
| break | |
| return chunks | |
| def evaluate(self, claim: str, evidence_texts: list[str]) -> float: | |
| if not evidence_texts: | |
| return 1.0 | |
| all_scores = [] | |
| chunked_evidences = [] | |
| for text in evidence_texts: | |
| chunked_evidences.extend(self._chunk_evidence(text, window_size=3, stride=3)) | |
| for evidence in chunked_evidences: | |
| try: | |
| inputs = self.tokenizer( | |
| evidence, | |
| claim, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512 | |
| ) | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| probs = torch.softmax(outputs.logits, dim=-1) | |
| entail_prob = 0.0 | |
| contradict_prob = 0.0 | |
| neutral_prob = 0.0 | |
| for i, label in self.model.config.id2label.items(): | |
| prob = probs[0][i].item() | |
| label_lower = label.lower() | |
| if 'entail' in label_lower: | |
| entail_prob = prob | |
| elif 'contradict' in label_lower: | |
| contradict_prob = prob | |
| elif 'neutral' in label_lower: | |
| neutral_prob = prob | |
| nli_prob = (0.5 * neutral_prob) + contradict_prob | |
| all_scores.append(nli_prob) | |
| except Exception as e: | |
| logger.warning(f"NLI Evaluation failed for a pair: {e}") | |
| if not all_scores: | |
| return 1.0 | |
| all_scores.sort() | |
| min_score = all_scores[0] | |
| max_score = max(all_scores) | |
| # If strong support exists (< 0.05) AND no strong contradiction (> 0.7), | |
| # trust the supporting evidence | |
| if min_score < 0.05 and max_score < 0.7: | |
| result = min_score | |
| elif min_score < 0.05 and max_score >= 0.7: | |
| # Mixed evidence: some support, some contradict — use percentile 25 | |
| idx = max(0, len(all_scores) // 4) | |
| result = all_scores[idx] | |
| else: | |
| # No strong support — use percentile 25 | |
| idx = max(0, len(all_scores) // 4) | |
| result = all_scores[idx] | |
| # Unverified claim handler | |
| if result > 0.45 and result < 0.55: | |
| result = 0.8501 | |
| return result | |