File size: 518 Bytes
9d21edd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from transformers import pipeline

# Load once (slow only first time)
nli_pipeline = pipeline(
    "text-classification",
    model="roberta-large-mnli",
    device=-1  # CPU
)

def nli_contradiction(text1, text2, threshold=0.8):
    """

    Returns True if NLI model strongly predicts contradiction

    """
    input_text = f"{text1} </s></s> {text2}"
    result = nli_pipeline(input_text)[0]

    return (
        result["label"] == "CONTRADICTION" and
        result["score"] >= threshold
    )