import gradio as gr import nltk from nltk.corpus import words nltk.download('words') from transformers import BartTokenizer, BartForConditionalGeneration tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") def predict(text, number_of_predictions): print(text.split()[-1]) if (text.split()[-1] not in words.words()) & (text.split()[-1] not in ["?",".",","]): return None TXT = text + "" input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"] logits = model(input_ids).logits masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() probs = logits[0, masked_index].softmax(dim=0) values, predictions = probs.topk(int(number_of_predictions)) list_of_words = tokenizer.decode(predictions).split() list_of_words = [i for i in list_of_words if (i in words.words()) | (i in ["?",".",","])] return ', '.join([str(elem) for elem in list_of_words]) # return list_of_words intr = gr.Interface(predict, ["text",gr.Number(value=5)], "text", live = True, title = "Next word predictor", description = "Press or delete an extra space so the model will compute for current input.") intr.launch(inline = False)