File size: 2,195 Bytes
ac2020e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""Form components for configuration in LegisQA"""

import streamlit as st
from legisqa_local.config.models import PROVIDER_MODELS, CONGRESS_NUMBERS, SPONSOR_PARTIES


def get_generative_config(key_prefix: str) -> dict:
    """Render generative model configuration form"""
    output = {}

    key = "provider"
    output[key] = st.selectbox(
        label=key, options=PROVIDER_MODELS.keys(), key=f"{key_prefix}|{key}"
    )

    key = "model_name"
    output[key] = st.selectbox(
        label=key,
        options=PROVIDER_MODELS[output["provider"]],
        key=f"{key_prefix}|{key}",
    )

    key = "temperature"
    output[key] = st.slider(
        key,
        min_value=0.0,
        max_value=2.0,
        value=0.0,
        key=f"{key_prefix}|{key}",
    )

    key = "max_output_tokens"
    output[key] = st.slider(
        key,
        min_value=8192,
        max_value=16_384,
        key=f"{key_prefix}|{key}",
    )

    key = "should_escape_markdown"
    output[key] = st.checkbox(
        key,
        value=False,
        key=f"{key_prefix}|{key}",
    )

    key = "should_add_legis_urls"
    output[key] = st.checkbox(
        key,
        value=True,
        key=f"{key_prefix}|{key}",
    )

    return output


def get_retrieval_config(key_prefix: str) -> dict:
    """Render retrieval configuration form"""
    output = {}

    key = "n_ret_docs"
    output[key] = st.slider(
        "Number of chunks to retrieve",
        min_value=1,
        max_value=32,
        value=8,
        key=f"{key_prefix}|{key}",
    )

    key = "filter_legis_id"
    output[key] = st.text_input("Bill ID (e.g. 118-s-2293)", key=f"{key_prefix}|{key}")

    key = "filter_bioguide_id"
    output[key] = st.text_input("Bioguide ID (e.g. R000595)", key=f"{key_prefix}|{key}")

    key = "filter_congress_nums"
    output[key] = st.multiselect(
        "Congress Numbers",
        CONGRESS_NUMBERS,
        default=CONGRESS_NUMBERS[-2:],
        key=f"{key_prefix}|{key}",
    )

    key = "filter_sponsor_parties"
    output[key] = st.multiselect(
        "Sponsor Party",
        SPONSOR_PARTIES,
        default=SPONSOR_PARTIES,
        key=f"{key_prefix}|{key}",
    )

    return output