larrysim commited on
Commit
e728a36
·
verified ·
1 Parent(s): 6e015c4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +203 -0
app.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ import pickle
5
+ import time
6
+ import os
7
+ from tensorflow.keras.models import load_model
8
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
9
+
10
+ # Set page config
11
+ st.set_page_config(
12
+ page_title="Next Word Prediction",
13
+ page_icon="🔮",
14
+ layout="wide",
15
+ initial_sidebar_state="expanded"
16
+ )
17
+
18
+ # Custom CSS
19
+ st.markdown("""
20
+ <style>
21
+ .main-header {
22
+ font-size: 3rem;
23
+ color: #1f77b4;
24
+ text-align: center;
25
+ margin-bottom: 2rem;
26
+ }
27
+ .prediction-box {
28
+ background-color: #f0f2f6;
29
+ padding: 20px;
30
+ border-radius: 10px;
31
+ border-left: 5px solid #1f77b4;
32
+ margin-top: 20px;
33
+ }
34
+ .stButton button {
35
+ width: 100%;
36
+ background-color: #1f77b4;
37
+ color: white;
38
+ }
39
+ </style>
40
+ """, unsafe_allow_html=True)
41
+
42
+ # Load model and tokenizer
43
+ @st.cache_resource
44
+ def load_components():
45
+ try:
46
+ # Check if model file exists
47
+ if not os.path.exists('model.h5'):
48
+ st.error("Model file (model.h5) not found!")
49
+ return None, None
50
+
51
+ model = load_model('model.h5')
52
+
53
+ # Try to load .pkl tokenizer file
54
+ if os.path.exists('tokenizer.pkl'):
55
+ with open('tokenizer.pkl', 'rb') as handle:
56
+ tokenizer = pickle.load(handle)
57
+ st.success("Successfully loaded tokenizer.pkl")
58
+ # Try alternative file names if needed
59
+ elif os.path.exists('tokenizer.pickle'):
60
+ with open('tokenizer.pickle', 'rb') as handle:
61
+ tokenizer = pickle.load(handle)
62
+ st.success("Successfully loaded tokenizer.pickle")
63
+ else:
64
+ st.error("Tokenizer file not found. Please ensure tokenizer.pkl exists.")
65
+ return model, None
66
+
67
+ return model, tokenizer
68
+
69
+ except Exception as e:
70
+ st.error(f"Error loading model: {e}")
71
+ return None, None
72
+
73
+ # Prediction function
74
+ def predict_next_words(text, num_words=3, temperature=1.0):
75
+ if not text.strip():
76
+ return "Please enter some text first"
77
+
78
+ try:
79
+ # Tokenize input text
80
+ sequence = tokenizer.texts_to_sequences([text])
81
+
82
+ if not sequence or not sequence[0]:
83
+ return "No recognizable words in input"
84
+
85
+ sequence = sequence[0]
86
+
87
+ # Predict next words
88
+ predictions = []
89
+ for _ in range(num_words):
90
+ # Pad sequence
91
+ padded_sequence = pad_sequences([sequence], maxlen=model.input_shape[1], padding='pre')
92
+
93
+ # Predict
94
+ predicted_probs = model.predict(padded_sequence, verbose=0)[0]
95
+
96
+ # Apply temperature
97
+ predicted_probs = np.log(predicted_probs) / temperature
98
+ exp_preds = np.exp(predicted_probs)
99
+ predicted_probs = exp_preds / np.sum(exp_preds)
100
+
101
+ # Sample from distribution
102
+ predicted_index = np.random.choice(len(predicted_probs), p=predicted_probs)
103
+
104
+ # Convert index to word
105
+ predicted_word = ""
106
+ for word, index in tokenizer.word_index.items():
107
+ if index == predicted_index:
108
+ predicted_word = word
109
+ break
110
+
111
+ predictions.append(predicted_word)
112
+ sequence.append(predicted_index)
113
+
114
+ return " ".join(predictions)
115
+
116
+ except Exception as e:
117
+ return f"Prediction error: {str(e)}"
118
+
119
+ # Main app
120
+ def main():
121
+ st.markdown('<h1 class="main-header">🔮 Next Word Prediction</h1>', unsafe_allow_html=True)
122
+
123
+ # Load model
124
+ model, tokenizer = load_components()
125
+
126
+ if model is None:
127
+ st.error("Failed to load model. Please check if model.h5 is in the correct directory.")
128
+ return
129
+
130
+ if tokenizer is None:
131
+ st.error("Failed to load tokenizer. Please check if tokenizer.pkl is in the correct directory.")
132
+ return
133
+
134
+ # Layout
135
+ col1, col2 = st.columns([2, 1])
136
+
137
+ with col1:
138
+ input_text = st.text_area(
139
+ "Input Text",
140
+ "The weather today is",
141
+ height=150,
142
+ help="Enter some text to start the prediction"
143
+ )
144
+
145
+ # Prediction parameters
146
+ col_a, col_b = st.columns(2)
147
+ with col_a:
148
+ num_words = st.slider(
149
+ "Words to predict",
150
+ min_value=1,
151
+ max_value=10,
152
+ value=3,
153
+ help="Number of words to generate"
154
+ )
155
+ with col_b:
156
+ temperature = st.slider(
157
+ "Temperature",
158
+ min_value=0.1,
159
+ max_value=2.0,
160
+ value=1.0,
161
+ step=0.1,
162
+ help="Higher values = more creative, Lower values = more predictable"
163
+ )
164
+
165
+ # Predict button
166
+ if st.button("Predict Next Words", type="primary"):
167
+ with st.spinner("Generating prediction..."):
168
+ time.sleep(0.5) # Simulate processing
169
+ prediction = predict_next_words(input_text, num_words, temperature)
170
+
171
+ st.markdown('<div class="prediction-box">', unsafe_allow_html=True)
172
+ st.subheader("Prediction Result")
173
+ st.success(f"**{input_text} {prediction}**")
174
+ st.markdown('</div>', unsafe_allow_html=True)
175
+
176
+ with col2:
177
+ st.subheader("Examples to try")
178
+ examples = [
179
+ "I want to eat",
180
+ "Machine learning is",
181
+ "The future of AI",
182
+ "In the beginning",
183
+ "She went to the",
184
+ "The best way to",
185
+ "Artificial intelligence will"
186
+ ]
187
+
188
+ for example in examples:
189
+ if st.button(example, key=example):
190
+ st.session_state.input_text = example
191
+
192
+ st.markdown("---")
193
+ st.info("💡 **Tip**: Adjust the temperature slider to control the creativity of predictions.")
194
+
195
+ # Model info
196
+ with st.expander("Model Information"):
197
+ st.write(f"**Model Architecture**: {model.name}")
198
+ st.write(f"**Input Shape**: {model.input_shape}")
199
+ st.write(f"**Output Shape**: {model.output_shape}")
200
+ st.write(f"**Vocabulary Size**: {len(tokenizer.word_index)}")
201
+
202
+ if __name__ == "__main__":
203
+ main()