|
|
import numpy as np |
|
|
import torch |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
import torch.nn as nn |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
char2int = torch.load('finish_story_modydick_char2int.pth') |
|
|
char_array = torch.load('finish_story_modydick_char_array.pth') |
|
|
|
|
|
|
|
|
|
|
|
class RNN(nn.Module): |
|
|
def __init__(self,vocab_size,embed_dim,rnn_hidden_size): |
|
|
super().__init__() |
|
|
self.embedding = nn.Embedding(vocab_size,embed_dim) |
|
|
self.rnn_hidden_size = rnn_hidden_size |
|
|
self.rnn = nn.LSTM(embed_dim,rnn_hidden_size,batch_first=True) |
|
|
self.fc = nn.Linear(rnn_hidden_size,vocab_size) |
|
|
|
|
|
def forward(self,x,hidden,cell): |
|
|
out = self.embedding(x).unsqueeze(1) |
|
|
out, (hidden,cell) = self.rnn(out,(hidden,cell)) |
|
|
out = self.fc(out).reshape(out.size(0),-1) |
|
|
return out, hidden, cell |
|
|
|
|
|
def init_hidden(self,batch_size): |
|
|
hidden = torch.zeros(1,batch_size,self.rnn_hidden_size) |
|
|
cell = torch.zeros(1,batch_size,self.rnn_hidden_size) |
|
|
return hidden,cell |
|
|
|
|
|
|
|
|
vocab_size=len(char2int) |
|
|
embed_dim = 256 |
|
|
rnn_hidden_size=512 |
|
|
torch.manual_seed(1) |
|
|
model = RNN(vocab_size,embed_dim,rnn_hidden_size) |
|
|
|
|
|
|
|
|
model.load_state_dict(torch.load('finish_story_mobydick_weights.pth',map_location=torch.device('cpu'))) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
from torch.distributions.categorical import Categorical |
|
|
torch.manual_seed(1) |
|
|
|
|
|
def sample(input_str,len_generated_text=500,scale_factor=1.0): |
|
|
"""A function for generating text based on an input string. On each iteration, the next character is selected based on distribution of (scaled) logit outputs from inputting current generated text.""" |
|
|
|
|
|
|
|
|
if len(input_str)==0: |
|
|
input_str = 'The whale' |
|
|
|
|
|
|
|
|
encoded_input = torch.tensor([char2int[s] for s in input_str]) |
|
|
encoded_input = torch.reshape(encoded_input,(1,-1)) |
|
|
generated_str = input_str |
|
|
|
|
|
|
|
|
hidden,cell = model.init_hidden(1) |
|
|
for c in range(len(input_str)-1): |
|
|
_,hidden,cell = model(encoded_input[:,c].view(1),hidden,cell) |
|
|
|
|
|
|
|
|
|
|
|
last_char = encoded_input[:,-1] |
|
|
|
|
|
|
|
|
for i in range(len_generated_text): |
|
|
|
|
|
logits,hidden,cell = model(last_char.view(1),hidden,cell) |
|
|
logits=torch.squeeze(logits,0) |
|
|
scaled_logits=logits*scale_factor |
|
|
|
|
|
|
|
|
m = Categorical(logits=scaled_logits) |
|
|
last_char = m.sample() |
|
|
generated_str += str(char_array[last_char]) |
|
|
|
|
|
return generated_str |
|
|
|
|
|
|
|
|
|
|
|
title = "Moby-Dick story generator" |
|
|
description = "Enter the beginnning of a story. Your entry can be a word, a few words, a sentence, or more! This tool will attempt to pick up where you left off. Warning: the model was trained to imitate Moby-Dick by Herman Melville.\n Adjust the sliders to select the number of characters you would like in your output story, and the scaling factor.\n A higher scaling factor gives a more predictable (albeit boring) story. A lower scaling factor gives a less predictable (and certainly interesting) story. A scaling factor of 2 produces reasonable results." |
|
|
gr.Interface(fn=sample, |
|
|
inputs=[ |
|
|
gr.Textbox(label = 'Type the beginning of your story here.'), |
|
|
gr.Slider(0,5000,value=1000,step=10,label='How many characters in your output?'), |
|
|
gr.Slider(0, 5, value=2,label='Which scaling factor?') |
|
|
], |
|
|
outputs="label", |
|
|
title = title, |
|
|
description = description, |
|
|
).launch() |
|
|
|