Spaces:
Running
Running
| import torch | |
| import numpy as np | |
| from sentence_transformers import CrossEncoder | |
| from huggingface_hub import login | |
| class NLIVerifier: | |
| def __init__(self, model_name="cross-encoder/nli-distilroberta-base", hf_token=None): | |
| """ | |
| Initialize the NLI model using CrossEncoder. | |
| """ | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Loading NLI Model ({self.device})...") | |
| if hf_token: | |
| try: | |
| login(token=hf_token) | |
| print("Logged in to Hugging Face.") | |
| except Exception as e: | |
| print(f"HF Login Warning: {e}") | |
| try: | |
| self.model = CrossEncoder(model_name, device=self.device) | |
| print("NLI Model Loaded Successfully.") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| self.model = None | |
| # Label mapping for cross-encoder/nli-distilroberta-base | |
| # 0: Contradiction | |
| # 1: Entailment | |
| # 2: Neutral | |
| self.labels = ["Contradiction", "Entailment", "Neutral"] | |
| def predict(self, text1, text2): | |
| """ | |
| Verify if text1 and text2 contradict each other. | |
| Returns: (IsContradiction: bool, Confidence: float, Label: str) | |
| """ | |
| if not self.model: | |
| return False, 0.0, "Model Error" | |
| # CrossEncoder returns logits | |
| scores = self.model.predict([(text1, text2)])[0] | |
| # Apply softmax to get probabilities | |
| exp_scores = np.exp(scores) | |
| probs = exp_scores / np.sum(exp_scores) | |
| pred_label_idx = probs.argmax() | |
| confidence = probs[pred_label_idx] | |
| label = self.labels[pred_label_idx] | |
| # Check if Contradiction (Index 0) is the winner with high confidence | |
| is_contradiction = (pred_label_idx == 0 and confidence > 0.5) | |
| return is_contradiction, float(confidence), label | |