File size: 914 Bytes
f408130
079dbdb
f408130
 
 
e9260a3
f408130
e9260a3
3615152
 
67d1220
b663fcc
 
 
 
 
 
 
588e461
e9260a3
 
 
b663fcc
 
e9260a3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')

classifier = pipeline("zero-shot-classification",
                      model=model, tokenizer=tokenizer)

import gradio as gr

def classify(input_query, input_classes, input_multi_class):
    
    input_candidate_classes = input_classes

    res = classifier(input_query, input_candidate_classes, multi_class=input_multi_class)
    res_dict = {res.get('labels')[i]: res.get('scores')[i]
                for i in range(len(res.get('labels')))}

    return res_dict

demo = gr.Interface(
  fn=classify,
  inputs=[gr.Text(label='Search Query'), gr.Text(label='Candidate Classes'), gr.Checkbox(label='Multi_class')],
  outputs=gr.Label(label='Prediction:')
)

demo.launch()