krushna commited on
Commit
b6e8b15
·
1 Parent(s): e8b2309

create app file

Browse files
Files changed (1) hide show
  1. app.py +16 -0
app.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import BartTokenizer, BartForConditionalGeneration
3
+ tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
4
+ model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
5
+ def predict(text, number of predictions):
6
+ TXT = text + "<mask>"
7
+ input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
8
+ logits = model(input_ids).logits
9
+
10
+ masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
11
+ probs = logits[0, masked_index].softmax(dim=0)
12
+ values, predictions = probs.topk(5)
13
+ list_of_words = tokenizer.decode(predictions).split()
14
+ return ', '.join([str(elem) for elem in list_of_words])
15
+ intr = gr.Interface(predict, ["text","number"], "text")
16
+ intr.launch(inline = False)