Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import nltk | |
| from nltk.corpus import words | |
| nltk.download('words') | |
| from transformers import BartTokenizer, BartForConditionalGeneration | |
| tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") | |
| model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") | |
| def predict(text, number_of_predictions): | |
| print(text.split()[-1]) | |
| if (text.split()[-1] not in words.words()) & (text.split()[-1] not in ["?",".",","]): | |
| return None | |
| TXT = text + "<mask>" | |
| input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"] | |
| logits = model(input_ids).logits | |
| masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() | |
| probs = logits[0, masked_index].softmax(dim=0) | |
| values, predictions = probs.topk(int(number_of_predictions)) | |
| list_of_words = tokenizer.decode(predictions).split() | |
| list_of_words = [i for i in list_of_words if (i in words.words()) | (i in ["?",".",","])] | |
| return ', '.join([str(elem) for elem in list_of_words]) | |
| # return list_of_words | |
| intr = gr.Interface(predict, ["text",gr.Number(value=5)], "text", live = True, title = "Next word predictor", description = "Press or delete an extra space so the model will compute for current input.") | |
| intr.launch(inline = False) |