import streamlit as st
import tensorflow as tf
st.markdown(
"""
Next Character Prediction
""",
unsafe_allow_html=True
)
characters = st.text_input("Enter the characters to begin with")
epoch = st.radio("Enter the epoch value",['1','5','10','25','50','100','200'])
if characters and epoch:
one_step_reloaded = tf.saved_model.load(f'model_epoch{epoch}')
states = None
next_char = tf.constant([characters])
result = [next_char]
for n in range(100):
next_char, states = one_step_reloaded.generate_one_step(next_char, states=states)
result.append(next_char)
st.write(tf.strings.join(result)[0].numpy().decode("utf-8"))