will / src /ui /streamlit /pages /generate.py
matt1847's picture
修正: 初期画面の空白表示とGPU必須モデルの除外
b4e822f
"""
生成ページ
デブリ生成のメインUIを提供する
"""
import streamlit as st
from ....models.registry import ModelRegistry
from ....generators.debris_generator import DebrisGenerator
from ....visualizers.signal_visualizer import SignalVisualizer
from ..components import render_model_selector
# モデルキャッシュ用のキー
_MODEL_CACHE_KEY = "_cached_model"
_GENERATOR_CACHE_KEY = "_cached_generator"
@st.cache_resource(show_spinner=False)
def _get_model(model_key: str):
"""モデルをキャッシュして取得"""
model = ModelRegistry.get(model_key)
model.load()
return model
def render_generate_page() -> None:
"""生成ページをレンダリング"""
# タイトル
st.markdown('<p class="title">WILL</p>', unsafe_allow_html=True)
st.markdown(
'<p class="subtitle">PURE COMPUTATIONAL WILL</p>', unsafe_allow_html=True
)
# モデル選択UI
col1, col2, col3 = st.columns([1, 2, 1])
with col2:
selected_model_key = render_model_selector()
# セッション状態の初期化
if "debris" not in st.session_state:
st.session_state.debris = None
st.session_state.seed = None
st.session_state.signal_img = None
st.session_state.has_result = False
# LISTENボタン
col1, col2, col3 = st.columns([1, 1, 1])
with col2:
clicked = st.button("LISTEN", key="listen_btn", use_container_width=True)
if clicked:
with st.spinner("Generating..."):
# モデルとジェネレータの取得
model = _get_model(selected_model_key)
generator = DebrisGenerator(model)
visualizer = SignalVisualizer()
# デブリ生成
result = generator.generate()
# 結果をセッション状態に保存
st.session_state.debris = result.debris
st.session_state.seed = result.seed
st.session_state.signal_img = visualizer.generate_image(
result.noise, result.corrupted_logits
)
st.session_state.has_result = True
# 結果の表示(結果がある場合のみ)
if st.session_state.get("has_result", False) and st.session_state.debris:
st.markdown(
f'''
<div class="debris-container">
<img class="signal-img" src="data:image/png;base64,{st.session_state.signal_img}">
<div class="debris">{" ".join(st.session_state.debris)}</div>
</div>
<p class="seed">{st.session_state.seed}</p>
''',
unsafe_allow_html=True,
)