larrysim commited on
Commit
b86722f
·
verified ·
1 Parent(s): e077248

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -158
app.py CHANGED
@@ -1,165 +1,28 @@
1
  import streamlit as st
2
- import tensorflow as tf
3
- from tensorflow.keras.models import load_model
4
- from tensorflow.keras.preprocessing.sequence import pad_sequences
5
- import numpy as np
6
- import pickle
7
- import re
8
- import os
9
-
10
- # Set Streamlit configuration (instead of using .streamlit/config.toml)
11
- st.set_page_config(
12
- page_title="Next Word Predictor",
13
- page_icon="🔮",
14
- layout="centered"
15
- )
16
-
17
- # Set other Streamlit configurations
18
- st.set_option('server.headless', True)
19
- st.set_option('server.port', 8501)
20
- st.set_option('server.enableCORS', False)
21
- st.set_option('server.enableXsrfProtection', False)
22
-
23
- # Custom CSS for styling
24
- st.markdown("""
25
- <style>
26
- .main {
27
- background-color: #f5f5f5;
28
- }
29
- .stTextInput>div>div>input {
30
- background-color: #ffffff;
31
- color: #000000;
32
- }
33
- .prediction-box {
34
- background-color: #e6f7ff;
35
- padding: 15px;
36
- border-radius: 10px;
37
- border-left: 5px solid #1890ff;
38
- margin-top: 20px;
39
- }
40
- </style>
41
- """, unsafe_allow_html=True)
42
 
 
43
  @st.cache_resource
44
- def load_models():
45
- """Load the model and tokenizer with caching"""
46
- try:
47
- # Check if files exist
48
- if not os.path.exists('nextword_lstm_model.h5'):
49
- st.error("Model file not found!")
50
- return None, None
51
-
52
- if not os.path.exists('tokenizer.pkl'):
53
- st.error("Tokenizer file not found!")
54
- return None, None
55
-
56
- # Load model with custom objects if needed
57
- model = load_model('nextword_lstm_model.h5', compile=False)
58
- model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
59
-
60
- # Load tokenizer
61
- with open('tokenizer.pkl', 'rb') as f:
62
- tokenizer = pickle.load(f)
63
-
64
- st.success("Model and tokenizer loaded successfully!")
65
- return model, tokenizer
66
-
67
- except Exception as e:
68
- st.error(f"Error loading model: {str(e)}")
69
- return None, None
70
 
71
- def predict_next_word(model, tokenizer, seed_text, max_seq_len):
72
- """Predict the next word given a seed text"""
73
- try:
74
- # Clean and preprocess the input text
75
- seed_text = re.sub(r'[^\w\s]', '', seed_text.lower()).strip()
76
-
77
- if not seed_text:
78
- return "Please enter some text"
79
-
80
- # Convert text to sequence
81
- token_list = tokenizer.texts_to_sequences([seed_text])
82
-
83
- if not token_list or not token_list[0]:
84
- return "Please enter more meaningful text"
85
-
86
- token_list = token_list[0]
87
-
88
- # Pad sequences
89
- token_list = pad_sequences([token_list], maxlen=max_seq_len-1, padding='pre')
90
-
91
- # Make prediction
92
- predicted = model.predict(token_list, verbose=0)
93
- predicted_word_index = np.argmax(predicted, axis=-1)[0]
94
-
95
- # Find the word corresponding to the predicted index
96
- for word, index in tokenizer.word_index.items():
97
- if index == predicted_word_index:
98
- return word.capitalize()
99
-
100
- return "No prediction available"
101
- except Exception as e:
102
- return f"Error in prediction: {str(e)}"
103
 
104
- def main():
105
- st.title("🔮 Next Word Predictor")
106
- st.markdown("Enter some text and I'll predict the next word using an LSTM model trained on a large corpus.")
107
-
108
- # Debug: Show files in directory
109
- st.sidebar.write("Debug Info:")
110
- st.sidebar.write("Files in directory:", os.listdir('.'))
111
-
112
- # Load model and tokenizer
113
- with st.spinner("Loading model..."):
114
- model, tokenizer = load_models()
115
-
116
- if model is None or tokenizer is None:
117
- st.error("Failed to load the model. Please check if model files are available.")
118
- return
119
-
120
- # Calculate max sequence length (you might want to set this based on your training)
121
- max_seq_len = 20 # Adjust based on your model's training parameters
122
-
123
- # Input section
124
- st.subheader("Enter your text")
125
- seed_text = st.text_input(
126
- "Start typing...",
127
- placeholder="Type something like 'I am going to'",
128
- key="text_input"
129
- )
130
-
131
- # Prediction button
132
- if st.button("Predict Next Word", type="primary"):
133
- if seed_text.strip():
134
- with st.spinner("Predicting..."):
135
- next_word = predict_next_word(model, tokenizer, seed_text, max_seq_len)
136
-
137
- # Display result
138
- st.markdown(f"""
139
- <div class="prediction-box">
140
- <h3>Prediction</h3>
141
- <p style="font-size: 20px; margin-bottom: 0;"><strong>{seed_text} <span style="color: #1890ff;">{next_word}</span></strong></p>
142
- </div>
143
- """, unsafe_allow_html=True)
144
- else:
145
- st.warning("Please enter some text first!")
146
 
147
- # Information section
148
- st.markdown("---")
149
- st.subheader("About")
150
- st.markdown("""
151
- This app uses an LSTM neural network trained on a large text corpus to predict the next word in a sequence.
152
-
153
- **How it works:**
154
- - The model was trained on 20,000 text samples
155
- - Uses word embeddings and LSTM layers
156
- - Predicts the most likely next word based on context
157
-
158
- **Try phrases like:**
159
- - "I am going to"
160
- - "The weather is"
161
- - "Machine learning is"
162
- """)
163
 
164
- if __name__ == "__main__":
165
- main()
 
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ # Load model & tokenizer
6
  @st.cache_resource
7
+ def load_model():
8
+ tokenizer = AutoTokenizer.from_pretrained(".")
9
+ model = AutoModelForCausalLM.from_pretrained(".")
10
+ return tokenizer, model
11
+
12
+ tokenizer, model = load_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ st.title("📝 Next Word Prediction App")
15
+ st.write("Type a sentence and let the model suggest the next word!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ # User input
18
+ text = st.text_input("Enter your sentence:", "")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ if st.button("Predict Next Word") and text:
21
+ inputs = tokenizer(text, return_tensors="pt")
22
+ with torch.no_grad():
23
+ outputs = model.generate(**inputs, max_new_tokens=1)
24
+ prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Extract only the new part
27
+ predicted_next = prediction[len(text):].strip()
28
+ st.success(f"**Predicted next word:** {predicted_next}")