Spaces:
Runtime error
Runtime error
File size: 914 Bytes
f408130 079dbdb f408130 e9260a3 f408130 e9260a3 3615152 67d1220 b663fcc 588e461 e9260a3 b663fcc e9260a3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')
classifier = pipeline("zero-shot-classification",
model=model, tokenizer=tokenizer)
import gradio as gr
def classify(input_query, input_classes, input_multi_class):
input_candidate_classes = input_classes
res = classifier(input_query, input_candidate_classes, multi_class=input_multi_class)
res_dict = {res.get('labels')[i]: res.get('scores')[i]
for i in range(len(res.get('labels')))}
return res_dict
demo = gr.Interface(
fn=classify,
inputs=[gr.Text(label='Search Query'), gr.Text(label='Candidate Classes'), gr.Checkbox(label='Multi_class')],
outputs=gr.Label(label='Prediction:')
)
demo.launch() |