File size: 1,897 Bytes
d1033d4
 
 
 
 
b4e822f
d1033d4
 
 
f94169f
d1033d4
 
b4e822f
 
 
 
 
 
 
 
 
d1033d4
 
 
 
 
 
 
b4e822f
 
 
d1033d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""
UIコンポーネント

再利用可能なUIコンポーネントを提供する
"""
from typing import Optional, Set

import streamlit as st

from ...models.registry import ModelRegistry, DEFAULT_MODEL_KEY


# GPU必須モデル(CPUでは動作しない)
GPU_REQUIRED_MODELS: Set[str] = {
    "gpt-oss-20b",      # 21B - 16GB VRAM必要
    "olmo-7b",          # 7B - 14GB VRAM必要
    "mistral-7b",       # 7B - 14GB VRAM必要
    "llama-3.2-3b",     # 3B - 6GB VRAM必要
}


def render_model_selector() -> str:
    """
    モデル選択UIをレンダリング

    Returns:
        選択されたモデルのキー
    """
    # 利用可能なモデル一覧を取得(GPU必須モデルを除外)
    all_model_keys = ModelRegistry.list_models()
    model_keys = [k for k in all_model_keys if k not in GPU_REQUIRED_MODELS]
    configs = ModelRegistry.get_all_configs()

    # 表示名とキーのマッピング
    display_names = {key: configs[key].name for key in model_keys}

    # セッション状態の初期化
    if "selected_model" not in st.session_state:
        st.session_state.selected_model = DEFAULT_MODEL_KEY

    # モデル選択ボックス
    selected_name = st.selectbox(
        "MODEL",
        options=[display_names[key] for key in model_keys],
        index=model_keys.index(st.session_state.selected_model),
        key="model_selectbox",
        label_visibility="collapsed",
    )

    # 選択された表示名からキーを逆引き
    selected_key = next(
        key for key, name in display_names.items() if name == selected_name
    )
    st.session_state.selected_model = selected_key

    # モデル情報を表示
    config = configs[selected_key]
    st.markdown(
        f'<p class="model-info">{config.embedding_dim} dim / {config.vocab_size:,} tokens</p>',
        unsafe_allow_html=True,
    )

    return selected_key