File size: 1,241 Bytes
b6e8b15
cfd3105
 
 
b6e8b15
 
 
44c0dde
cfd3105
 
 
b6e8b15
 
 
 
 
 
81ec61b
b6e8b15
cfd3105
 
 
94f8086
b6e8b15
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import gradio as gr
import nltk
from nltk.corpus import words
nltk.download('words')
from transformers import BartTokenizer, BartForConditionalGeneration
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
def predict(text, number_of_predictions):
  print(text.split()[-1])
  if (text.split()[-1] not in words.words()) & (text.split()[-1] not in ["?",".",","]):
    return None
  TXT = text + "<mask>"
  input_ids = tokenizer([TXT], return_tensors="pt")["input_ids"]
  logits = model(input_ids).logits
  
  masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
  probs = logits[0, masked_index].softmax(dim=0)
  values, predictions = probs.topk(int(number_of_predictions))
  list_of_words = tokenizer.decode(predictions).split()
  list_of_words = [i for i in list_of_words if (i in words.words()) | (i in ["?",".",","])]
  return ', '.join([str(elem) for elem in list_of_words])
  # return list_of_words
intr = gr.Interface(predict, ["text",gr.Number(value=5)], "text", live = True, title = "Next word predictor", description = "Press or delete an extra space so the model will compute for current input.")
intr.launch(inline = False)