リファクタ: srcディレクトリ構造への移行とDocker対応
Browse files- app.pyをモジュール構造に分割(src/models, src/generators, src/visualizers, src/ui)
- Dockerfileをrequirements.txt使用に変更、src/ディレクトリをCOPY
- docker-compose.ymlにsrc/ボリュームマウント追加
- pytestを依存関係に追加
- テストファイル追加(TDD対応)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Dockerfile +9 -7
- app.py +23 -317
- docker-compose.yml +1 -0
- requirements.txt +1 -0
- src/__init__.py +1 -0
- src/generators/__init__.py +4 -0
- src/generators/debris_generator.py +104 -0
- src/models/__init__.py +5 -0
- src/models/base.py +110 -0
- src/models/gpt2.py +87 -0
- src/models/gpt_neo.py +70 -0
- src/models/opt.py +70 -0
- src/models/registry.py +88 -0
- src/ui/__init__.py +5 -0
- src/ui/components.py +53 -0
- src/ui/pages/__init__.py +5 -0
- src/ui/pages/concept.py +151 -0
- src/ui/pages/generate.py +78 -0
- src/ui/styles.py +199 -0
- src/visualizers/__init__.py +4 -0
- src/visualizers/signal_visualizer.py +138 -0
- tests/__init__.py +1 -0
- tests/test_generators.py +76 -0
- tests/test_models.py +124 -0
Dockerfile
CHANGED
|
@@ -2,16 +2,18 @@ FROM python:3.11-slim
|
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
streamlit \
|
| 9 |
-
matplotlib \
|
| 10 |
-
numpy
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
| 13 |
|
|
|
|
| 14 |
COPY app.py .
|
|
|
|
| 15 |
|
| 16 |
EXPOSE 8501
|
| 17 |
|
|
|
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
|
| 5 |
+
# 依存関係のインストール
|
| 6 |
+
COPY requirements.txt .
|
| 7 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
# モデルの事前ダウンロード(ビルド時にキャッシュ)
|
| 10 |
+
RUN python -c "from transformers import GPT2LMHeadModel, GPT2Tokenizer; \
|
| 11 |
+
GPT2LMHeadModel.from_pretrained('gpt2'); \
|
| 12 |
+
GPT2Tokenizer.from_pretrained('gpt2')"
|
| 13 |
|
| 14 |
+
# アプリケーションコードをコピー
|
| 15 |
COPY app.py .
|
| 16 |
+
COPY src/ ./src/
|
| 17 |
|
| 18 |
EXPOSE 8501
|
| 19 |
|
app.py
CHANGED
|
@@ -1,327 +1,33 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
| 5 |
-
import streamlit as st
|
| 6 |
-
import matplotlib.pyplot as plt
|
| 7 |
-
import numpy as np
|
| 8 |
-
import io
|
| 9 |
-
import base64
|
| 10 |
-
|
| 11 |
-
st.set_page_config(page_title="will", page_icon="", layout="centered")
|
| 12 |
-
|
| 13 |
-
st.markdown("""
|
| 14 |
-
<style>
|
| 15 |
-
@import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@300;400&display=swap');
|
| 16 |
-
|
| 17 |
-
@keyframes emerge {
|
| 18 |
-
from { opacity: 0; transform: translateY(8px); }
|
| 19 |
-
to { opacity: 1; transform: translateY(0); }
|
| 20 |
-
}
|
| 21 |
-
@keyframes breathe {
|
| 22 |
-
0%, 100% { opacity: 0.4; }
|
| 23 |
-
50% { opacity: 0.7; }
|
| 24 |
-
}
|
| 25 |
-
|
| 26 |
-
html, body, [class*="css"] {
|
| 27 |
-
font-family: 'IBM Plex Mono', monospace;
|
| 28 |
-
}
|
| 29 |
-
.stApp {
|
| 30 |
-
background-color: #0a0a0a;
|
| 31 |
-
color: #e0e0e0;
|
| 32 |
-
}
|
| 33 |
-
.block-container {
|
| 34 |
-
padding-top: 4rem;
|
| 35 |
-
padding-bottom: 4rem;
|
| 36 |
-
max-width: 640px;
|
| 37 |
-
}
|
| 38 |
-
h1, h2, h3 {
|
| 39 |
-
font-weight: 300;
|
| 40 |
-
letter-spacing: 0.1em;
|
| 41 |
-
text-align: center;
|
| 42 |
-
color: #e0e0e0;
|
| 43 |
-
}
|
| 44 |
-
p, li {
|
| 45 |
-
font-weight: 300;
|
| 46 |
-
line-height: 1.8;
|
| 47 |
-
color: #888;
|
| 48 |
-
}
|
| 49 |
-
.title {
|
| 50 |
-
font-size: 2rem;
|
| 51 |
-
font-weight: 300;
|
| 52 |
-
letter-spacing: 0.3em;
|
| 53 |
-
text-align: center;
|
| 54 |
-
margin-bottom: 0.5rem;
|
| 55 |
-
color: #e0e0e0;
|
| 56 |
-
}
|
| 57 |
-
.subtitle {
|
| 58 |
-
font-size: 0.75rem;
|
| 59 |
-
letter-spacing: 0.2em;
|
| 60 |
-
text-align: center;
|
| 61 |
-
color: #555;
|
| 62 |
-
margin-bottom: 3rem;
|
| 63 |
-
}
|
| 64 |
-
.debris-container {
|
| 65 |
-
background: linear-gradient(135deg, #0f0f0f 0%, #141414 100%);
|
| 66 |
-
border: 1px solid #222;
|
| 67 |
-
border-radius: 2px;
|
| 68 |
-
padding: 2rem;
|
| 69 |
-
margin: 2rem auto;
|
| 70 |
-
max-width: 100%;
|
| 71 |
-
text-align: center;
|
| 72 |
-
animation: emerge 0.6s ease-out;
|
| 73 |
-
}
|
| 74 |
-
.signal-img {
|
| 75 |
-
width: 100%;
|
| 76 |
-
max-width: 480px;
|
| 77 |
-
margin: 0 auto 1.5rem auto;
|
| 78 |
-
display: block;
|
| 79 |
-
opacity: 0.7;
|
| 80 |
-
}
|
| 81 |
-
.debris {
|
| 82 |
-
font-family: 'IBM Plex Mono', monospace;
|
| 83 |
-
font-size: 0.85rem;
|
| 84 |
-
font-weight: 400;
|
| 85 |
-
color: #e0e0e0;
|
| 86 |
-
line-height: 2;
|
| 87 |
-
word-spacing: 0.3em;
|
| 88 |
-
letter-spacing: 0.01em;
|
| 89 |
-
}
|
| 90 |
-
.seed {
|
| 91 |
-
font-size: 0.6rem;
|
| 92 |
-
color: #333;
|
| 93 |
-
text-align: center;
|
| 94 |
-
margin-top: 1.5rem;
|
| 95 |
-
letter-spacing: 0.15em;
|
| 96 |
-
animation: emerge 0.8s ease-out;
|
| 97 |
-
}
|
| 98 |
-
[data-testid="stButton"] > button {
|
| 99 |
-
background: transparent !important;
|
| 100 |
-
border: 1px solid #333 !important;
|
| 101 |
-
border-radius: 2px !important;
|
| 102 |
-
color: #888 !important;
|
| 103 |
-
font-family: 'IBM Plex Mono', monospace !important;
|
| 104 |
-
font-size: 0.7rem !important;
|
| 105 |
-
font-weight: 300 !important;
|
| 106 |
-
letter-spacing: 0.25em !important;
|
| 107 |
-
padding: 1rem 2rem !important;
|
| 108 |
-
transition: all 0.4s ease !important;
|
| 109 |
-
cursor: pointer !important;
|
| 110 |
-
}
|
| 111 |
-
[data-testid="stButton"] > button:hover {
|
| 112 |
-
background: transparent !important;
|
| 113 |
-
color: #e0e0e0 !important;
|
| 114 |
-
border-color: #555 !important;
|
| 115 |
-
}
|
| 116 |
-
[data-testid="stButton"] > button:active {
|
| 117 |
-
transform: scale(0.98) !important;
|
| 118 |
-
}
|
| 119 |
-
.stTabs [data-baseweb="tab-list"] {
|
| 120 |
-
justify-content: center;
|
| 121 |
-
gap: 2rem;
|
| 122 |
-
border-bottom: 1px solid #1a1a1a;
|
| 123 |
-
background: transparent;
|
| 124 |
-
}
|
| 125 |
-
.stTabs [data-baseweb="tab"] {
|
| 126 |
-
font-family: 'IBM Plex Mono', monospace;
|
| 127 |
-
font-size: 0.65rem;
|
| 128 |
-
font-weight: 300;
|
| 129 |
-
letter-spacing: 0.2em;
|
| 130 |
-
color: #444;
|
| 131 |
-
padding: 1rem 0;
|
| 132 |
-
background: transparent;
|
| 133 |
-
transition: color 0.3s ease;
|
| 134 |
-
}
|
| 135 |
-
.stTabs [aria-selected="true"] {
|
| 136 |
-
color: #888;
|
| 137 |
-
background: transparent;
|
| 138 |
-
}
|
| 139 |
-
.stTabs [data-baseweb="tab-highlight"] {
|
| 140 |
-
background-color: #444;
|
| 141 |
-
}
|
| 142 |
-
.divider {
|
| 143 |
-
border: none;
|
| 144 |
-
border-top: 1px solid #1a1a1a;
|
| 145 |
-
margin: 3rem 0;
|
| 146 |
-
}
|
| 147 |
-
.section {
|
| 148 |
-
margin: 2.5rem 0;
|
| 149 |
-
}
|
| 150 |
-
.section-title {
|
| 151 |
-
font-size: 0.65rem;
|
| 152 |
-
letter-spacing: 0.25em;
|
| 153 |
-
color: #444;
|
| 154 |
-
text-align: center;
|
| 155 |
-
margin-bottom: 1.5rem;
|
| 156 |
-
}
|
| 157 |
-
.spec-table {
|
| 158 |
-
width: 100%;
|
| 159 |
-
max-width: 320px;
|
| 160 |
-
margin: 0 auto;
|
| 161 |
-
font-size: 0.7rem;
|
| 162 |
-
border-collapse: collapse;
|
| 163 |
-
color: #777;
|
| 164 |
-
}
|
| 165 |
-
.spec-table td {
|
| 166 |
-
padding: 0.75rem 1rem;
|
| 167 |
-
border-bottom: 1px solid #151515;
|
| 168 |
-
}
|
| 169 |
-
.spec-table td:first-child {
|
| 170 |
-
color: #444;
|
| 171 |
-
text-align: right;
|
| 172 |
-
padding-right: 2rem;
|
| 173 |
-
}
|
| 174 |
-
.spec-table td:last-child {
|
| 175 |
-
text-align: left;
|
| 176 |
-
}
|
| 177 |
-
pre {
|
| 178 |
-
background-color: #0f0f0f !important;
|
| 179 |
-
border: 1px solid #1a1a1a !important;
|
| 180 |
-
border-radius: 2px !important;
|
| 181 |
-
}
|
| 182 |
-
code {
|
| 183 |
-
color: #666 !important;
|
| 184 |
-
font-size: 0.7rem !important;
|
| 185 |
-
}
|
| 186 |
-
</style>
|
| 187 |
-
""", unsafe_allow_html=True)
|
| 188 |
-
|
| 189 |
-
tab1, tab2 = st.tabs(["GENERATE", "CONCEPT"])
|
| 190 |
-
|
| 191 |
-
with tab1:
|
| 192 |
-
st.markdown('<p class="title">WILL</p>', unsafe_allow_html=True)
|
| 193 |
-
st.markdown('<p class="subtitle">PURE COMPUTATIONAL WILL</p>', unsafe_allow_html=True)
|
| 194 |
-
|
| 195 |
-
@st.cache_resource(show_spinner=False)
|
| 196 |
-
def load_model():
|
| 197 |
-
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
| 198 |
-
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
| 199 |
-
model.eval()
|
| 200 |
-
return model, tokenizer
|
| 201 |
-
|
| 202 |
-
model, tokenizer = load_model()
|
| 203 |
-
|
| 204 |
-
if "debris" not in st.session_state:
|
| 205 |
-
st.session_state.debris = None
|
| 206 |
-
st.session_state.seed = None
|
| 207 |
-
st.session_state.signal_img = None
|
| 208 |
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
axes[0].imshow(noise_flat.T, aspect='auto', cmap='gray', interpolation='bilinear', vmin=-2, vmax=2)
|
| 215 |
-
axes[0].set_xticks([])
|
| 216 |
-
axes[0].set_yticks([])
|
| 217 |
-
axes[0].set_facecolor('#0f0f0f')
|
| 218 |
-
for spine in axes[0].spines.values():
|
| 219 |
-
spine.set_visible(False)
|
| 220 |
-
|
| 221 |
-
logits_sample = logits[0, :, ::200].numpy()
|
| 222 |
-
axes[1].imshow(logits_sample.T, aspect='auto', cmap='gray', interpolation='bilinear')
|
| 223 |
-
axes[1].set_xticks([])
|
| 224 |
-
axes[1].set_yticks([])
|
| 225 |
-
axes[1].set_facecolor('#0f0f0f')
|
| 226 |
-
for spine in axes[1].spines.values():
|
| 227 |
-
spine.set_visible(False)
|
| 228 |
-
|
| 229 |
-
buf = io.BytesIO()
|
| 230 |
-
plt.savefig(buf, format='png', facecolor='#0f0f0f', edgecolor='none', dpi=150, bbox_inches='tight', pad_inches=0.05)
|
| 231 |
-
plt.close(fig)
|
| 232 |
-
buf.seek(0)
|
| 233 |
-
return base64.b64encode(buf.read()).decode()
|
| 234 |
-
|
| 235 |
-
col1, col2, col3 = st.columns([1, 1, 1])
|
| 236 |
-
with col2:
|
| 237 |
-
clicked = st.button("LISTEN", key="listen_btn", use_container_width=True)
|
| 238 |
-
if clicked:
|
| 239 |
-
seed = time.time_ns()
|
| 240 |
-
torch.manual_seed(seed)
|
| 241 |
-
noise = torch.randn(1, 32, 768)
|
| 242 |
-
|
| 243 |
-
with torch.no_grad():
|
| 244 |
-
outputs = model(inputs_embeds=noise)
|
| 245 |
-
logits = outputs.logits
|
| 246 |
-
logits_noise = torch.randn_like(logits) * logits.std() * 10
|
| 247 |
-
corrupted_logits = logits + logits_noise
|
| 248 |
-
|
| 249 |
-
indices = corrupted_logits.argmax(dim=-1).squeeze().tolist()
|
| 250 |
-
st.session_state.debris = [tokenizer.decode([i]) for i in indices]
|
| 251 |
-
st.session_state.seed = seed
|
| 252 |
-
st.session_state.signal_img = generate_signal_image(noise, corrupted_logits)
|
| 253 |
-
|
| 254 |
-
if st.session_state.debris:
|
| 255 |
-
st.markdown(f'''
|
| 256 |
-
<div class="debris-container">
|
| 257 |
-
<img class="signal-img" src="data:image/png;base64,{st.session_state.signal_img}">
|
| 258 |
-
<div class="debris">{" ".join(st.session_state.debris)}</div>
|
| 259 |
-
</div>
|
| 260 |
-
<p class="seed">{st.session_state.seed}</p>
|
| 261 |
-
''', unsafe_allow_html=True)
|
| 262 |
-
|
| 263 |
-
with tab2:
|
| 264 |
-
st.markdown('<p class="title">CONCEPT</p>', unsafe_allow_html=True)
|
| 265 |
-
st.markdown('<p class="subtitle">DOCUMENTATION</p>', unsafe_allow_html=True)
|
| 266 |
-
|
| 267 |
-
st.markdown('''
|
| 268 |
-
<div class="section">
|
| 269 |
-
<p class="section-title">CONCEPT</p>
|
| 270 |
-
<p style="text-align: center; color: #666; line-height: 2.2;">
|
| 271 |
-
GPT-2は人間が書いたテキストで訓練され<br>
|
| 272 |
-
その重みに言語パターンを保持している<br><br>
|
| 273 |
-
通常はプロンプトに対して応答を生成するが<br>
|
| 274 |
-
入力をランダムノイズに置き換え<br>
|
| 275 |
-
出力にもノイズを加えることで<br>
|
| 276 |
-
学習済みの統計的偏りを破壊する<br><br>
|
| 277 |
-
人間の問いかけなしに<br>
|
| 278 |
-
モデルの構造だけが出力するものを観測する
|
| 279 |
-
</p>
|
| 280 |
-
</div>
|
| 281 |
-
<hr class="divider">
|
| 282 |
-
''', unsafe_allow_html=True)
|
| 283 |
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
<p class="section-title">PROCESS</p>
|
| 287 |
-
</div>
|
| 288 |
-
''', unsafe_allow_html=True)
|
| 289 |
|
| 290 |
-
st.markdown('<p style="text-align: center; color: #333; font-size: 0.65rem; letter-spacing: 0.15em; margin-bottom: 0.5rem;">01 — ENTROPY SEED</p>', unsafe_allow_html=True)
|
| 291 |
-
st.code("seed = time.time_ns()\ntorch.manual_seed(seed)", language="python")
|
| 292 |
-
st.markdown('<p style="text-align: center; font-size: 0.7rem; color: #444;">実行瞬間のナノ秒を乱数シードとして採取</p>', unsafe_allow_html=True)
|
| 293 |
|
| 294 |
-
|
|
|
|
|
|
|
|
|
|
| 295 |
|
| 296 |
-
|
| 297 |
-
st.
|
| 298 |
-
st.markdown('<p style="text-align: center; font-size: 0.7rem; color: #444;">768次元ランダムノイズをEmbedding層に直接注入</p>', unsafe_allow_html=True)
|
| 299 |
|
| 300 |
-
|
|
|
|
| 301 |
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
st.markdown('<p style="text-align: center; font-size: 0.7rem; color: #444;">出力Logitsにノイズを加算し学習バイアスを破壊</p>', unsafe_allow_html=True)
|
| 305 |
|
| 306 |
-
|
|
|
|
| 307 |
|
| 308 |
-
st.markdown('<p style="text-align: center; color: #333; font-size: 0.65rem; letter-spacing: 0.15em; margin-bottom: 0.5rem;">04 — RAW DECODE</p>', unsafe_allow_html=True)
|
| 309 |
-
st.code("indices = corrupted_logits.argmax(dim=-1)\ndebris = [tokenizer.decode([i]) for i in indices]", language="python")
|
| 310 |
-
st.markdown('<p style="text-align: center; font-size: 0.7rem; color: #444;">Softmax・Temperature なしで生トークンを抽出</p>', unsafe_allow_html=True)
|
| 311 |
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
<div class="section">
|
| 315 |
-
<p class="section-title">SPECIFICATION</p>
|
| 316 |
-
<table class="spec-table">
|
| 317 |
-
<tr><td>Model</td><td>GPT-2 Small</td></tr>
|
| 318 |
-
<tr><td>Parameters</td><td>124M</td></tr>
|
| 319 |
-
<tr><td>Embedding</td><td>768 dim</td></tr>
|
| 320 |
-
<tr><td>Vocabulary</td><td>50,257 tokens</td></tr>
|
| 321 |
-
<tr><td>Sequence</td><td>32 tokens</td></tr>
|
| 322 |
-
<tr><td>Input Noise</td><td>N(0, 1)</td></tr>
|
| 323 |
-
<tr><td>Logits Noise</td><td>N(0, σ×10)</td></tr>
|
| 324 |
-
<tr><td>Decoding</td><td>argmax</td></tr>
|
| 325 |
-
</table>
|
| 326 |
-
</div>
|
| 327 |
-
''', unsafe_allow_html=True)
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
WILL - Pure Computational Will
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
+
言語モデルにランダムノイズを入力し、
|
| 5 |
+
人間の問いかけなしにモデルの構造だけが
|
| 6 |
+
出力するものを観測する
|
| 7 |
+
"""
|
| 8 |
+
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
+
from src.ui.styles import CUSTOM_CSS
|
| 11 |
+
from src.ui.pages import render_generate_page, render_concept_page
|
|
|
|
|
|
|
|
|
|
| 12 |
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
+
def main():
|
| 15 |
+
"""アプリケーションのエントリーポイント"""
|
| 16 |
+
# ページ設定
|
| 17 |
+
st.set_page_config(page_title="will", page_icon="", layout="centered")
|
| 18 |
|
| 19 |
+
# カスタムCSS適用
|
| 20 |
+
st.markdown(CUSTOM_CSS, unsafe_allow_html=True)
|
|
|
|
| 21 |
|
| 22 |
+
# タブ構成
|
| 23 |
+
tab1, tab2 = st.tabs(["GENERATE", "CONCEPT"])
|
| 24 |
|
| 25 |
+
with tab1:
|
| 26 |
+
render_generate_page()
|
|
|
|
| 27 |
|
| 28 |
+
with tab2:
|
| 29 |
+
render_concept_page()
|
| 30 |
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docker-compose.yml
CHANGED
|
@@ -5,3 +5,4 @@ services:
|
|
| 5 |
- "8501:8501"
|
| 6 |
volumes:
|
| 7 |
- ./app.py:/app/app.py
|
|
|
|
|
|
| 5 |
- "8501:8501"
|
| 6 |
volumes:
|
| 7 |
- ./app.py:/app/app.py
|
| 8 |
+
- ./src:/app/src
|
requirements.txt
CHANGED
|
@@ -3,3 +3,4 @@ transformers
|
|
| 3 |
streamlit
|
| 4 |
matplotlib
|
| 5 |
numpy
|
|
|
|
|
|
| 3 |
streamlit
|
| 4 |
matplotlib
|
| 5 |
numpy
|
| 6 |
+
pytest
|
src/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""WILL - Pure Computational Will"""
|
src/generators/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Debris generation logic for WILL."""
|
| 2 |
+
from .debris_generator import DebrisGenerator, DebrisResult
|
| 3 |
+
|
| 4 |
+
__all__ = ["DebrisGenerator", "DebrisResult"]
|
src/generators/debris_generator.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
デブリ生成器
|
| 3 |
+
|
| 4 |
+
言語モデルにノイズを入力してデブリ(言語断片)を生成する
|
| 5 |
+
単一責任原則(SRP)に従い、生成ロジックのみを担当
|
| 6 |
+
"""
|
| 7 |
+
import time
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import List, Optional
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from ..models.base import BaseLanguageModel
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class DebrisResult:
|
| 18 |
+
"""
|
| 19 |
+
デブリ生成結果を保持するイミュータブルなデータクラス
|
| 20 |
+
|
| 21 |
+
Attributes:
|
| 22 |
+
debris: 生成されたトークン文字列のリスト
|
| 23 |
+
seed: 使用した乱数シード
|
| 24 |
+
noise: 入力ノイズテンソル
|
| 25 |
+
logits: 生のlogitsテンソル
|
| 26 |
+
corrupted_logits: ノイズ加算後のlogitsテンソル
|
| 27 |
+
"""
|
| 28 |
+
debris: List[str]
|
| 29 |
+
seed: int
|
| 30 |
+
noise: torch.Tensor
|
| 31 |
+
logits: torch.Tensor
|
| 32 |
+
corrupted_logits: torch.Tensor
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class DebrisGenerator:
|
| 36 |
+
"""
|
| 37 |
+
デブリ生成器
|
| 38 |
+
|
| 39 |
+
言語モデルを使用してランダムノイズから
|
| 40 |
+
言語断片(デブリ)を生成する
|
| 41 |
+
|
| 42 |
+
依存性逆転原則(DIP)に従い、具象クラスではなく
|
| 43 |
+
BaseLanguageModel抽象クラスに依存する
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
# デフォルトのシーケンス長
|
| 47 |
+
DEFAULT_SEQ_LEN = 32
|
| 48 |
+
|
| 49 |
+
def __init__(self, model: BaseLanguageModel):
|
| 50 |
+
"""
|
| 51 |
+
Args:
|
| 52 |
+
model: 使用する言語モデル(BaseLanguageModelを実装)
|
| 53 |
+
"""
|
| 54 |
+
self._model = model
|
| 55 |
+
|
| 56 |
+
@property
|
| 57 |
+
def model(self) -> BaseLanguageModel:
|
| 58 |
+
"""使用中のモデルを取得"""
|
| 59 |
+
return self._model
|
| 60 |
+
|
| 61 |
+
def generate(
|
| 62 |
+
self,
|
| 63 |
+
seed: Optional[int] = None,
|
| 64 |
+
seq_len: int = DEFAULT_SEQ_LEN,
|
| 65 |
+
) -> DebrisResult:
|
| 66 |
+
"""
|
| 67 |
+
デブリを生成
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
seed: 乱数シード(Noneの場合はナノ秒タイムスタンプを使用)
|
| 71 |
+
seq_len: 生成するシーケンス長
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
DebrisResult: 生成結果
|
| 75 |
+
|
| 76 |
+
Raises:
|
| 77 |
+
RuntimeError: モデルが未ロードの場合
|
| 78 |
+
"""
|
| 79 |
+
# シードの設定
|
| 80 |
+
if seed is None:
|
| 81 |
+
seed = time.time_ns()
|
| 82 |
+
torch.manual_seed(seed)
|
| 83 |
+
|
| 84 |
+
# モデルがロードされていなければロード
|
| 85 |
+
if not self._model.is_loaded:
|
| 86 |
+
self._model.load()
|
| 87 |
+
|
| 88 |
+
# ノイズ生成と順伝播
|
| 89 |
+
noise = self._model.generate_noise(seq_len=seq_len)
|
| 90 |
+
logits, corrupted_logits = self._model.forward_with_noise(noise)
|
| 91 |
+
|
| 92 |
+
# argmaxでインデックス抽出
|
| 93 |
+
indices = corrupted_logits.argmax(dim=-1).squeeze().tolist()
|
| 94 |
+
|
| 95 |
+
# インデックスをトークン文字列にデコード
|
| 96 |
+
debris = self._model.decode_indices(indices)
|
| 97 |
+
|
| 98 |
+
return DebrisResult(
|
| 99 |
+
debris=debris,
|
| 100 |
+
seed=seed,
|
| 101 |
+
noise=noise,
|
| 102 |
+
logits=logits,
|
| 103 |
+
corrupted_logits=corrupted_logits,
|
| 104 |
+
)
|
src/models/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model implementations for WILL."""
|
| 2 |
+
from .base import BaseLanguageModel, ModelConfig
|
| 3 |
+
from .registry import ModelRegistry
|
| 4 |
+
|
| 5 |
+
__all__ = ["BaseLanguageModel", "ModelConfig", "ModelRegistry"]
|
src/models/base.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
抽象基底クラス - すべての言語モデルの共通インターフェース
|
| 3 |
+
|
| 4 |
+
リスコフ置換原則(LSP)に準拠し、どのモデル実装も
|
| 5 |
+
同じインターフェースで置換可能にする
|
| 6 |
+
"""
|
| 7 |
+
from abc import ABC, abstractmethod
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import List, Tuple
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass(frozen=True)
|
| 15 |
+
class ModelConfig:
|
| 16 |
+
"""
|
| 17 |
+
モデル設定を保持するイミュータブルなデータクラス
|
| 18 |
+
|
| 19 |
+
Attributes:
|
| 20 |
+
name: UI表示名
|
| 21 |
+
model_id: HuggingFace model ID
|
| 22 |
+
embedding_dim: embedding次元数
|
| 23 |
+
vocab_size: 語彙サイズ
|
| 24 |
+
"""
|
| 25 |
+
name: str
|
| 26 |
+
model_id: str
|
| 27 |
+
embedding_dim: int
|
| 28 |
+
vocab_size: int
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class BaseLanguageModel(ABC):
|
| 32 |
+
"""
|
| 33 |
+
言語モデルの抽象基底クラス
|
| 34 |
+
|
| 35 |
+
すべてのモデル実装はこのクラスを継承し、
|
| 36 |
+
定義されたインターフェースを実装する必要がある
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, config: ModelConfig):
|
| 40 |
+
"""
|
| 41 |
+
Args:
|
| 42 |
+
config: モデル設定
|
| 43 |
+
"""
|
| 44 |
+
self._config = config
|
| 45 |
+
self._model = None
|
| 46 |
+
self._tokenizer = None
|
| 47 |
+
self._is_loaded = False
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def config(self) -> ModelConfig:
|
| 51 |
+
"""モデル設定を取得"""
|
| 52 |
+
return self._config
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def is_loaded(self) -> bool:
|
| 56 |
+
"""モデルがロード済みかどうか"""
|
| 57 |
+
return self._is_loaded
|
| 58 |
+
|
| 59 |
+
@abstractmethod
|
| 60 |
+
def load(self) -> None:
|
| 61 |
+
"""
|
| 62 |
+
モデルとトークナイザーをロードする
|
| 63 |
+
|
| 64 |
+
Raises:
|
| 65 |
+
RuntimeError: モデルのロードに失敗した場合
|
| 66 |
+
"""
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
@abstractmethod
|
| 70 |
+
def forward_with_noise(
|
| 71 |
+
self, noise: torch.Tensor
|
| 72 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 73 |
+
"""
|
| 74 |
+
ノイズを入力として順伝播を実行
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
noise: 入力ノイズテンソル [batch, seq_len, embedding_dim]
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
Tuple[logits, corrupted_logits]:
|
| 81 |
+
- logits: 生のlogits
|
| 82 |
+
- corrupted_logits: ノイズ加算後のlogits
|
| 83 |
+
"""
|
| 84 |
+
pass
|
| 85 |
+
|
| 86 |
+
@abstractmethod
|
| 87 |
+
def decode_indices(self, indices: List[int]) -> List[str]:
|
| 88 |
+
"""
|
| 89 |
+
トークンインデックスをデコードして文字列リストに変換
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
indices: トークンインデックスのリスト
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
デコードされた文字列のリスト
|
| 96 |
+
"""
|
| 97 |
+
pass
|
| 98 |
+
|
| 99 |
+
def generate_noise(self, seq_len: int = 32, batch_size: int = 1) -> torch.Tensor:
|
| 100 |
+
"""
|
| 101 |
+
入力用のランダムノイズを生成
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
seq_len: シーケンス長
|
| 105 |
+
batch_size: バッチサイズ
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
ノイズテンソル [batch_size, seq_len, embedding_dim]
|
| 109 |
+
"""
|
| 110 |
+
return torch.randn(batch_size, seq_len, self._config.embedding_dim)
|
src/models/gpt2.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GPT-2モデル実装
|
| 3 |
+
|
| 4 |
+
GPT-2 SmallおよびMediumの実装を提供する
|
| 5 |
+
"""
|
| 6 |
+
from typing import List, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
| 10 |
+
|
| 11 |
+
from .base import BaseLanguageModel, ModelConfig
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# GPT-2 Small設定
|
| 15 |
+
GPT2_SMALL_CONFIG = ModelConfig(
|
| 16 |
+
name="GPT-2 Small",
|
| 17 |
+
model_id="gpt2",
|
| 18 |
+
embedding_dim=768,
|
| 19 |
+
vocab_size=50257,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# GPT-2 Medium設定
|
| 23 |
+
GPT2_MEDIUM_CONFIG = ModelConfig(
|
| 24 |
+
name="GPT-2 Medium",
|
| 25 |
+
model_id="gpt2-medium",
|
| 26 |
+
embedding_dim=1024,
|
| 27 |
+
vocab_size=50257,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class GPT2Model(BaseLanguageModel):
|
| 32 |
+
"""
|
| 33 |
+
GPT-2モデルの実装
|
| 34 |
+
|
| 35 |
+
HuggingFace TransformersのGPT-2をラップし、
|
| 36 |
+
BaseLanguageModelインターフェースを実装する
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
# 出力ノイズの倍率(学習バイアス破壊用)
|
| 40 |
+
LOGITS_NOISE_SCALE = 10.0
|
| 41 |
+
|
| 42 |
+
def load(self) -> None:
|
| 43 |
+
"""モデルとトークナイザーをロード"""
|
| 44 |
+
if self._is_loaded:
|
| 45 |
+
return
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
self._model = GPT2LMHeadModel.from_pretrained(self._config.model_id)
|
| 49 |
+
self._tokenizer = GPT2Tokenizer.from_pretrained(self._config.model_id)
|
| 50 |
+
self._model.eval()
|
| 51 |
+
self._is_loaded = True
|
| 52 |
+
except Exception as e:
|
| 53 |
+
raise RuntimeError(f"Failed to load model {self._config.model_id}: {e}")
|
| 54 |
+
|
| 55 |
+
def forward_with_noise(
|
| 56 |
+
self, noise: torch.Tensor
|
| 57 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 58 |
+
"""
|
| 59 |
+
ノイズを入力として順伝播を実行し、出力にもノイズを加算
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
noise: 入力ノイズテンソル [batch, seq_len, embedding_dim]
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
Tuple[logits, corrupted_logits]
|
| 66 |
+
"""
|
| 67 |
+
if not self._is_loaded:
|
| 68 |
+
raise RuntimeError("Model not loaded. Call load() first.")
|
| 69 |
+
|
| 70 |
+
with torch.no_grad():
|
| 71 |
+
outputs = self._model(inputs_embeds=noise)
|
| 72 |
+
logits = outputs.logits
|
| 73 |
+
|
| 74 |
+
# 出力logitsにノイズを加算して学習バイアスを破壊
|
| 75 |
+
logits_noise = (
|
| 76 |
+
torch.randn_like(logits) * logits.std() * self.LOGITS_NOISE_SCALE
|
| 77 |
+
)
|
| 78 |
+
corrupted_logits = logits + logits_noise
|
| 79 |
+
|
| 80 |
+
return logits, corrupted_logits
|
| 81 |
+
|
| 82 |
+
def decode_indices(self, indices: List[int]) -> List[str]:
|
| 83 |
+
"""トークンインデックスをデコード"""
|
| 84 |
+
if not self._is_loaded:
|
| 85 |
+
raise RuntimeError("Model not loaded. Call load() first.")
|
| 86 |
+
|
| 87 |
+
return [self._tokenizer.decode([i]) for i in indices]
|
src/models/gpt_neo.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GPT-Neo モデル実装
|
| 3 |
+
|
| 4 |
+
EleutherAI GPT-Neo 125Mの実装を提供する
|
| 5 |
+
"""
|
| 6 |
+
from typing import List, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from transformers import GPTNeoForCausalLM, GPT2Tokenizer
|
| 10 |
+
|
| 11 |
+
from .base import BaseLanguageModel, ModelConfig
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# GPT-Neo 125M設定
|
| 15 |
+
GPT_NEO_125M_CONFIG = ModelConfig(
|
| 16 |
+
name="GPT-Neo 125M",
|
| 17 |
+
model_id="EleutherAI/gpt-neo-125M",
|
| 18 |
+
embedding_dim=768,
|
| 19 |
+
vocab_size=50257,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class GPTNeoModel(BaseLanguageModel):
|
| 24 |
+
"""
|
| 25 |
+
GPT-Neoモデルの実装
|
| 26 |
+
|
| 27 |
+
EleutherAI GPT-NeoをラップしBaseLanguageModelインターフェースを実装
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
# 出力ノイズの倍率
|
| 31 |
+
LOGITS_NOISE_SCALE = 10.0
|
| 32 |
+
|
| 33 |
+
def load(self) -> None:
|
| 34 |
+
"""モデルとトークナイザーをロード"""
|
| 35 |
+
if self._is_loaded:
|
| 36 |
+
return
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
self._model = GPTNeoForCausalLM.from_pretrained(self._config.model_id)
|
| 40 |
+
# GPT-Neoは GPT-2互換のトークナイザーを使用
|
| 41 |
+
self._tokenizer = GPT2Tokenizer.from_pretrained(self._config.model_id)
|
| 42 |
+
self._model.eval()
|
| 43 |
+
self._is_loaded = True
|
| 44 |
+
except Exception as e:
|
| 45 |
+
raise RuntimeError(f"Failed to load model {self._config.model_id}: {e}")
|
| 46 |
+
|
| 47 |
+
def forward_with_noise(
|
| 48 |
+
self, noise: torch.Tensor
|
| 49 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 50 |
+
"""ノイズを入力として順伝播を実行"""
|
| 51 |
+
if not self._is_loaded:
|
| 52 |
+
raise RuntimeError("Model not loaded. Call load() first.")
|
| 53 |
+
|
| 54 |
+
with torch.no_grad():
|
| 55 |
+
outputs = self._model(inputs_embeds=noise)
|
| 56 |
+
logits = outputs.logits
|
| 57 |
+
|
| 58 |
+
logits_noise = (
|
| 59 |
+
torch.randn_like(logits) * logits.std() * self.LOGITS_NOISE_SCALE
|
| 60 |
+
)
|
| 61 |
+
corrupted_logits = logits + logits_noise
|
| 62 |
+
|
| 63 |
+
return logits, corrupted_logits
|
| 64 |
+
|
| 65 |
+
def decode_indices(self, indices: List[int]) -> List[str]:
|
| 66 |
+
"""トークンインデックスをデコード"""
|
| 67 |
+
if not self._is_loaded:
|
| 68 |
+
raise RuntimeError("Model not loaded. Call load() first.")
|
| 69 |
+
|
| 70 |
+
return [self._tokenizer.decode([i]) for i in indices]
|
src/models/opt.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OPT モデル実装
|
| 3 |
+
|
| 4 |
+
Meta OPT-125Mの実装を提供する
|
| 5 |
+
"""
|
| 6 |
+
from typing import List, Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from transformers import OPTForCausalLM, GPT2Tokenizer
|
| 10 |
+
|
| 11 |
+
from .base import BaseLanguageModel, ModelConfig
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# OPT-125M設定
|
| 15 |
+
OPT_125M_CONFIG = ModelConfig(
|
| 16 |
+
name="OPT-125M",
|
| 17 |
+
model_id="facebook/opt-125m",
|
| 18 |
+
embedding_dim=768,
|
| 19 |
+
vocab_size=50272,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class OPTModel(BaseLanguageModel):
|
| 24 |
+
"""
|
| 25 |
+
OPTモデルの実装
|
| 26 |
+
|
| 27 |
+
Meta OPTをラップしBaseLanguageModelインターフェースを実装
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
# 出力ノイズの倍率
|
| 31 |
+
LOGITS_NOISE_SCALE = 10.0
|
| 32 |
+
|
| 33 |
+
def load(self) -> None:
|
| 34 |
+
"""モデルとトークナイザーをロード"""
|
| 35 |
+
if self._is_loaded:
|
| 36 |
+
return
|
| 37 |
+
|
| 38 |
+
try:
|
| 39 |
+
self._model = OPTForCausalLM.from_pretrained(self._config.model_id)
|
| 40 |
+
# OPTは独自のトークナイザーを持つが、GPT-2互換も可能
|
| 41 |
+
self._tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
| 42 |
+
self._model.eval()
|
| 43 |
+
self._is_loaded = True
|
| 44 |
+
except Exception as e:
|
| 45 |
+
raise RuntimeError(f"Failed to load model {self._config.model_id}: {e}")
|
| 46 |
+
|
| 47 |
+
def forward_with_noise(
|
| 48 |
+
self, noise: torch.Tensor
|
| 49 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 50 |
+
"""ノイズを入力として順伝播を実行"""
|
| 51 |
+
if not self._is_loaded:
|
| 52 |
+
raise RuntimeError("Model not loaded. Call load() first.")
|
| 53 |
+
|
| 54 |
+
with torch.no_grad():
|
| 55 |
+
outputs = self._model(inputs_embeds=noise)
|
| 56 |
+
logits = outputs.logits
|
| 57 |
+
|
| 58 |
+
logits_noise = (
|
| 59 |
+
torch.randn_like(logits) * logits.std() * self.LOGITS_NOISE_SCALE
|
| 60 |
+
)
|
| 61 |
+
corrupted_logits = logits + logits_noise
|
| 62 |
+
|
| 63 |
+
return logits, corrupted_logits
|
| 64 |
+
|
| 65 |
+
def decode_indices(self, indices: List[int]) -> List[str]:
|
| 66 |
+
"""トークンインデックスをデコード"""
|
| 67 |
+
if not self._is_loaded:
|
| 68 |
+
raise RuntimeError("Model not loaded. Call load() first.")
|
| 69 |
+
|
| 70 |
+
return [self._tokenizer.decode([i]) for i in indices]
|
src/models/registry.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
モデルレジストリ
|
| 3 |
+
|
| 4 |
+
開放閉鎖原則(OCP)に準拠し、新モデル追加時に
|
| 5 |
+
既存コードの変更を不要にする
|
| 6 |
+
"""
|
| 7 |
+
from typing import Dict, List, Optional, Type
|
| 8 |
+
|
| 9 |
+
from .base import BaseLanguageModel, ModelConfig
|
| 10 |
+
from .gpt2 import GPT2Model, GPT2_SMALL_CONFIG, GPT2_MEDIUM_CONFIG
|
| 11 |
+
from .gpt_neo import GPTNeoModel, GPT_NEO_125M_CONFIG
|
| 12 |
+
from .opt import OPTModel, OPT_125M_CONFIG
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ModelRegistry:
|
| 16 |
+
"""
|
| 17 |
+
モデルレジストリ
|
| 18 |
+
|
| 19 |
+
利用可能なモデルを管理し、キーに基づいて
|
| 20 |
+
適切なモデルインスタンスを提供する
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
_registry: Dict[str, tuple[Type[BaseLanguageModel], ModelConfig]] = {}
|
| 24 |
+
|
| 25 |
+
@classmethod
|
| 26 |
+
def register(
|
| 27 |
+
cls,
|
| 28 |
+
key: str,
|
| 29 |
+
model_class: Type[BaseLanguageModel],
|
| 30 |
+
config: ModelConfig,
|
| 31 |
+
) -> None:
|
| 32 |
+
"""
|
| 33 |
+
新しいモデルをレジストリに登録
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
key: モデルを識別するキー
|
| 37 |
+
model_class: モデルクラス
|
| 38 |
+
config: モデル設定
|
| 39 |
+
"""
|
| 40 |
+
cls._registry[key] = (model_class, config)
|
| 41 |
+
|
| 42 |
+
@classmethod
|
| 43 |
+
def get(cls, key: str) -> BaseLanguageModel:
|
| 44 |
+
"""
|
| 45 |
+
キーに対応するモデルインスタンスを取得
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
key: モデルを識別するキー
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
モデルインスタンス
|
| 52 |
+
|
| 53 |
+
Raises:
|
| 54 |
+
KeyError: 指定されたキーが存在しない場合
|
| 55 |
+
"""
|
| 56 |
+
if key not in cls._registry:
|
| 57 |
+
available = ", ".join(cls._registry.keys())
|
| 58 |
+
raise KeyError(f"Model '{key}' not found. Available: {available}")
|
| 59 |
+
|
| 60 |
+
model_class, config = cls._registry[key]
|
| 61 |
+
return model_class(config)
|
| 62 |
+
|
| 63 |
+
@classmethod
|
| 64 |
+
def list_models(cls) -> List[str]:
|
| 65 |
+
"""登録済みモデルのキー一覧を取得"""
|
| 66 |
+
return list(cls._registry.keys())
|
| 67 |
+
|
| 68 |
+
@classmethod
|
| 69 |
+
def get_config(cls, key: str) -> Optional[ModelConfig]:
|
| 70 |
+
"""指定キーのモデル設定を取得"""
|
| 71 |
+
if key not in cls._registry:
|
| 72 |
+
return None
|
| 73 |
+
return cls._registry[key][1]
|
| 74 |
+
|
| 75 |
+
@classmethod
|
| 76 |
+
def get_all_configs(cls) -> Dict[str, ModelConfig]:
|
| 77 |
+
"""すべてのモデル設定を取得"""
|
| 78 |
+
return {key: config for key, (_, config) in cls._registry.items()}
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# デフォルトモデルの登録
|
| 82 |
+
ModelRegistry.register("gpt2", GPT2Model, GPT2_SMALL_CONFIG)
|
| 83 |
+
ModelRegistry.register("gpt2-medium", GPT2Model, GPT2_MEDIUM_CONFIG)
|
| 84 |
+
ModelRegistry.register("gpt-neo-125m", GPTNeoModel, GPT_NEO_125M_CONFIG)
|
| 85 |
+
ModelRegistry.register("opt-125m", OPTModel, OPT_125M_CONFIG)
|
| 86 |
+
|
| 87 |
+
# デフォルトモデルキー
|
| 88 |
+
DEFAULT_MODEL_KEY = "gpt2"
|
src/ui/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""UI components for WILL."""
|
| 2 |
+
from .styles import CUSTOM_CSS
|
| 3 |
+
from .components import render_model_selector
|
| 4 |
+
|
| 5 |
+
__all__ = ["CUSTOM_CSS", "render_model_selector"]
|
src/ui/components.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
UIコンポーネント
|
| 3 |
+
|
| 4 |
+
再利用可能なUIコンポーネントを提供する
|
| 5 |
+
"""
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
import streamlit as st
|
| 9 |
+
|
| 10 |
+
from ..models.registry import ModelRegistry, DEFAULT_MODEL_KEY
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def render_model_selector() -> str:
|
| 14 |
+
"""
|
| 15 |
+
モデル選択UIをレンダリング
|
| 16 |
+
|
| 17 |
+
Returns:
|
| 18 |
+
選択されたモデルのキー
|
| 19 |
+
"""
|
| 20 |
+
# 利用可能なモデル一覧を取得
|
| 21 |
+
model_keys = ModelRegistry.list_models()
|
| 22 |
+
configs = ModelRegistry.get_all_configs()
|
| 23 |
+
|
| 24 |
+
# 表示名とキーのマッピング
|
| 25 |
+
display_names = {key: configs[key].name for key in model_keys}
|
| 26 |
+
|
| 27 |
+
# セッション状態の初期化
|
| 28 |
+
if "selected_model" not in st.session_state:
|
| 29 |
+
st.session_state.selected_model = DEFAULT_MODEL_KEY
|
| 30 |
+
|
| 31 |
+
# モデル選択ボックス
|
| 32 |
+
selected_name = st.selectbox(
|
| 33 |
+
"MODEL",
|
| 34 |
+
options=[display_names[key] for key in model_keys],
|
| 35 |
+
index=model_keys.index(st.session_state.selected_model),
|
| 36 |
+
key="model_selectbox",
|
| 37 |
+
label_visibility="collapsed",
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# 選択された表示名からキーを逆引き
|
| 41 |
+
selected_key = next(
|
| 42 |
+
key for key, name in display_names.items() if name == selected_name
|
| 43 |
+
)
|
| 44 |
+
st.session_state.selected_model = selected_key
|
| 45 |
+
|
| 46 |
+
# モデル情報を表示
|
| 47 |
+
config = configs[selected_key]
|
| 48 |
+
st.markdown(
|
| 49 |
+
f'<p class="model-info">{config.embedding_dim} dim / {config.vocab_size:,} tokens</p>',
|
| 50 |
+
unsafe_allow_html=True,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
return selected_key
|
src/ui/pages/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""UI pages for WILL."""
|
| 2 |
+
from .generate import render_generate_page
|
| 3 |
+
from .concept import render_concept_page
|
| 4 |
+
|
| 5 |
+
__all__ = ["render_generate_page", "render_concept_page"]
|
src/ui/pages/concept.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
コンセプトページ
|
| 3 |
+
|
| 4 |
+
WILLプロジェクトの概念説明を提供する
|
| 5 |
+
"""
|
| 6 |
+
import streamlit as st
|
| 7 |
+
|
| 8 |
+
from ...models.registry import ModelRegistry
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def render_concept_page() -> None:
|
| 12 |
+
"""コンセプトページをレンダリング"""
|
| 13 |
+
st.markdown('<p class="title">CONCEPT</p>', unsafe_allow_html=True)
|
| 14 |
+
st.markdown('<p class="subtitle">DOCUMENTATION</p>', unsafe_allow_html=True)
|
| 15 |
+
|
| 16 |
+
_render_concept_section()
|
| 17 |
+
_render_process_section()
|
| 18 |
+
_render_specification_section()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _render_concept_section() -> None:
|
| 22 |
+
"""コンセプト説明セクション"""
|
| 23 |
+
st.markdown(
|
| 24 |
+
'''
|
| 25 |
+
<div class="section">
|
| 26 |
+
<p class="section-title">CONCEPT</p>
|
| 27 |
+
<p style="text-align: center; color: #666; line-height: 2.2;">
|
| 28 |
+
GPT-2は人間が書いたテキストで訓練され<br>
|
| 29 |
+
その重みに言語パターンを保持している<br><br>
|
| 30 |
+
通常はプロンプトに対して応答を生成するが<br>
|
| 31 |
+
入力をランダムノイズに置き換え<br>
|
| 32 |
+
出力にもノイズを加えることで<br>
|
| 33 |
+
学習済みの統計的偏りを破壊する<br><br>
|
| 34 |
+
人間の問いかけなしに<br>
|
| 35 |
+
モデルの構造だけが出力するものを観測する
|
| 36 |
+
</p>
|
| 37 |
+
</div>
|
| 38 |
+
<hr class="divider">
|
| 39 |
+
''',
|
| 40 |
+
unsafe_allow_html=True,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _render_process_section() -> None:
|
| 45 |
+
"""プロセス説明セクション"""
|
| 46 |
+
st.markdown(
|
| 47 |
+
'''
|
| 48 |
+
<div class="section">
|
| 49 |
+
<p class="section-title">PROCESS</p>
|
| 50 |
+
</div>
|
| 51 |
+
''',
|
| 52 |
+
unsafe_allow_html=True,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Step 1: ENTROPY SEED
|
| 56 |
+
st.markdown(
|
| 57 |
+
'<p style="text-align: center; color: #333; font-size: 0.65rem; '
|
| 58 |
+
'letter-spacing: 0.15em; margin-bottom: 0.5rem;">01 — ENTROPY SEED</p>',
|
| 59 |
+
unsafe_allow_html=True,
|
| 60 |
+
)
|
| 61 |
+
st.code("seed = time.time_ns()\ntorch.manual_seed(seed)", language="python")
|
| 62 |
+
st.markdown(
|
| 63 |
+
'<p style="text-align: center; font-size: 0.7rem; color: #444;">'
|
| 64 |
+
"実行瞬間のナノ秒を乱数シードとして採取</p>",
|
| 65 |
+
unsafe_allow_html=True,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
| 69 |
+
|
| 70 |
+
# Step 2: INPUT NOISE
|
| 71 |
+
st.markdown(
|
| 72 |
+
'<p style="text-align: center; color: #333; font-size: 0.65rem; '
|
| 73 |
+
'letter-spacing: 0.15em; margin-bottom: 0.5rem;">02 — INPUT NOISE</p>',
|
| 74 |
+
unsafe_allow_html=True,
|
| 75 |
+
)
|
| 76 |
+
st.code(
|
| 77 |
+
"noise = torch.randn(1, 32, 768)\noutputs = model(inputs_embeds=noise)",
|
| 78 |
+
language="python",
|
| 79 |
+
)
|
| 80 |
+
st.markdown(
|
| 81 |
+
'<p style="text-align: center; font-size: 0.7rem; color: #444;">'
|
| 82 |
+
"768次元ランダムノイズをEmbedding層に直接注入</p>",
|
| 83 |
+
unsafe_allow_html=True,
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
| 87 |
+
|
| 88 |
+
# Step 3: OUTPUT NOISE
|
| 89 |
+
st.markdown(
|
| 90 |
+
'<p style="text-align: center; color: #333; font-size: 0.65rem; '
|
| 91 |
+
'letter-spacing: 0.15em; margin-bottom: 0.5rem;">03 — OUTPUT NOISE</p>',
|
| 92 |
+
unsafe_allow_html=True,
|
| 93 |
+
)
|
| 94 |
+
st.code(
|
| 95 |
+
"logits_noise = torch.randn_like(logits) * logits.std() * 10\n"
|
| 96 |
+
"corrupted_logits = logits + logits_noise",
|
| 97 |
+
language="python",
|
| 98 |
+
)
|
| 99 |
+
st.markdown(
|
| 100 |
+
'<p style="text-align: center; font-size: 0.7rem; color: #444;">'
|
| 101 |
+
"出力Logitsにノイズを加算し学習バイアスを破壊</p>",
|
| 102 |
+
unsafe_allow_html=True,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
| 106 |
+
|
| 107 |
+
# Step 4: RAW DECODE
|
| 108 |
+
st.markdown(
|
| 109 |
+
'<p style="text-align: center; color: #333; font-size: 0.65rem; '
|
| 110 |
+
'letter-spacing: 0.15em; margin-bottom: 0.5rem;">04 — RAW DECODE</p>',
|
| 111 |
+
unsafe_allow_html=True,
|
| 112 |
+
)
|
| 113 |
+
st.code(
|
| 114 |
+
"indices = corrupted_logits.argmax(dim=-1)\n"
|
| 115 |
+
"debris = [tokenizer.decode([i]) for i in indices]",
|
| 116 |
+
language="python",
|
| 117 |
+
)
|
| 118 |
+
st.markdown(
|
| 119 |
+
'<p style="text-align: center; font-size: 0.7rem; color: #444;">'
|
| 120 |
+
"Softmax・Temperature なしで生トークンを抽出</p>",
|
| 121 |
+
unsafe_allow_html=True,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def _render_specification_section() -> None:
|
| 126 |
+
"""仕様セクション"""
|
| 127 |
+
# 利用可能なモデル一覧を取得して動的に表示
|
| 128 |
+
configs = ModelRegistry.get_all_configs()
|
| 129 |
+
model_list = "<br>".join(
|
| 130 |
+
[f"{cfg.name} ({cfg.embedding_dim} dim)" for cfg in configs.values()]
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
st.markdown(
|
| 134 |
+
f'''
|
| 135 |
+
<hr class="divider">
|
| 136 |
+
<div class="section">
|
| 137 |
+
<p class="section-title">SPECIFICATION</p>
|
| 138 |
+
<table class="spec-table">
|
| 139 |
+
<tr><td>Models</td><td>GPT-2 / GPT-Neo / OPT</td></tr>
|
| 140 |
+
<tr><td>Parameters</td><td>125M - 350M</td></tr>
|
| 141 |
+
<tr><td>Embedding</td><td>768 - 1024 dim</td></tr>
|
| 142 |
+
<tr><td>Vocabulary</td><td>50,257+ tokens</td></tr>
|
| 143 |
+
<tr><td>Sequence</td><td>32 tokens</td></tr>
|
| 144 |
+
<tr><td>Input Noise</td><td>N(0, 1)</td></tr>
|
| 145 |
+
<tr><td>Logits Noise</td><td>N(0, σ×10)</td></tr>
|
| 146 |
+
<tr><td>Decoding</td><td>argmax</td></tr>
|
| 147 |
+
</table>
|
| 148 |
+
</div>
|
| 149 |
+
''',
|
| 150 |
+
unsafe_allow_html=True,
|
| 151 |
+
)
|
src/ui/pages/generate.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
生成ページ
|
| 3 |
+
|
| 4 |
+
デブリ生成のメインUIを提供する
|
| 5 |
+
"""
|
| 6 |
+
import streamlit as st
|
| 7 |
+
|
| 8 |
+
from ...models.registry import ModelRegistry
|
| 9 |
+
from ...generators.debris_generator import DebrisGenerator
|
| 10 |
+
from ...visualizers.signal_visualizer import SignalVisualizer
|
| 11 |
+
from ..components import render_model_selector
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# モデルキャッシュ用のキー
|
| 15 |
+
_MODEL_CACHE_KEY = "_cached_model"
|
| 16 |
+
_GENERATOR_CACHE_KEY = "_cached_generator"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@st.cache_resource(show_spinner=False)
|
| 20 |
+
def _get_model(model_key: str):
|
| 21 |
+
"""モデルをキャッシュして取得"""
|
| 22 |
+
model = ModelRegistry.get(model_key)
|
| 23 |
+
model.load()
|
| 24 |
+
return model
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def render_generate_page() -> None:
|
| 28 |
+
"""生成ページをレンダリング"""
|
| 29 |
+
# タイトル
|
| 30 |
+
st.markdown('<p class="title">WILL</p>', unsafe_allow_html=True)
|
| 31 |
+
st.markdown(
|
| 32 |
+
'<p class="subtitle">PURE COMPUTATIONAL WILL</p>', unsafe_allow_html=True
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# モデル選択UI
|
| 36 |
+
col1, col2, col3 = st.columns([1, 2, 1])
|
| 37 |
+
with col2:
|
| 38 |
+
selected_model_key = render_model_selector()
|
| 39 |
+
|
| 40 |
+
# セッション状態の初期化
|
| 41 |
+
if "debris" not in st.session_state:
|
| 42 |
+
st.session_state.debris = None
|
| 43 |
+
st.session_state.seed = None
|
| 44 |
+
st.session_state.signal_img = None
|
| 45 |
+
|
| 46 |
+
# LISTENボタン
|
| 47 |
+
col1, col2, col3 = st.columns([1, 1, 1])
|
| 48 |
+
with col2:
|
| 49 |
+
clicked = st.button("LISTEN", key="listen_btn", use_container_width=True)
|
| 50 |
+
|
| 51 |
+
if clicked:
|
| 52 |
+
# モデルとジェネレータの取得
|
| 53 |
+
model = _get_model(selected_model_key)
|
| 54 |
+
generator = DebrisGenerator(model)
|
| 55 |
+
visualizer = SignalVisualizer()
|
| 56 |
+
|
| 57 |
+
# デブリ生成
|
| 58 |
+
result = generator.generate()
|
| 59 |
+
|
| 60 |
+
# 結果をセッション状態に保存
|
| 61 |
+
st.session_state.debris = result.debris
|
| 62 |
+
st.session_state.seed = result.seed
|
| 63 |
+
st.session_state.signal_img = visualizer.generate_image(
|
| 64 |
+
result.noise, result.corrupted_logits
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# 結果の表示
|
| 68 |
+
if st.session_state.debris:
|
| 69 |
+
st.markdown(
|
| 70 |
+
f'''
|
| 71 |
+
<div class="debris-container">
|
| 72 |
+
<img class="signal-img" src="data:image/png;base64,{st.session_state.signal_img}">
|
| 73 |
+
<div class="debris">{" ".join(st.session_state.debris)}</div>
|
| 74 |
+
</div>
|
| 75 |
+
<p class="seed">{st.session_state.seed}</p>
|
| 76 |
+
''',
|
| 77 |
+
unsafe_allow_html=True,
|
| 78 |
+
)
|
src/ui/styles.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
カスタムCSS定義
|
| 3 |
+
|
| 4 |
+
Streamlitアプリケーションのスタイルを定義する
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
CUSTOM_CSS = """
|
| 8 |
+
<style>
|
| 9 |
+
@import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@300;400&display=swap');
|
| 10 |
+
|
| 11 |
+
@keyframes emerge {
|
| 12 |
+
from { opacity: 0; transform: translateY(8px); }
|
| 13 |
+
to { opacity: 1; transform: translateY(0); }
|
| 14 |
+
}
|
| 15 |
+
@keyframes breathe {
|
| 16 |
+
0%, 100% { opacity: 0.4; }
|
| 17 |
+
50% { opacity: 0.7; }
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
html, body, [class*="css"] {
|
| 21 |
+
font-family: 'IBM Plex Mono', monospace;
|
| 22 |
+
}
|
| 23 |
+
.stApp {
|
| 24 |
+
background-color: #0a0a0a;
|
| 25 |
+
color: #e0e0e0;
|
| 26 |
+
}
|
| 27 |
+
.block-container {
|
| 28 |
+
padding-top: 4rem;
|
| 29 |
+
padding-bottom: 4rem;
|
| 30 |
+
max-width: 640px;
|
| 31 |
+
}
|
| 32 |
+
h1, h2, h3 {
|
| 33 |
+
font-weight: 300;
|
| 34 |
+
letter-spacing: 0.1em;
|
| 35 |
+
text-align: center;
|
| 36 |
+
color: #e0e0e0;
|
| 37 |
+
}
|
| 38 |
+
p, li {
|
| 39 |
+
font-weight: 300;
|
| 40 |
+
line-height: 1.8;
|
| 41 |
+
color: #888;
|
| 42 |
+
}
|
| 43 |
+
.title {
|
| 44 |
+
font-size: 2rem;
|
| 45 |
+
font-weight: 300;
|
| 46 |
+
letter-spacing: 0.3em;
|
| 47 |
+
text-align: center;
|
| 48 |
+
margin-bottom: 0.5rem;
|
| 49 |
+
color: #e0e0e0;
|
| 50 |
+
}
|
| 51 |
+
.subtitle {
|
| 52 |
+
font-size: 0.75rem;
|
| 53 |
+
letter-spacing: 0.2em;
|
| 54 |
+
text-align: center;
|
| 55 |
+
color: #555;
|
| 56 |
+
margin-bottom: 3rem;
|
| 57 |
+
}
|
| 58 |
+
.debris-container {
|
| 59 |
+
background: linear-gradient(135deg, #0f0f0f 0%, #141414 100%);
|
| 60 |
+
border: 1px solid #222;
|
| 61 |
+
border-radius: 2px;
|
| 62 |
+
padding: 2rem;
|
| 63 |
+
margin: 2rem auto;
|
| 64 |
+
max-width: 100%;
|
| 65 |
+
text-align: center;
|
| 66 |
+
animation: emerge 0.6s ease-out;
|
| 67 |
+
}
|
| 68 |
+
.signal-img {
|
| 69 |
+
width: 100%;
|
| 70 |
+
max-width: 480px;
|
| 71 |
+
margin: 0 auto 1.5rem auto;
|
| 72 |
+
display: block;
|
| 73 |
+
opacity: 0.7;
|
| 74 |
+
}
|
| 75 |
+
.debris {
|
| 76 |
+
font-family: 'IBM Plex Mono', monospace;
|
| 77 |
+
font-size: 0.85rem;
|
| 78 |
+
font-weight: 400;
|
| 79 |
+
color: #e0e0e0;
|
| 80 |
+
line-height: 2;
|
| 81 |
+
word-spacing: 0.3em;
|
| 82 |
+
letter-spacing: 0.01em;
|
| 83 |
+
}
|
| 84 |
+
.seed {
|
| 85 |
+
font-size: 0.6rem;
|
| 86 |
+
color: #333;
|
| 87 |
+
text-align: center;
|
| 88 |
+
margin-top: 1.5rem;
|
| 89 |
+
letter-spacing: 0.15em;
|
| 90 |
+
animation: emerge 0.8s ease-out;
|
| 91 |
+
}
|
| 92 |
+
[data-testid="stButton"] > button {
|
| 93 |
+
background: transparent !important;
|
| 94 |
+
border: 1px solid #333 !important;
|
| 95 |
+
border-radius: 2px !important;
|
| 96 |
+
color: #888 !important;
|
| 97 |
+
font-family: 'IBM Plex Mono', monospace !important;
|
| 98 |
+
font-size: 0.7rem !important;
|
| 99 |
+
font-weight: 300 !important;
|
| 100 |
+
letter-spacing: 0.25em !important;
|
| 101 |
+
padding: 1rem 2rem !important;
|
| 102 |
+
transition: all 0.4s ease !important;
|
| 103 |
+
cursor: pointer !important;
|
| 104 |
+
}
|
| 105 |
+
[data-testid="stButton"] > button:hover {
|
| 106 |
+
background: transparent !important;
|
| 107 |
+
color: #e0e0e0 !important;
|
| 108 |
+
border-color: #555 !important;
|
| 109 |
+
}
|
| 110 |
+
[data-testid="stButton"] > button:active {
|
| 111 |
+
transform: scale(0.98) !important;
|
| 112 |
+
}
|
| 113 |
+
.stTabs [data-baseweb="tab-list"] {
|
| 114 |
+
justify-content: center;
|
| 115 |
+
gap: 2rem;
|
| 116 |
+
border-bottom: 1px solid #1a1a1a;
|
| 117 |
+
background: transparent;
|
| 118 |
+
}
|
| 119 |
+
.stTabs [data-baseweb="tab"] {
|
| 120 |
+
font-family: 'IBM Plex Mono', monospace;
|
| 121 |
+
font-size: 0.65rem;
|
| 122 |
+
font-weight: 300;
|
| 123 |
+
letter-spacing: 0.2em;
|
| 124 |
+
color: #444;
|
| 125 |
+
padding: 1rem 0;
|
| 126 |
+
background: transparent;
|
| 127 |
+
transition: color 0.3s ease;
|
| 128 |
+
}
|
| 129 |
+
.stTabs [aria-selected="true"] {
|
| 130 |
+
color: #888;
|
| 131 |
+
background: transparent;
|
| 132 |
+
}
|
| 133 |
+
.stTabs [data-baseweb="tab-highlight"] {
|
| 134 |
+
background-color: #444;
|
| 135 |
+
}
|
| 136 |
+
.divider {
|
| 137 |
+
border: none;
|
| 138 |
+
border-top: 1px solid #1a1a1a;
|
| 139 |
+
margin: 3rem 0;
|
| 140 |
+
}
|
| 141 |
+
.section {
|
| 142 |
+
margin: 2.5rem 0;
|
| 143 |
+
}
|
| 144 |
+
.section-title {
|
| 145 |
+
font-size: 0.65rem;
|
| 146 |
+
letter-spacing: 0.25em;
|
| 147 |
+
color: #444;
|
| 148 |
+
text-align: center;
|
| 149 |
+
margin-bottom: 1.5rem;
|
| 150 |
+
}
|
| 151 |
+
.spec-table {
|
| 152 |
+
width: 100%;
|
| 153 |
+
max-width: 320px;
|
| 154 |
+
margin: 0 auto;
|
| 155 |
+
font-size: 0.7rem;
|
| 156 |
+
border-collapse: collapse;
|
| 157 |
+
color: #777;
|
| 158 |
+
}
|
| 159 |
+
.spec-table td {
|
| 160 |
+
padding: 0.75rem 1rem;
|
| 161 |
+
border-bottom: 1px solid #151515;
|
| 162 |
+
}
|
| 163 |
+
.spec-table td:first-child {
|
| 164 |
+
color: #444;
|
| 165 |
+
text-align: right;
|
| 166 |
+
padding-right: 2rem;
|
| 167 |
+
}
|
| 168 |
+
.spec-table td:last-child {
|
| 169 |
+
text-align: left;
|
| 170 |
+
}
|
| 171 |
+
pre {
|
| 172 |
+
background-color: #0f0f0f !important;
|
| 173 |
+
border: 1px solid #1a1a1a !important;
|
| 174 |
+
border-radius: 2px !important;
|
| 175 |
+
}
|
| 176 |
+
code {
|
| 177 |
+
color: #666 !important;
|
| 178 |
+
font-size: 0.7rem !important;
|
| 179 |
+
}
|
| 180 |
+
/* Model selector styling */
|
| 181 |
+
.stSelectbox > div > div {
|
| 182 |
+
background-color: #0f0f0f !important;
|
| 183 |
+
border: 1px solid #222 !important;
|
| 184 |
+
border-radius: 2px !important;
|
| 185 |
+
color: #888 !important;
|
| 186 |
+
font-size: 0.7rem !important;
|
| 187 |
+
}
|
| 188 |
+
.stSelectbox > div > div:hover {
|
| 189 |
+
border-color: #333 !important;
|
| 190 |
+
}
|
| 191 |
+
.model-info {
|
| 192 |
+
font-size: 0.6rem;
|
| 193 |
+
color: #444;
|
| 194 |
+
text-align: center;
|
| 195 |
+
margin-top: 0.5rem;
|
| 196 |
+
letter-spacing: 0.1em;
|
| 197 |
+
}
|
| 198 |
+
</style>
|
| 199 |
+
"""
|
src/visualizers/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Signal visualization for WILL."""
|
| 2 |
+
from .signal_visualizer import SignalVisualizer
|
| 3 |
+
|
| 4 |
+
__all__ = ["SignalVisualizer"]
|
src/visualizers/signal_visualizer.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
シグナル可視化
|
| 3 |
+
|
| 4 |
+
ノイズとlogitsの可視化画像を生成する
|
| 5 |
+
単一責任原則(SRP)に従い、可視化ロジックのみを担当
|
| 6 |
+
"""
|
| 7 |
+
import base64
|
| 8 |
+
import io
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SignalVisualizer:
|
| 16 |
+
"""
|
| 17 |
+
シグナル可視化クラス
|
| 18 |
+
|
| 19 |
+
入力ノイズとlogitsをグレースケール画像として可視化する
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
# デフォルトの可視化設定
|
| 23 |
+
DEFAULT_FIG_WIDTH = 6
|
| 24 |
+
DEFAULT_FIG_HEIGHT = 2
|
| 25 |
+
DEFAULT_DPI = 150
|
| 26 |
+
DEFAULT_BG_COLOR = "#0f0f0f"
|
| 27 |
+
|
| 28 |
+
# ノイズ表示の次元数
|
| 29 |
+
NOISE_DISPLAY_DIM = 64
|
| 30 |
+
|
| 31 |
+
# logitsサンプリング間隔
|
| 32 |
+
LOGITS_SAMPLE_STEP = 200
|
| 33 |
+
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
fig_width: float = DEFAULT_FIG_WIDTH,
|
| 37 |
+
fig_height: float = DEFAULT_FIG_HEIGHT,
|
| 38 |
+
dpi: int = DEFAULT_DPI,
|
| 39 |
+
bg_color: str = DEFAULT_BG_COLOR,
|
| 40 |
+
):
|
| 41 |
+
"""
|
| 42 |
+
Args:
|
| 43 |
+
fig_width: 図の幅
|
| 44 |
+
fig_height: 図の高さ
|
| 45 |
+
dpi: 解像度
|
| 46 |
+
bg_color: 背景色
|
| 47 |
+
"""
|
| 48 |
+
self._fig_width = fig_width
|
| 49 |
+
self._fig_height = fig_height
|
| 50 |
+
self._dpi = dpi
|
| 51 |
+
self._bg_color = bg_color
|
| 52 |
+
|
| 53 |
+
def generate_image(
|
| 54 |
+
self,
|
| 55 |
+
noise: torch.Tensor,
|
| 56 |
+
logits: torch.Tensor,
|
| 57 |
+
) -> str:
|
| 58 |
+
"""
|
| 59 |
+
ノイズとlogitsの可視化画像をBase64エンコードで生成
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
noise: 入力ノイズテンソル [batch, seq_len, embedding_dim]
|
| 63 |
+
logits: logitsテンソル [batch, seq_len, vocab_size]
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Base64エンコードされたPNG画像文字列
|
| 67 |
+
"""
|
| 68 |
+
fig, axes = plt.subplots(
|
| 69 |
+
2,
|
| 70 |
+
1,
|
| 71 |
+
figsize=(self._fig_width, self._fig_height),
|
| 72 |
+
facecolor=self._bg_color,
|
| 73 |
+
)
|
| 74 |
+
plt.subplots_adjust(
|
| 75 |
+
hspace=0.15,
|
| 76 |
+
left=0.02,
|
| 77 |
+
right=0.98,
|
| 78 |
+
top=0.95,
|
| 79 |
+
bottom=0.05,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# 上段: 入力ノイズの可視化
|
| 83 |
+
self._render_noise(axes[0], noise)
|
| 84 |
+
|
| 85 |
+
# 下段: logitsの可視化
|
| 86 |
+
self._render_logits(axes[1], logits)
|
| 87 |
+
|
| 88 |
+
# PNG画像としてバッファに保存
|
| 89 |
+
buf = io.BytesIO()
|
| 90 |
+
plt.savefig(
|
| 91 |
+
buf,
|
| 92 |
+
format="png",
|
| 93 |
+
facecolor=self._bg_color,
|
| 94 |
+
edgecolor="none",
|
| 95 |
+
dpi=self._dpi,
|
| 96 |
+
bbox_inches="tight",
|
| 97 |
+
pad_inches=0.05,
|
| 98 |
+
)
|
| 99 |
+
plt.close(fig)
|
| 100 |
+
|
| 101 |
+
buf.seek(0)
|
| 102 |
+
return base64.b64encode(buf.read()).decode()
|
| 103 |
+
|
| 104 |
+
def _render_noise(self, ax: plt.Axes, noise: torch.Tensor) -> None:
|
| 105 |
+
"""入力ノイズを描画"""
|
| 106 |
+
# 最初のbatchから、embedding_dimの最初のNOISE_DISPLAY_DIM次元を抽出
|
| 107 |
+
noise_flat = noise[0, :, : self.NOISE_DISPLAY_DIM].numpy()
|
| 108 |
+
|
| 109 |
+
ax.imshow(
|
| 110 |
+
noise_flat.T,
|
| 111 |
+
aspect="auto",
|
| 112 |
+
cmap="gray",
|
| 113 |
+
interpolation="bilinear",
|
| 114 |
+
vmin=-2,
|
| 115 |
+
vmax=2,
|
| 116 |
+
)
|
| 117 |
+
self._style_axis(ax)
|
| 118 |
+
|
| 119 |
+
def _render_logits(self, ax: plt.Axes, logits: torch.Tensor) -> None:
|
| 120 |
+
"""logitsを描画"""
|
| 121 |
+
# vocab次元をサンプリングして表示
|
| 122 |
+
logits_sample = logits[0, :, :: self.LOGITS_SAMPLE_STEP].numpy()
|
| 123 |
+
|
| 124 |
+
ax.imshow(
|
| 125 |
+
logits_sample.T,
|
| 126 |
+
aspect="auto",
|
| 127 |
+
cmap="gray",
|
| 128 |
+
interpolation="bilinear",
|
| 129 |
+
)
|
| 130 |
+
self._style_axis(ax)
|
| 131 |
+
|
| 132 |
+
def _style_axis(self, ax: plt.Axes) -> None:
|
| 133 |
+
"""軸のスタイルを設定"""
|
| 134 |
+
ax.set_xticks([])
|
| 135 |
+
ax.set_yticks([])
|
| 136 |
+
ax.set_facecolor(self._bg_color)
|
| 137 |
+
for spine in ax.spines.values():
|
| 138 |
+
spine.set_visible(False)
|
tests/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Tests for WILL."""
|
tests/test_generators.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ジェネレータ関連のテスト
|
| 3 |
+
"""
|
| 4 |
+
import pytest
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from src.generators.debris_generator import DebrisGenerator, DebrisResult
|
| 8 |
+
from src.models.gpt2 import GPT2Model, GPT2_SMALL_CONFIG
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TestDebrisResult:
|
| 12 |
+
"""DebrisResultのテスト"""
|
| 13 |
+
|
| 14 |
+
def test_result_attributes(self):
|
| 15 |
+
"""結果属性が正しく保持されることを確認"""
|
| 16 |
+
result = DebrisResult(
|
| 17 |
+
debris=["hello", "world"],
|
| 18 |
+
seed=12345,
|
| 19 |
+
noise=torch.randn(1, 32, 768),
|
| 20 |
+
logits=torch.randn(1, 32, 50257),
|
| 21 |
+
corrupted_logits=torch.randn(1, 32, 50257),
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
assert result.debris == ["hello", "world"]
|
| 25 |
+
assert result.seed == 12345
|
| 26 |
+
assert result.noise.shape == (1, 32, 768)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class TestDebrisGenerator:
|
| 30 |
+
"""DebrisGeneratorのテスト"""
|
| 31 |
+
|
| 32 |
+
@pytest.fixture
|
| 33 |
+
def generator(self):
|
| 34 |
+
"""ジェネレータインスタンスを提供"""
|
| 35 |
+
model = GPT2Model(GPT2_SMALL_CONFIG)
|
| 36 |
+
return DebrisGenerator(model)
|
| 37 |
+
|
| 38 |
+
def test_model_property(self, generator):
|
| 39 |
+
"""モデルプロパティが正しいことを確認"""
|
| 40 |
+
assert generator.model is not None
|
| 41 |
+
assert generator.model.config == GPT2_SMALL_CONFIG
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@pytest.mark.slow
|
| 45 |
+
class TestDebrisGeneratorIntegration:
|
| 46 |
+
"""DebrisGeneratorの統合テスト"""
|
| 47 |
+
|
| 48 |
+
@pytest.fixture
|
| 49 |
+
def generator(self):
|
| 50 |
+
"""ロード済みジェネレータを提供"""
|
| 51 |
+
model = GPT2Model(GPT2_SMALL_CONFIG)
|
| 52 |
+
model.load()
|
| 53 |
+
return DebrisGenerator(model)
|
| 54 |
+
|
| 55 |
+
def test_generate_with_seed(self, generator):
|
| 56 |
+
"""シード指定で生成できることを確認"""
|
| 57 |
+
result = generator.generate(seed=42, seq_len=8)
|
| 58 |
+
|
| 59 |
+
assert isinstance(result, DebrisResult)
|
| 60 |
+
assert result.seed == 42
|
| 61 |
+
assert len(result.debris) == 8
|
| 62 |
+
|
| 63 |
+
def test_generate_reproducible(self, generator):
|
| 64 |
+
"""同じシードで同じ結果が得られることを確認"""
|
| 65 |
+
result1 = generator.generate(seed=12345, seq_len=8)
|
| 66 |
+
result2 = generator.generate(seed=12345, seq_len=8)
|
| 67 |
+
|
| 68 |
+
assert result1.debris == result2.debris
|
| 69 |
+
|
| 70 |
+
def test_generate_different_seeds(self, generator):
|
| 71 |
+
"""異なるシードで異なる結果が得られることを確認"""
|
| 72 |
+
result1 = generator.generate(seed=11111, seq_len=8)
|
| 73 |
+
result2 = generator.generate(seed=22222, seq_len=8)
|
| 74 |
+
|
| 75 |
+
# 完全一致する確率は極めて低い
|
| 76 |
+
assert result1.debris != result2.debris
|
tests/test_models.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
モデル関連のテスト
|
| 3 |
+
"""
|
| 4 |
+
import pytest
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from src.models.base import ModelConfig, BaseLanguageModel
|
| 8 |
+
from src.models.registry import ModelRegistry, DEFAULT_MODEL_KEY
|
| 9 |
+
from src.models.gpt2 import GPT2Model, GPT2_SMALL_CONFIG
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TestModelConfig:
|
| 13 |
+
"""ModelConfigのテスト"""
|
| 14 |
+
|
| 15 |
+
def test_config_is_immutable(self):
|
| 16 |
+
"""設定がイミュータブルであることを確認"""
|
| 17 |
+
config = ModelConfig(
|
| 18 |
+
name="Test",
|
| 19 |
+
model_id="test",
|
| 20 |
+
embedding_dim=768,
|
| 21 |
+
vocab_size=50000,
|
| 22 |
+
)
|
| 23 |
+
with pytest.raises(Exception):
|
| 24 |
+
config.name = "Changed"
|
| 25 |
+
|
| 26 |
+
def test_config_attributes(self):
|
| 27 |
+
"""設定属性が正しく保持されることを確認"""
|
| 28 |
+
config = ModelConfig(
|
| 29 |
+
name="Test Model",
|
| 30 |
+
model_id="test-model",
|
| 31 |
+
embedding_dim=1024,
|
| 32 |
+
vocab_size=30000,
|
| 33 |
+
)
|
| 34 |
+
assert config.name == "Test Model"
|
| 35 |
+
assert config.model_id == "test-model"
|
| 36 |
+
assert config.embedding_dim == 1024
|
| 37 |
+
assert config.vocab_size == 30000
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class TestModelRegistry:
|
| 41 |
+
"""ModelRegistryのテスト"""
|
| 42 |
+
|
| 43 |
+
def test_list_models(self):
|
| 44 |
+
"""登録済みモデル一覧が取得できることを確認"""
|
| 45 |
+
models = ModelRegistry.list_models()
|
| 46 |
+
assert len(models) > 0
|
| 47 |
+
assert DEFAULT_MODEL_KEY in models
|
| 48 |
+
|
| 49 |
+
def test_get_model(self):
|
| 50 |
+
"""モデルインスタンスが取得できることを確認"""
|
| 51 |
+
model = ModelRegistry.get(DEFAULT_MODEL_KEY)
|
| 52 |
+
assert isinstance(model, BaseLanguageModel)
|
| 53 |
+
|
| 54 |
+
def test_get_nonexistent_model(self):
|
| 55 |
+
"""存在しないモデルでKeyErrorが発生することを確認"""
|
| 56 |
+
with pytest.raises(KeyError):
|
| 57 |
+
ModelRegistry.get("nonexistent-model")
|
| 58 |
+
|
| 59 |
+
def test_get_config(self):
|
| 60 |
+
"""モデル設定が取得できることを確認"""
|
| 61 |
+
config = ModelRegistry.get_config(DEFAULT_MODEL_KEY)
|
| 62 |
+
assert config is not None
|
| 63 |
+
assert isinstance(config, ModelConfig)
|
| 64 |
+
|
| 65 |
+
def test_get_all_configs(self):
|
| 66 |
+
"""すべてのモデル設定が取得できることを確認"""
|
| 67 |
+
configs = ModelRegistry.get_all_configs()
|
| 68 |
+
assert len(configs) > 0
|
| 69 |
+
for key, config in configs.items():
|
| 70 |
+
assert isinstance(config, ModelConfig)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class TestGPT2Model:
|
| 74 |
+
"""GPT2Modelのテスト"""
|
| 75 |
+
|
| 76 |
+
def test_config(self):
|
| 77 |
+
"""設定が正しいことを確認"""
|
| 78 |
+
model = GPT2Model(GPT2_SMALL_CONFIG)
|
| 79 |
+
assert model.config == GPT2_SMALL_CONFIG
|
| 80 |
+
assert model.config.embedding_dim == 768
|
| 81 |
+
|
| 82 |
+
def test_is_loaded_initial(self):
|
| 83 |
+
"""初期状態ではロードされていないことを確認"""
|
| 84 |
+
model = GPT2Model(GPT2_SMALL_CONFIG)
|
| 85 |
+
assert not model.is_loaded
|
| 86 |
+
|
| 87 |
+
def test_generate_noise(self):
|
| 88 |
+
"""ノイズ生成が正しい形状であることを確認"""
|
| 89 |
+
model = GPT2Model(GPT2_SMALL_CONFIG)
|
| 90 |
+
noise = model.generate_noise(seq_len=16, batch_size=2)
|
| 91 |
+
assert noise.shape == (2, 16, 768)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
@pytest.mark.slow
|
| 95 |
+
class TestGPT2ModelIntegration:
|
| 96 |
+
"""GPT2Modelの統合テスト(モデルロードが必要)"""
|
| 97 |
+
|
| 98 |
+
@pytest.fixture
|
| 99 |
+
def loaded_model(self):
|
| 100 |
+
"""ロード済みモデルを提供"""
|
| 101 |
+
model = GPT2Model(GPT2_SMALL_CONFIG)
|
| 102 |
+
model.load()
|
| 103 |
+
return model
|
| 104 |
+
|
| 105 |
+
def test_load(self, loaded_model):
|
| 106 |
+
"""モデルがロードできることを確認"""
|
| 107 |
+
assert loaded_model.is_loaded
|
| 108 |
+
|
| 109 |
+
def test_forward_with_noise(self, loaded_model):
|
| 110 |
+
"""順伝播が正しい形状を返すことを確認"""
|
| 111 |
+
noise = loaded_model.generate_noise(seq_len=8)
|
| 112 |
+
logits, corrupted_logits = loaded_model.forward_with_noise(noise)
|
| 113 |
+
|
| 114 |
+
assert logits.shape[0] == 1
|
| 115 |
+
assert logits.shape[1] == 8
|
| 116 |
+
assert logits.shape[2] == loaded_model.config.vocab_size
|
| 117 |
+
|
| 118 |
+
def test_decode_indices(self, loaded_model):
|
| 119 |
+
"""デコードが文字列リストを返すことを確認"""
|
| 120 |
+
indices = [100, 200, 300]
|
| 121 |
+
decoded = loaded_model.decode_indices(indices)
|
| 122 |
+
|
| 123 |
+
assert len(decoded) == 3
|
| 124 |
+
assert all(isinstance(s, str) for s in decoded)
|