| import streamlit as st |
| from persona_vectors.extraction import MaskStrategy |
|
|
|
|
| def render_mask_strategy_select( |
| *, |
| key: str, |
| last_key: str, |
| help_text: str, |
| remember_key: str | None = None, |
| ) -> MaskStrategy: |
| last_strategy = st.session_state.get( |
| remember_key, |
| st.session_state.get(last_key, MaskStrategy.ANSWER_MEAN.value), |
| ) |
| strategies = list(MaskStrategy) |
| selected = st.selectbox( |
| "Mask strategy", |
| options=strategies, |
| index=next( |
| ( |
| idx |
| for idx, strategy in enumerate(strategies) |
| if strategy.value == last_strategy |
| ), |
| 0, |
| ), |
| format_func=lambda strategy: strategy.value.replace("_", " ").title(), |
| key=key, |
| help=help_text, |
| ) |
| st.session_state[last_key] = selected.value |
| if remember_key is not None: |
| st.session_state[remember_key] = selected.value |
| return selected |
|
|