Spaces:
Build error
Build error
| # from transformers import pipeline | |
| from transformers import BartForSequenceClassification, BartTokenizer | |
| import gradio as grad | |
| # zero_shot_classifier = pipeline("zero-shot-classification") | |
| bart_tkn = BartTokenizer.from_pretrained('facebook/bart-large-mnli') | |
| mdl = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli') | |
| # def classify(text, labels): | |
| def classify(text, label): | |
| # classifier_labels = labels.split(",") | |
| # #["software", "politics", "love", "movies", "emergency", "advertisment", "sports"] | |
| # response = zero_shot_classifier(text, classifier_labels) | |
| tkn_ids = bart_tkn.encode(text, label, return_tensors = "pt") | |
| tkn_lgts = mdl(tkn_ids)[0] | |
| entail_contra_tkn_lgts = tkn_lgts[:, [0, 2]] | |
| probab = entail_contra_tkn_lgts.softmax(dim = 1) | |
| response = probab[:, 1].item() * 100 | |
| return response | |
| txt = grad.Textbox(lines = 1, label = "English", placeholder = "text to be classified") | |
| labels = grad.Textbox(lines = 1, label = "Labels", placeholder = "comma separated labels") | |
| out = grad.Textbox(lines = 1, label = "Classification") | |
| grad.Interface( | |
| classify, | |
| inputs = [txt, labels], | |
| outputs = out | |
| ).launch() |