Spaces:
Sleeping
Sleeping
Yeb Havinga
commited on
Commit
·
5cf4ee2
1
Parent(s):
a9f2b23
Make seed configurable
Browse files
app.py
CHANGED
|
@@ -87,12 +87,6 @@ def instantiate_models():
|
|
| 87 |
p["pipeline"].load()
|
| 88 |
|
| 89 |
|
| 90 |
-
def set_new_seed():
|
| 91 |
-
seed = randint(0, 2**32 - 1)
|
| 92 |
-
set_seed(seed)
|
| 93 |
-
return seed
|
| 94 |
-
|
| 95 |
-
|
| 96 |
def main():
|
| 97 |
st.set_page_config( # Alternate names: setup_page, page, layout
|
| 98 |
page_title="Netherator", # String or None. Strings get appended with "• Streamlit".
|
|
@@ -122,9 +116,6 @@ def main():
|
|
| 122 |
|
| 123 |
st.session_state["text"] = st.text_area("Enter text", st.session_state.prompt_box)
|
| 124 |
|
| 125 |
-
# min_length = st.sidebar.number_input(
|
| 126 |
-
# "Min length", min_value=10, max_value=150, value=75
|
| 127 |
-
# )
|
| 128 |
max_length = st.sidebar.number_input(
|
| 129 |
"Lengte van de tekst",
|
| 130 |
value=200,
|
|
@@ -140,8 +131,28 @@ def main():
|
|
| 140 |
"Num return sequences", min_value=1, max_value=5, value=1
|
| 141 |
)
|
| 142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
if sampling_mode := st.sidebar.selectbox(
|
| 144 |
-
|
| 145 |
):
|
| 146 |
if sampling_mode == "Beam Search":
|
| 147 |
num_beams = st.sidebar.number_input(
|
|
@@ -200,7 +211,7 @@ and the [Huggingface text generation interface doc](https://huggingface.co/trans
|
|
| 200 |
estimate = int(estimate)
|
| 201 |
|
| 202 |
with st.spinner(
|
| 203 |
-
|
| 204 |
):
|
| 205 |
memory = psutil.virtual_memory()
|
| 206 |
generator = next(
|
|
@@ -211,7 +222,7 @@ and the [Huggingface text generation interface doc](https://huggingface.co/trans
|
|
| 211 |
),
|
| 212 |
None,
|
| 213 |
)
|
| 214 |
-
seed
|
| 215 |
time_start = time.time()
|
| 216 |
result = generator.get_text(text=st.session_state.text, **params)
|
| 217 |
time_end = time.time()
|
|
|
|
| 87 |
p["pipeline"].load()
|
| 88 |
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
def main():
|
| 91 |
st.set_page_config( # Alternate names: setup_page, page, layout
|
| 92 |
page_title="Netherator", # String or None. Strings get appended with "• Streamlit".
|
|
|
|
| 116 |
|
| 117 |
st.session_state["text"] = st.text_area("Enter text", st.session_state.prompt_box)
|
| 118 |
|
|
|
|
|
|
|
|
|
|
| 119 |
max_length = st.sidebar.number_input(
|
| 120 |
"Lengte van de tekst",
|
| 121 |
value=200,
|
|
|
|
| 131 |
"Num return sequences", min_value=1, max_value=5, value=1
|
| 132 |
)
|
| 133 |
|
| 134 |
+
seed_placeholder = st.sidebar.empty()
|
| 135 |
+
if "seed" not in st.session_state:
|
| 136 |
+
print(f"Session state {st.session_state} does not contain seed")
|
| 137 |
+
st.session_state["seed"] = 4162549114
|
| 138 |
+
print(f"Seed is set to: {st.session_state['seed']}")
|
| 139 |
+
|
| 140 |
+
seed = seed_placeholder.number_input(
|
| 141 |
+
"Seed", min_value=0, max_value=2 ** 32 - 1, value=st.session_state["seed"]
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
def set_random_seed():
|
| 145 |
+
st.session_state["seed"] = randint(0, 2 ** 32 - 1)
|
| 146 |
+
seed = seed_placeholder.number_input(
|
| 147 |
+
"Seed", min_value=0, max_value=2 ** 32 - 1, value=st.session_state["seed"]
|
| 148 |
+
)
|
| 149 |
+
print(f"New random seed set to: {seed}")
|
| 150 |
+
|
| 151 |
+
if st.button("New random seed?"):
|
| 152 |
+
set_random_seed()
|
| 153 |
+
|
| 154 |
if sampling_mode := st.sidebar.selectbox(
|
| 155 |
+
"select a Mode", index=0, options=["Top-k Sampling", "Beam Search"]
|
| 156 |
):
|
| 157 |
if sampling_mode == "Beam Search":
|
| 158 |
num_beams = st.sidebar.number_input(
|
|
|
|
| 211 |
estimate = int(estimate)
|
| 212 |
|
| 213 |
with st.spinner(
|
| 214 |
+
text=f"Please wait ~ {estimate} second{'s' if estimate != 1 else ''} while getting results ..."
|
| 215 |
):
|
| 216 |
memory = psutil.virtual_memory()
|
| 217 |
generator = next(
|
|
|
|
| 222 |
),
|
| 223 |
None,
|
| 224 |
)
|
| 225 |
+
set_seed(seed)
|
| 226 |
time_start = time.time()
|
| 227 |
result = generator.get_text(text=st.session_state.text, **params)
|
| 228 |
time_end = time.time()
|