File size: 974 Bytes
2bf3d21 b279884 2bf3d21 b279884 2bf3d21 b279884 2bf3d21 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | import streamlit as st
from persona_vectors.extraction import MaskStrategy
def render_mask_strategy_select(
*,
key: str,
last_key: str,
help_text: str,
remember_key: str | None = None,
) -> MaskStrategy:
last_strategy = st.session_state.get(
remember_key,
st.session_state.get(last_key, MaskStrategy.ANSWER_MEAN.value),
)
strategies = list(MaskStrategy)
selected = st.selectbox(
"Mask strategy",
options=strategies,
index=next(
(
idx
for idx, strategy in enumerate(strategies)
if strategy.value == last_strategy
),
0,
),
format_func=lambda strategy: strategy.value.replace("_", " ").title(),
key=key,
help=help_text,
)
st.session_state[last_key] = selected.value
if remember_key is not None:
st.session_state[remember_key] = selected.value
return selected
|