anisaallahdadi's picture
Update app.py
3615152
raw
history blame contribute delete
914 Bytes
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')
classifier = pipeline("zero-shot-classification",
model=model, tokenizer=tokenizer)
import gradio as gr
def classify(input_query, input_classes, input_multi_class):
input_candidate_classes = input_classes
res = classifier(input_query, input_candidate_classes, multi_class=input_multi_class)
res_dict = {res.get('labels')[i]: res.get('scores')[i]
for i in range(len(res.get('labels')))}
return res_dict
demo = gr.Interface(
fn=classify,
inputs=[gr.Text(label='Search Query'), gr.Text(label='Candidate Classes'), gr.Checkbox(label='Multi_class')],
outputs=gr.Label(label='Prediction:')
)
demo.launch()