|
|
""" |
|
|
UIコンポーネント |
|
|
|
|
|
再利用可能なUIコンポーネントを提供する |
|
|
""" |
|
|
from typing import Optional, Set |
|
|
|
|
|
import streamlit as st |
|
|
|
|
|
from ...models.registry import ModelRegistry, DEFAULT_MODEL_KEY |
|
|
|
|
|
|
|
|
|
|
|
GPU_REQUIRED_MODELS: Set[str] = { |
|
|
"gpt-oss-20b", |
|
|
"olmo-7b", |
|
|
"mistral-7b", |
|
|
"llama-3.2-3b", |
|
|
} |
|
|
|
|
|
|
|
|
def render_model_selector() -> str: |
|
|
""" |
|
|
モデル選択UIをレンダリング |
|
|
|
|
|
Returns: |
|
|
選択されたモデルのキー |
|
|
""" |
|
|
|
|
|
all_model_keys = ModelRegistry.list_models() |
|
|
model_keys = [k for k in all_model_keys if k not in GPU_REQUIRED_MODELS] |
|
|
configs = ModelRegistry.get_all_configs() |
|
|
|
|
|
|
|
|
display_names = {key: configs[key].name for key in model_keys} |
|
|
|
|
|
|
|
|
if "selected_model" not in st.session_state: |
|
|
st.session_state.selected_model = DEFAULT_MODEL_KEY |
|
|
|
|
|
|
|
|
selected_name = st.selectbox( |
|
|
"MODEL", |
|
|
options=[display_names[key] for key in model_keys], |
|
|
index=model_keys.index(st.session_state.selected_model), |
|
|
key="model_selectbox", |
|
|
label_visibility="collapsed", |
|
|
) |
|
|
|
|
|
|
|
|
selected_key = next( |
|
|
key for key, name in display_names.items() if name == selected_name |
|
|
) |
|
|
st.session_state.selected_model = selected_key |
|
|
|
|
|
|
|
|
config = configs[selected_key] |
|
|
st.markdown( |
|
|
f'<p class="model-info">{config.embedding_dim} dim / {config.vocab_size:,} tokens</p>', |
|
|
unsafe_allow_html=True, |
|
|
) |
|
|
|
|
|
return selected_key |
|
|
|