BartTextGenV1 / app.py
krushna's picture
Added title and shortened description
94f8086
import gradio as gr
import nltk
from nltk.corpus import words
nltk.download('words')
from transformers import BartTokenizer, BartForConditionalGeneration
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
def predict(text, number_of_predictions):
print(text.split()[-1])
if (text.split()[-1] not in words.words()) & (text.split()[-1] not in ["?",".",","]):
return None
TXT = text + "<mask>"
input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
logits = model(input_ids).logits
masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
probs = logits[0, masked_index].softmax(dim=0)
values, predictions = probs.topk(int(number_of_predictions))
list_of_words = tokenizer.decode(predictions).split()
list_of_words = [i for i in list_of_words if (i in words.words()) | (i in ["?",".",","])]
return ', '.join([str(elem) for elem in list_of_words])
# return list_of_words
intr = gr.Interface(predict, ["text",gr.Number(value=5)], "text", live = True, title = "Next word predictor", description = "Press or delete an extra space so the model will compute for current input.")
intr.launch(inline = False)