| import streamlit as st | |
| import torch | |
| from transformers import AutoModelForSequenceClassification as ASC | |
| from transformers import AutoTokenizer as AT | |
| model = ASC.from_pretrained("rickxzo/albert-large-v2-s.a.m-nli") | |
| tokenizer = AT.from_pretrained("rickxzo/albert-large-v2-s.a.m-nli") | |
| def infer(sentence1, sentence2): | |
| inputs = tokenizer(sentence1, sentence2, return_tensors="pt", truncation=True, padding=True) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| return torch.argmax(probs).item() | |
| st.title("Contradiction Detector using AlBERT model") | |
| premise = st.text_area("Enter the premise: ") | |
| hypothesis = st.text_area("Enter the hypothesis: ") | |
| if premise and hypothesis: | |
| k = infer(premise, hypothesis) | |
| if k == 2: | |
| st.write("#### **Contradicting Statements Detected!**") | |
| elif k == 1: | |
| st.write("#### **Neutral Statements Detected.**") | |
| elif k == 0: | |
| st.write("#### **Entailing Statements Detected.**") | |