Spaces:
Runtime error
Runtime error
create app file
Browse files
app.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from transformers import BartTokenizer, BartForConditionalGeneration
|
| 3 |
+
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
|
| 4 |
+
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
|
| 5 |
+
def predict(text, number of predictions):
|
| 6 |
+
TXT = text + "<mask>"
|
| 7 |
+
input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
|
| 8 |
+
logits = model(input_ids).logits
|
| 9 |
+
|
| 10 |
+
masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
|
| 11 |
+
probs = logits[0, masked_index].softmax(dim=0)
|
| 12 |
+
values, predictions = probs.topk(5)
|
| 13 |
+
list_of_words = tokenizer.decode(predictions).split()
|
| 14 |
+
return ', '.join([str(elem) for elem in list_of_words])
|
| 15 |
+
intr = gr.Interface(predict, ["text","number"], "text")
|
| 16 |
+
intr.launch(inline = False)
|