Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from huggingface_hub import snapshot_download | |
| import os # utility library | |
| # libraries to load the model and serve inference | |
| import tensorflow_text | |
| import tensorflow as tf | |
| def main(): | |
| st.title("Interactive demo: T5 Multitasking Demo") | |
| st.sidebar.image("https://i.gzn.jp/img/2020/02/25/google-ai-t5/01.png") | |
| saved_model_path = load_model_cache() | |
| # Model is loaded in st.session_state to remain stateless across reloading | |
| if 'model' not in st.session_state: | |
| st.session_state.model = tf.saved_model.load(saved_model_path, ["serve"]) | |
| dashboard(st.session_state.model) | |
| def load_model_cache(): | |
| """Function to retrieve the model from HuggingFace Hub and cache it using st.cache wrapper | |
| """ | |
| CACHE_DIR = "hfhub_cache" # where the library's fork would be stored once downloaded | |
| if not os.path.exists(CACHE_DIR): | |
| os.mkdir(CACHE_DIR) | |
| # download the files from huggingface repo and load the model with tensorflow | |
| snapshot_download(repo_id="stevekola/T5", cache_dir=CACHE_DIR) | |
| saved_model_path = os.path.join(CACHE_DIR, os.listdir(CACHE_DIR)[0]) | |
| return saved_model_path | |
| def dashboard(model): | |
| """Function to display the inputs and results | |
| params: | |
| model stateless model to run inference from | |
| """ | |
| task_type = st.sidebar.radio("Task Type", | |
| [ | |
| "Translate English to French", | |
| "Translate English to German", | |
| "Translate English to Romanian", | |
| "Grammatical Correctness of Sentence", | |
| "Text Summarization", | |
| "Document Similarity Score" | |
| ]) | |
| default_sentence = "I am Steven and I live in Lagos, Nigeria." | |
| text_summarization_sentence = "I don't care about those doing the comparison, but comparing \ | |
| the Ghanaian Jollof Rice to Nigerian Jollof Rice is an insult to Nigerians." | |
| doc_similarity_sentence1 = "I reside in the commercial capital city of Nigeria, which is Lagos." | |
| doc_similarity_sentence2 = "I live in Lagos." | |
| help_msg = "You could either type in the sentences to run inferences on or use the upload button to \ | |
| upload text files containing those sentences. The input sentence box, by default, displays sample \ | |
| texts or the texts in the files that you've uploaded. Feel free to erase them and type in new sentences." | |
| if task_type.startswith("Document Similarity"): # document similarity requires two documents | |
| uploaded_file = upload_files(help_msg, text="Upload 2 documents for similarity check", accept_multiple_files=True) | |
| if uploaded_file: | |
| sentence1 = st.text_area("Enter first document/sentence", uploaded_file[0], help=help_msg) | |
| sentence2 = st.text_area("Enter second document/sentence", uploaded_file[1], help=help_msg) | |
| else: | |
| sentence1 = st.text_area("Enter first document/sentence", doc_similarity_sentence1) | |
| sentence2 = st.text_area("Enter second document/sentence", doc_similarity_sentence2) | |
| sentence = sentence1 + "---" + sentence2 # to be processed like other tasks' single sentences | |
| else: | |
| uploaded_file = upload_files(help_msg) | |
| if uploaded_file: | |
| sentence = st.text_area("Enter sentence", uploaded_file, help=help_msg) | |
| elif task_type.startswith("Text Summarization"): # text summarization's default input should be longer | |
| sentence = st.text_area("Enter sentence", text_summarization_sentence, help=help_msg) | |
| else: | |
| sentence = st.text_area("Enter sentence", default_sentence, help=help_msg) | |
| st.write("**Output Text**") | |
| with st.spinner("Waiting for prediction..."): # spinner while model is running inferences | |
| output_text = predict(task_type, sentence, model) | |
| st.write(output_text) | |
| try: # to workaround the environment's Streamlit version | |
| st.download_button("Download output text", output_text) | |
| except AttributeError: | |
| st.text("File download not enabled for this Streamlit version \U0001F612") | |
| def upload_files(help_msg, text="Upload a text file here", accept_multiple_files=False): | |
| """Function to upload text files and return as string text | |
| params: | |
| text Display label for the upload button | |
| accept_multiple_files params for the file_uploader function to accept more than a file | |
| returns: | |
| a string or a list of strings (in case of multiple files being uploaded) | |
| """ | |
| def upload(): | |
| uploaded_files = st.file_uploader(label="Upload text files only", | |
| type="txt", help=help_msg, | |
| accept_multiple_files=accept_multiple_files) | |
| if st.button("Process"): | |
| if not uploaded_files: | |
| st.write("**No file uploaded!**") | |
| return None | |
| st.write("**Upload successful!**") | |
| if type(uploaded_files) == list: | |
| return [f.read().decode("utf-8") for f in uploaded_files] | |
| return uploaded_files.read().decode("utf-8") | |
| try: # to workaround the environment's Streamlit version | |
| with st.expander(text): | |
| return upload() | |
| except AttributeError: | |
| return upload() | |
| def predict(task_type, sentence, model): | |
| """Function to parse the user inputs, run the parsed text through the | |
| model and return output in a readable format. | |
| params: | |
| task_type sentence representing the type of task to run on T5 model | |
| sentence sentence to get inference on | |
| model model to get inferences from | |
| returns: | |
| text decoded into a human-readable format. | |
| """ | |
| task_dict = { | |
| "Translate English to French": "Translate English to French", | |
| "Translate English to German": "Translate English to German", | |
| "Translate English to Romanian": "Translate English to Romanian", | |
| "Grammatical Correctness of Sentence": "cola sentence", | |
| "Text Summarization": "summarize", | |
| "Document Similarity Score": "stsb", | |
| } | |
| question = f"{task_dict[task_type]}: {sentence}" # parsing the user inputs into a format recognized by T5 | |
| # Document Similarity takes in two sentences so it has to be parsed in a separate manner | |
| if task_type.startswith("Document Similarity"): | |
| sentences = sentence.split('---') | |
| question = f"{task_dict[task_type]} sentence1: {sentences[0]} sentence2: {sentences[1]}" | |
| return predict_fn([question], model)[0].decode('utf-8') | |
| def predict_fn(x, model): | |
| """Function to get inferences from model on live data points. | |
| params: | |
| x input text to run get output on | |
| model model to run inferences from | |
| returns: | |
| a numpy array representing the output | |
| """ | |
| return model.signatures['serving_default'](tf.constant(x))['outputs'].numpy() | |
| if __name__ == "__main__": | |
| main() |