File size: 974 Bytes
2bf3d21
 
 
 
 
 
 
 
 
b279884
2bf3d21
b279884
 
 
 
2bf3d21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b279884
 
2bf3d21
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import streamlit as st
from persona_vectors.extraction import MaskStrategy


def render_mask_strategy_select(
    *,
    key: str,
    last_key: str,
    help_text: str,
    remember_key: str | None = None,
) -> MaskStrategy:
    last_strategy = st.session_state.get(
        remember_key,
        st.session_state.get(last_key, MaskStrategy.ANSWER_MEAN.value),
    )
    strategies = list(MaskStrategy)
    selected = st.selectbox(
        "Mask strategy",
        options=strategies,
        index=next(
            (
                idx
                for idx, strategy in enumerate(strategies)
                if strategy.value == last_strategy
            ),
            0,
        ),
        format_func=lambda strategy: strategy.value.replace("_", " ").title(),
        key=key,
        help=help_text,
    )
    st.session_state[last_key] = selected.value
    if remember_key is not None:
        st.session_state[remember_key] = selected.value
    return selected