Update src/streamlit_app.py
Browse files- src/streamlit_app.py +6 -6
src/streamlit_app.py
CHANGED
|
@@ -7,8 +7,8 @@ import os
|
|
| 7 |
token = os.environ.get("token")
|
| 8 |
|
| 9 |
|
| 10 |
-
def instantiate_gpt2(max_length_ : int, num_return_sequences : int, text : str) -> dict:
|
| 11 |
-
pipe = pipeline(task='text-generation', model='Iscte-Sintra/
|
| 12 |
results = pipe(text, max_length=max_length_, num_return_sequences=num_return_sequences)
|
| 13 |
return results
|
| 14 |
|
|
@@ -21,15 +21,15 @@ def instantiate_albertina(top_k : int, text : str) -> dict:
|
|
| 21 |
return pipe(text, top_k=top_k)
|
| 22 |
|
| 23 |
|
| 24 |
-
def
|
| 25 |
try:
|
| 26 |
-
st.title("
|
| 27 |
max_length : int = st.sidebar.slider("Max Length", 10, 200)
|
| 28 |
num_return_sequences : int = st.sidebar.number_input('Number of Sequences to be Generated', min_value=1, max_value=10, value=1, step=1)
|
| 29 |
text : str = st.text_area("Text", "Katxor sta tr谩s di p贸rta.", height=75)
|
| 30 |
|
| 31 |
if st.button("Submit"):
|
| 32 |
-
results = instantiate_gpt2(max_length, num_return_sequences, text)
|
| 33 |
if results:
|
| 34 |
for result in results:
|
| 35 |
st.write(f"**Generated Text**: {result['generated_text']}")
|
|
@@ -114,4 +114,4 @@ if model_dict[selected_model] == 1:
|
|
| 114 |
elif model_dict[selected_model] == 2:
|
| 115 |
build_albertina_page()
|
| 116 |
else:
|
| 117 |
-
build_gpt2_page()
|
|
|
|
| 7 |
token = os.environ.get("token")
|
| 8 |
|
| 9 |
|
| 10 |
+
def instantiate_gpt2(model_name: str,max_length_ : int, num_return_sequences : int, text : str) -> dict:
|
| 11 |
+
pipe = pipeline(task='text-generation', model=f'Iscte-Sintra/{model_name}', tokenizer=f'Iscte-Sintra/{model_name}', token=token, truncation=True)
|
| 12 |
results = pipe(text, max_length=max_length_, num_return_sequences=num_return_sequences)
|
| 13 |
return results
|
| 14 |
|
|
|
|
| 21 |
return pipe(text, top_k=top_k)
|
| 22 |
|
| 23 |
|
| 24 |
+
def build_decoder_page(model_name):
|
| 25 |
try:
|
| 26 |
+
st.title(f"{model_name} : Decoder - Text Generation Task")
|
| 27 |
max_length : int = st.sidebar.slider("Max Length", 10, 200)
|
| 28 |
num_return_sequences : int = st.sidebar.number_input('Number of Sequences to be Generated', min_value=1, max_value=10, value=1, step=1)
|
| 29 |
text : str = st.text_area("Text", "Katxor sta tr谩s di p贸rta.", height=75)
|
| 30 |
|
| 31 |
if st.button("Submit"):
|
| 32 |
+
results = instantiate_gpt2(model_name ,max_length, num_return_sequences, text)
|
| 33 |
if results:
|
| 34 |
for result in results:
|
| 35 |
st.write(f"**Generated Text**: {result['generated_text']}")
|
|
|
|
| 114 |
elif model_dict[selected_model] == 2:
|
| 115 |
build_albertina_page()
|
| 116 |
else:
|
| 117 |
+
build_gpt2_page('GPT2-Kriolu')
|