LSTM_Project / app.py
MogulojuSai's picture
lstm
2eaa97d verified
import streamlit as st
import numpy as np
import pickle
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.utils import get_custom_objects
from tensorflow.keras.initializers import Orthogonal
# Register the Orthogonal initializer for compatibility
get_custom_objects()["Orthogonal"] = Orthogonal
# Load the LSTM model
def load_lstm_model():
try:
model = load_model('next_word_lstm.h5')
st.success("LSTM model loaded successfully!")
return model
except Exception as e:
st.error(f"Error loading the model: {e}")
return None
# Load the tokenizer
def load_tokenizer():
try:
with open('tokenizer.pickle', 'rb') as handle:
tokenizer = pickle.load(handle)
st.success("Tokenizer loaded successfully!")
return tokenizer
except Exception as e:
st.error(f"Error loading the tokenizer: {e}")
return None
# Function to predict the next word
def predict_next_word(model, tokenizer, text, max_sequence_len):
try:
# Convert the input text into a sequence of tokens
token_list = tokenizer.texts_to_sequences([text])[0]
# Ensure the sequence length matches max_sequence_len - 1
if len(token_list) >= max_sequence_len:
token_list = token_list[-(max_sequence_len - 1):]
# Pad the sequence to the required length
token_list = pad_sequences([token_list], maxlen=max_sequence_len - 1, padding='pre')
# Predict the next word
predicted = model.predict(token_list, verbose=0)
predicted_index = np.argmax(predicted, axis=1)[0]
# Map the predicted index back to a word
for word, index in tokenizer.word_index.items():
if index == predicted_index:
return word
return None
except Exception as e:
st.error(f"Error during prediction: {e}")
return None
# Streamlit App
def main():
st.title("Next Word Prediction with LSTM")
st.write("This app predicts the next word in a sentence using an LSTM model trained on text data.")
# Load model and tokenizer
model = load_lstm_model()
tokenizer = load_tokenizer()
# Input text box
input_text = st.text_input("Enter a sentence:")
# Predict button
if st.button("Predict Next Word"):
if model is not None and tokenizer is not None:
max_sequence_len = model.input_shape[1] + 1
next_word = predict_next_word(model, tokenizer, input_text, max_sequence_len)
if next_word:
st.write(f"Predicted next word: **{next_word}**")
else:
st.warning("Could not predict the next word. Please try a different input.")
else:
st.error("Model or tokenizer not loaded properly.")
if __name__ == "__main__":
main()