jiehou commited on
Commit
2f75b84
·
1 Parent(s): d8ce8e8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -0
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.models import load_model
2
+ from keras.preprocessing.sequence import pad_sequences
3
+ import sys
4
+ # load the model from disk
5
+ model = load_model("model_pretrain.h5")
6
+
7
+ char_to_int = {' ': 0,
8
+ 'a': 1,
9
+ 'b': 2,
10
+ 'c': 3,
11
+ 'd': 4,
12
+ 'e': 5,
13
+ 'f': 6,
14
+ 'g': 7,
15
+ 'h': 8,
16
+ 'i': 9,
17
+ 'j': 10,
18
+ 'k': 11,
19
+ 'l': 12,
20
+ 'm': 13,
21
+ 'n': 14,
22
+ 'o': 15,
23
+ 'p': 16,
24
+ 'q': 17,
25
+ 'r': 18,
26
+ 's': 19,
27
+ 't': 20,
28
+ 'u': 21,
29
+ 'v': 22,
30
+ 'w': 23,
31
+ 'x': 24,
32
+ 'y': 25,
33
+ 'z': 26}
34
+
35
+ int_to_char = dict((i, c) for c, i in char_to_int.items())
36
+
37
+ def get_sequence_from_encoding(sequence_encoded, ind_to_word):
38
+ in_text = ''
39
+ for index in sequence_encoded:
40
+ if index in ind_to_word:
41
+ word = ind_to_word[index]
42
+ else:
43
+ word = ''
44
+ in_text += '' + word
45
+ return in_text
46
+
47
+ def get_encoding_from_sequence(sequence, word_to_ind):
48
+ out_encode = []
49
+ for word in list(sequence.lower()):
50
+ if word in word_to_ind:
51
+ index = word_to_ind[word]
52
+ else:
53
+ index = 0
54
+ out_encode.append(index)
55
+ return out_encode
56
+
57
+
58
+
59
+ def generate_text(start_text,text_length=100):
60
+ encoding = get_encoding_from_sequence(start_text.lower(),char_to_int)
61
+ decoding = get_sequence_from_encoding(encoding,int_to_char)
62
+ print("Input sequence: ", start_text)
63
+ print("Start generating the paragraph: \n")
64
+
65
+ line_print = ''
66
+ new_sequence = start_text
67
+ sys.stdout.write(start_text)
68
+ for repeat in range(text_length):
69
+ test_data = np.reshape(encoding, (1, len(encoding)))
70
+ maxlen = 20 # specify how long the sequences should be. This cuts sequences that exceed that number.
71
+ test_data_pad = pad_sequences(test_data, padding='pre', maxlen=maxlen)
72
+
73
+ prediction = model.predict(test_data_pad, verbose=0)
74
+ index = np.argmax(prediction)
75
+ result = int_to_char[index]
76
+ seq_in = [int_to_char[value] for value in encoding]
77
+ if len(line_print) > 70 and result == ' ':
78
+ sys.stdout.write("\n")
79
+ line_print = ''
80
+
81
+ sys.stdout.write(result)
82
+ line_print = line_print + result
83
+ new_sequence = new_sequence + result
84
+ encoding.append(index)
85
+ encoding = encoding[1:len(encoding)]
86
+
87
+
88
+
89
+ ### configure inputs/outputs
90
+
91
+
92
+ set_input = gr.Textbox.(label = 'Starting words')
93
+ set_output = gr.Textbox.(label = 'Generated sentences')
94
+
95
+ ### configure gradio, detailed can be found at https://www.gradio.app/docs/#i_slider
96
+ interface = gr.Interface(fn=predict_price,
97
+ inputs=set_input,
98
+ outputs=set_output,
99
+ title="CSCI4750/5750 Demo 8: Web Application for Text Generation using RNN",
100
+ description= "Click examples below for a quick demo",
101
+ theme = 'huggingface',
102
+ layout = 'vertical'
103
+ )
104
+ interface.launch(debug=True)