import numpy as np import torch from torch.utils.data import Dataset, DataLoader import torch.nn as nn import gradio as gr # first we read in two items which are needed for our model # char2int is character encoder - a dictionary whose keys are unique characters and values are corresponding integer codes # char_array a character decoder - an array whose k^th entry is the character corresponding to code k char2int = torch.load('finish_story_modydick_char2int.pth') char_array = torch.load('finish_story_modydick_char_array.pth') # Then we define our RNN class class RNN(nn.Module): def __init__(self,vocab_size,embed_dim,rnn_hidden_size): super().__init__() self.embedding = nn.Embedding(vocab_size,embed_dim) self.rnn_hidden_size = rnn_hidden_size self.rnn = nn.LSTM(embed_dim,rnn_hidden_size,batch_first=True) self.fc = nn.Linear(rnn_hidden_size,vocab_size) def forward(self,x,hidden,cell): out = self.embedding(x).unsqueeze(1) out, (hidden,cell) = self.rnn(out,(hidden,cell)) out = self.fc(out).reshape(out.size(0),-1) return out, hidden, cell def init_hidden(self,batch_size): hidden = torch.zeros(1,batch_size,self.rnn_hidden_size) cell = torch.zeros(1,batch_size,self.rnn_hidden_size) return hidden,cell #Initialize our RNN instance vocab_size=len(char2int) embed_dim = 256 rnn_hidden_size=512 torch.manual_seed(1) model = RNN(vocab_size,embed_dim,rnn_hidden_size) # Load the weights from the trained model into our blank instance model.load_state_dict(torch.load('finish_story_mobydick_weights.pth',map_location=torch.device('cpu'))) model.eval() # Our sampling function which will be used to generate text. from torch.distributions.categorical import Categorical torch.manual_seed(1) def sample(input_str,len_generated_text=500,scale_factor=1.0): """A function for generating text based on an input string. On each iteration, the next character is selected based on distribution of (scaled) logit outputs from inputting current generated text.""" # If user input is empty, use a default input string if len(input_str)==0: input_str = 'The whale' #Encode the input string and reshape as needed encoded_input = torch.tensor([char2int[s] for s in input_str]) encoded_input = torch.reshape(encoded_input,(1,-1)) generated_str = input_str # initialize hidden and cell as zeros, then evaluate the model to get their final values. hidden and cell won't change after this, since we're not training. hidden,cell = model.init_hidden(1) for c in range(len(input_str)-1): _,hidden,cell = model(encoded_input[:,c].view(1),hidden,cell) # Grab the last character of the (encoded) input last_char = encoded_input[:,-1] # Loop to produce new characters for i in range(len_generated_text): # Compute output logits for this loop and scale according to scale_factor logits,hidden,cell = model(last_char.view(1),hidden,cell) logits=torch.squeeze(logits,0) scaled_logits=logits*scale_factor # Sample a new character based on scaled logits and append to generated_string m = Categorical(logits=scaled_logits) last_char = m.sample() generated_str += str(char_array[last_char]) return generated_str # Gradio interface title = "Moby-Dick story generator" description = "Enter the beginnning of a story. Your entry can be a word, a few words, a sentence, or more! This tool will attempt to pick up where you left off. Warning: the model was trained to imitate Moby-Dick by Herman Melville.\n Adjust the sliders to select the number of characters you would like in your output story, and the scaling factor.\n A higher scaling factor gives a more predictable (albeit boring) story. A lower scaling factor gives a less predictable (and certainly interesting) story. A scaling factor of 2 produces reasonable results." gr.Interface(fn=sample, inputs=[ gr.Textbox(label = 'Type the beginning of your story here.'), gr.Slider(0,5000,value=1000,step=10,label='How many characters in your output?'), gr.Slider(0, 5, value=2,label='Which scaling factor?') ], outputs="label", title = title, description = description, ).launch()