moby-dick / app.py
etweedy's picture
Upload app.py
1ef46a0
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import gradio as gr
# first we read in two items which are needed for our model
# char2int is character encoder - a dictionary whose keys are unique characters and values are corresponding integer codes
# char_array a character decoder - an array whose k^th entry is the character corresponding to code k
char2int = torch.load('finish_story_modydick_char2int.pth')
char_array = torch.load('finish_story_modydick_char_array.pth')
# Then we define our RNN class
class RNN(nn.Module):
def __init__(self,vocab_size,embed_dim,rnn_hidden_size):
super().__init__()
self.embedding = nn.Embedding(vocab_size,embed_dim)
self.rnn_hidden_size = rnn_hidden_size
self.rnn = nn.LSTM(embed_dim,rnn_hidden_size,batch_first=True)
self.fc = nn.Linear(rnn_hidden_size,vocab_size)
def forward(self,x,hidden,cell):
out = self.embedding(x).unsqueeze(1)
out, (hidden,cell) = self.rnn(out,(hidden,cell))
out = self.fc(out).reshape(out.size(0),-1)
return out, hidden, cell
def init_hidden(self,batch_size):
hidden = torch.zeros(1,batch_size,self.rnn_hidden_size)
cell = torch.zeros(1,batch_size,self.rnn_hidden_size)
return hidden,cell
#Initialize our RNN instance
vocab_size=len(char2int)
embed_dim = 256
rnn_hidden_size=512
torch.manual_seed(1)
model = RNN(vocab_size,embed_dim,rnn_hidden_size)
# Load the weights from the trained model into our blank instance
model.load_state_dict(torch.load('finish_story_mobydick_weights.pth',map_location=torch.device('cpu')))
model.eval()
# Our sampling function which will be used to generate text.
from torch.distributions.categorical import Categorical
torch.manual_seed(1)
def sample(input_str,len_generated_text=500,scale_factor=1.0):
"""A function for generating text based on an input string. On each iteration, the next character is selected based on distribution of (scaled) logit outputs from inputting current generated text."""
# If user input is empty, use a default input string
if len(input_str)==0:
input_str = 'The whale'
#Encode the input string and reshape as needed
encoded_input = torch.tensor([char2int[s] for s in input_str])
encoded_input = torch.reshape(encoded_input,(1,-1))
generated_str = input_str
# initialize hidden and cell as zeros, then evaluate the model to get their final values. hidden and cell won't change after this, since we're not training.
hidden,cell = model.init_hidden(1)
for c in range(len(input_str)-1):
_,hidden,cell = model(encoded_input[:,c].view(1),hidden,cell)
# Grab the last character of the (encoded) input
last_char = encoded_input[:,-1]
# Loop to produce new characters
for i in range(len_generated_text):
# Compute output logits for this loop and scale according to scale_factor
logits,hidden,cell = model(last_char.view(1),hidden,cell)
logits=torch.squeeze(logits,0)
scaled_logits=logits*scale_factor
# Sample a new character based on scaled logits and append to generated_string
m = Categorical(logits=scaled_logits)
last_char = m.sample()
generated_str += str(char_array[last_char])
return generated_str
# Gradio interface
title = "Moby-Dick story generator"
description = "Enter the beginnning of a story. Your entry can be a word, a few words, a sentence, or more! This tool will attempt to pick up where you left off. Warning: the model was trained to imitate Moby-Dick by Herman Melville.\n Adjust the sliders to select the number of characters you would like in your output story, and the scaling factor.\n A higher scaling factor gives a more predictable (albeit boring) story. A lower scaling factor gives a less predictable (and certainly interesting) story. A scaling factor of 2 produces reasonable results."
gr.Interface(fn=sample,
inputs=[
gr.Textbox(label = 'Type the beginning of your story here.'),
gr.Slider(0,5000,value=1000,step=10,label='How many characters in your output?'),
gr.Slider(0, 5, value=2,label='Which scaling factor?')
],
outputs="label",
title = title,
description = description,
).launch()