Spaces:
Runtime error
Runtime error
Commit ·
d352f8e
1
Parent(s): d9b4b87
add file uploader & download functions
Browse files- app.py +70 -26
- requirements.txt +1 -1
app.py
CHANGED
|
@@ -10,8 +10,7 @@ import tensorflow as tf
|
|
| 10 |
|
| 11 |
def main():
|
| 12 |
st.title("Interactive demo: T5 Multitasking Demo")
|
| 13 |
-
st.
|
| 14 |
-
text summarization, document similarity, and grammatical correctness of sentences.**")
|
| 15 |
saved_model_path = load_model_cache()
|
| 16 |
|
| 17 |
# Model is loaded in st.session_state to remain stateless across reloading
|
|
@@ -33,43 +32,88 @@ def load_model_cache():
|
|
| 33 |
snapshot_download(repo_id="stevekola/T5", cache_dir=CACHE_DIR)
|
| 34 |
saved_model_path = os.path.join(CACHE_DIR, os.listdir(CACHE_DIR)[0])
|
| 35 |
return saved_model_path
|
| 36 |
-
|
| 37 |
|
| 38 |
def dashboard(model):
|
| 39 |
-
"""
|
| 40 |
params:
|
| 41 |
model stateless model to run inference from
|
| 42 |
"""
|
| 43 |
-
st.sidebar.write("**Select the Task Type over here**")
|
| 44 |
task_type = st.sidebar.radio("Task Type",
|
| 45 |
[
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
])
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
else:
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
st.write("**Output Text**")
|
| 68 |
-
with st.spinner("
|
| 69 |
output_text = predict(task_type, sentence, model)
|
| 70 |
st.write(output_text)
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
def predict(task_type, sentence, model):
|
| 75 |
"""Function to parse the user inputs, run the parsed text through the
|
|
|
|
| 10 |
|
| 11 |
def main():
|
| 12 |
st.title("Interactive demo: T5 Multitasking Demo")
|
| 13 |
+
st.sidebar.image("https://i.gzn.jp/img/2020/02/25/google-ai-t5/01.png")
|
|
|
|
| 14 |
saved_model_path = load_model_cache()
|
| 15 |
|
| 16 |
# Model is loaded in st.session_state to remain stateless across reloading
|
|
|
|
| 32 |
snapshot_download(repo_id="stevekola/T5", cache_dir=CACHE_DIR)
|
| 33 |
saved_model_path = os.path.join(CACHE_DIR, os.listdir(CACHE_DIR)[0])
|
| 34 |
return saved_model_path
|
| 35 |
+
|
| 36 |
|
| 37 |
def dashboard(model):
|
| 38 |
+
"""Function to display the inputs and results
|
| 39 |
params:
|
| 40 |
model stateless model to run inference from
|
| 41 |
"""
|
|
|
|
| 42 |
task_type = st.sidebar.radio("Task Type",
|
| 43 |
[
|
| 44 |
+
"Translate English to French",
|
| 45 |
+
"Translate English to German",
|
| 46 |
+
"Translate English to Romanian",
|
| 47 |
+
"Grammatical Correctness of Sentence",
|
| 48 |
+
"Text Summarization",
|
| 49 |
+
"Document Similarity Score"
|
| 50 |
])
|
| 51 |
+
|
| 52 |
+
default_sentence = "I am Steven and I live in Lagos, Nigeria."
|
| 53 |
+
text_summarization_sentence = "I don't care about those doing the comparison, but comparing \
|
| 54 |
+
the Ghanaian Jollof Rice to Nigerian Jollof Rice is an insult to Nigerians."
|
| 55 |
+
doc_similarity_sentence1 = "I reside in the commercial capital city of Nigeria, which is Lagos."
|
| 56 |
+
doc_similarity_sentence2 = "I live in Lagos."
|
| 57 |
+
help_msg = "You could either type in the sentences to run inferences on or use the upload button to \
|
| 58 |
+
upload text files containing those sentences. The input sentence box, by default, displays sample \
|
| 59 |
+
texts or the texts in the files that you've uploaded. Feel free to erase them and type in new sentences."
|
| 60 |
+
|
| 61 |
+
if task_type.startswith("Document Similarity"): # document similarity requires two documents
|
| 62 |
+
uploaded_file = upload_files(help_msg, text="Upload 2 documents for similarity check", accept_multiple_files=True)
|
| 63 |
+
if uploaded_file:
|
| 64 |
+
sentence1 = st.text_area("Enter first document/sentence", uploaded_file[0], help=help_msg)
|
| 65 |
+
sentence2 = st.text_area("Enter second document/sentence", uploaded_file[1], help=help_msg)
|
| 66 |
+
else:
|
| 67 |
+
sentence1 = st.text_area("Enter first document/sentence", doc_similarity_sentence1)
|
| 68 |
+
sentence2 = st.text_area("Enter second document/sentence", doc_similarity_sentence2)
|
| 69 |
+
sentence = sentence1 + "---" + sentence2 # to be processed like other tasks' single sentences
|
| 70 |
else:
|
| 71 |
+
uploaded_file = upload_files(help_msg)
|
| 72 |
+
if uploaded_file:
|
| 73 |
+
sentence = st.text_area("Enter sentence", uploaded_file, help=help_msg)
|
| 74 |
+
elif task_type.startswith("Text Summarization"): # text summarization's default input should be longer
|
| 75 |
+
sentence = st.text_area("Enter sentence", text_summarization_sentence, help=help_msg)
|
| 76 |
+
else:
|
| 77 |
+
sentence = st.text_area("Enter sentence", default_sentence, help=help_msg)
|
| 78 |
|
| 79 |
st.write("**Output Text**")
|
| 80 |
+
with st.spinner("Waiting for prediction..."): # spinner while model is running inferences
|
| 81 |
output_text = predict(task_type, sentence, model)
|
| 82 |
st.write(output_text)
|
| 83 |
+
try: # to workaround the environment's Streamlit version
|
| 84 |
+
st.download_button("Download output text", output_text)
|
| 85 |
+
except AttributeError:
|
| 86 |
+
st.text("File download not enabled for this Streamlit version \U0001F612")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def upload_files(help_msg, text="Upload a text file here", accept_multiple_files=False):
|
| 90 |
+
"""Function to upload text files and return as string text
|
| 91 |
+
params:
|
| 92 |
+
text Display label for the upload button
|
| 93 |
+
accept_multiple_files params for the file_uploader function to accept more than a file
|
| 94 |
+
returns:
|
| 95 |
+
a string or a list of strings (in case of multiple files being uploaded)
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
def upload():
|
| 99 |
+
uploaded_files = st.file_uploader(label="Upload text files only",
|
| 100 |
+
type="txt", help=help_msg,
|
| 101 |
+
accept_multiple_files=accept_multiple_files)
|
| 102 |
+
if st.button("Process"):
|
| 103 |
+
if not uploaded_files:
|
| 104 |
+
st.write("**No file uploaded!**")
|
| 105 |
+
return None
|
| 106 |
+
st.write("**Upload successful!**")
|
| 107 |
+
if type(uploaded_files) == list:
|
| 108 |
+
return [f.read().decode("utf-8") for f in uploaded_files]
|
| 109 |
+
return uploaded_files.read().decode("utf-8")
|
| 110 |
+
|
| 111 |
+
try: # to workaround the environment's Streamlit version
|
| 112 |
+
with st.expander(text):
|
| 113 |
+
return upload()
|
| 114 |
+
except AttributeError:
|
| 115 |
+
return upload()
|
| 116 |
+
|
| 117 |
|
| 118 |
def predict(task_type, sentence, model):
|
| 119 |
"""Function to parse the user inputs, run the parsed text through the
|
requirements.txt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
t5
|
| 2 |
huggingface_hub
|
| 3 |
-
streamlit
|
|
|
|
| 1 |
t5
|
| 2 |
huggingface_hub
|
| 3 |
+
streamlit==1.0.0
|