Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import pickle | |
| from tensorflow.keras.models import load_model | |
| from tensorflow.keras.preprocessing.sequence import pad_sequences | |
| from tensorflow.keras.preprocessing.text import Tokenizer | |
| from tensorflow.keras.utils import get_custom_objects | |
| from tensorflow.keras.initializers import Orthogonal | |
| # Register the Orthogonal initializer for compatibility | |
| get_custom_objects()["Orthogonal"] = Orthogonal | |
| # Load the LSTM model | |
| def load_lstm_model(): | |
| try: | |
| model = load_model('next_word_lstm.h5') | |
| st.success("LSTM model loaded successfully!") | |
| return model | |
| except Exception as e: | |
| st.error(f"Error loading the model: {e}") | |
| return None | |
| # Load the tokenizer | |
| def load_tokenizer(): | |
| try: | |
| with open('tokenizer.pickle', 'rb') as handle: | |
| tokenizer = pickle.load(handle) | |
| st.success("Tokenizer loaded successfully!") | |
| return tokenizer | |
| except Exception as e: | |
| st.error(f"Error loading the tokenizer: {e}") | |
| return None | |
| # Function to predict the next word | |
| def predict_next_word(model, tokenizer, text, max_sequence_len): | |
| try: | |
| # Convert the input text into a sequence of tokens | |
| token_list = tokenizer.texts_to_sequences([text])[0] | |
| # Ensure the sequence length matches max_sequence_len - 1 | |
| if len(token_list) >= max_sequence_len: | |
| token_list = token_list[-(max_sequence_len - 1):] | |
| # Pad the sequence to the required length | |
| token_list = pad_sequences([token_list], maxlen=max_sequence_len - 1, padding='pre') | |
| # Predict the next word | |
| predicted = model.predict(token_list, verbose=0) | |
| predicted_index = np.argmax(predicted, axis=1)[0] | |
| # Map the predicted index back to a word | |
| for word, index in tokenizer.word_index.items(): | |
| if index == predicted_index: | |
| return word | |
| return None | |
| except Exception as e: | |
| st.error(f"Error during prediction: {e}") | |
| return None | |
| # Streamlit App | |
| def main(): | |
| st.title("Next Word Prediction with LSTM") | |
| st.write("This app predicts the next word in a sentence using an LSTM model trained on text data.") | |
| # Load model and tokenizer | |
| model = load_lstm_model() | |
| tokenizer = load_tokenizer() | |
| # Input text box | |
| input_text = st.text_input("Enter a sentence:") | |
| # Predict button | |
| if st.button("Predict Next Word"): | |
| if model is not None and tokenizer is not None: | |
| max_sequence_len = model.input_shape[1] + 1 | |
| next_word = predict_next_word(model, tokenizer, input_text, max_sequence_len) | |
| if next_word: | |
| st.write(f"Predicted next word: **{next_word}**") | |
| else: | |
| st.warning("Could not predict the next word. Please try a different input.") | |
| else: | |
| st.error("Model or tokenizer not loaded properly.") | |
| if __name__ == "__main__": | |
| main() | |