jiehou's picture
Update app.py
60048fe
from keras.models import load_model
from keras.preprocessing.sequence import pad_sequences
import gradio as gr
import numpy as np
import sys
# load the model from disk
model = load_model("model_pretrain.h5")
char_to_int = {' ': 0,
'a': 1,
'b': 2,
'c': 3,
'd': 4,
'e': 5,
'f': 6,
'g': 7,
'h': 8,
'i': 9,
'j': 10,
'k': 11,
'l': 12,
'm': 13,
'n': 14,
'o': 15,
'p': 16,
'q': 17,
'r': 18,
's': 19,
't': 20,
'u': 21,
'v': 22,
'w': 23,
'x': 24,
'y': 25,
'z': 26}
int_to_char = dict((i, c) for c, i in char_to_int.items())
def get_sequence_from_encoding(sequence_encoded, ind_to_word):
in_text = ''
for index in sequence_encoded:
if index in ind_to_word:
word = ind_to_word[index]
else:
word = ''
in_text += '' + word
return in_text
def get_encoding_from_sequence(sequence, word_to_ind):
out_encode = []
for word in list(sequence.lower()):
if word in word_to_ind:
index = word_to_ind[word]
else:
index = 0
out_encode.append(index)
return out_encode
def generate_text(start_text,text_length=100):
encoding = get_encoding_from_sequence(start_text.lower(),char_to_int)
decoding = get_sequence_from_encoding(encoding,int_to_char)
print("Input sequence: ", start_text)
print("Start generating the paragraph: \n")
line_print = ''
new_sequence = start_text
sys.stdout.write(start_text)
for repeat in range(text_length):
test_data = np.reshape(encoding, (1, len(encoding)))
maxlen = 20 # specify how long the sequences should be. This cuts sequences that exceed that number.
test_data_pad = pad_sequences(test_data, padding='pre', maxlen=maxlen)
prediction = model.predict(test_data_pad, verbose=0)
index = np.argmax(prediction)
result = int_to_char[index]
seq_in = [int_to_char[value] for value in encoding]
if len(line_print) > 70 and result == ' ':
sys.stdout.write("\n")
line_print = ''
sys.stdout.write(result)
line_print = line_print + result
new_sequence = new_sequence + result
encoding.append(index)
encoding = encoding[1:len(encoding)]
return new_sequence
### configure inputs/outputs
set_input = gr.Textbox(label = 'Starting words')
set_len = gr.Slider(1, 1000, step=5, label = 'Text Length')
set_output = gr.Textbox(label = 'Generated sentences')
### configure gradio, detailed can be found at https://www.gradio.app/docs/#i_slider
interface = gr.Interface(fn=generate_text,
inputs=[set_input,set_len],
outputs=set_output,
title="CSCI4750/5750 Demo 8: Web Application for Text Generation using RNN",
description= "Click examples below for a quick demo",
theme = 'huggingface'
)
interface.launch(debug=True)