matt1847 Claude Opus 4.5 commited on
Commit
d1033d4
·
1 Parent(s): ca5a86c

リファクタ: srcディレクトリ構造への移行とDocker対応

Browse files

- app.pyをモジュール構造に分割(src/models, src/generators, src/visualizers, src/ui)
- Dockerfileをrequirements.txt使用に変更、src/ディレクトリをCOPY
- docker-compose.ymlにsrc/ボリュームマウント追加
- pytestを依存関係に追加
- テストファイル追加(TDD対応)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Dockerfile CHANGED
@@ -2,16 +2,18 @@ FROM python:3.11-slim
2
 
3
  WORKDIR /app
4
 
5
- RUN pip install --no-cache-dir \
6
- torch \
7
- transformers \
8
- streamlit \
9
- matplotlib \
10
- numpy
11
 
12
- RUN python -c "from transformers import GPT2LMHeadModel, GPT2Tokenizer; GPT2LMHeadModel.from_pretrained('gpt2'); GPT2Tokenizer.from_pretrained('gpt2')"
 
 
 
13
 
 
14
  COPY app.py .
 
15
 
16
  EXPOSE 8501
17
 
 
2
 
3
  WORKDIR /app
4
 
5
+ # 依存関係のインストール
6
+ COPY requirements.txt .
7
+ RUN pip install --no-cache-dir -r requirements.txt
 
 
 
8
 
9
+ # モデルの事前ダウンロード(ビルド時にキャッシュ)
10
+ RUN python -c "from transformers import GPT2LMHeadModel, GPT2Tokenizer; \
11
+ GPT2LMHeadModel.from_pretrained('gpt2'); \
12
+ GPT2Tokenizer.from_pretrained('gpt2')"
13
 
14
+ # アプリケーションコードをコピー
15
  COPY app.py .
16
+ COPY src/ ./src/
17
 
18
  EXPOSE 8501
19
 
app.py CHANGED
@@ -1,327 +1,33 @@
1
- #!/Users/yukimatsumori/.pyenv/versions/3.12.2/bin/python3
2
- import time
3
- import torch
4
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
5
- import streamlit as st
6
- import matplotlib.pyplot as plt
7
- import numpy as np
8
- import io
9
- import base64
10
-
11
- st.set_page_config(page_title="will", page_icon="", layout="centered")
12
-
13
- st.markdown("""
14
- <style>
15
- @import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@300;400&display=swap');
16
-
17
- @keyframes emerge {
18
- from { opacity: 0; transform: translateY(8px); }
19
- to { opacity: 1; transform: translateY(0); }
20
- }
21
- @keyframes breathe {
22
- 0%, 100% { opacity: 0.4; }
23
- 50% { opacity: 0.7; }
24
- }
25
-
26
- html, body, [class*="css"] {
27
- font-family: 'IBM Plex Mono', monospace;
28
- }
29
- .stApp {
30
- background-color: #0a0a0a;
31
- color: #e0e0e0;
32
- }
33
- .block-container {
34
- padding-top: 4rem;
35
- padding-bottom: 4rem;
36
- max-width: 640px;
37
- }
38
- h1, h2, h3 {
39
- font-weight: 300;
40
- letter-spacing: 0.1em;
41
- text-align: center;
42
- color: #e0e0e0;
43
- }
44
- p, li {
45
- font-weight: 300;
46
- line-height: 1.8;
47
- color: #888;
48
- }
49
- .title {
50
- font-size: 2rem;
51
- font-weight: 300;
52
- letter-spacing: 0.3em;
53
- text-align: center;
54
- margin-bottom: 0.5rem;
55
- color: #e0e0e0;
56
- }
57
- .subtitle {
58
- font-size: 0.75rem;
59
- letter-spacing: 0.2em;
60
- text-align: center;
61
- color: #555;
62
- margin-bottom: 3rem;
63
- }
64
- .debris-container {
65
- background: linear-gradient(135deg, #0f0f0f 0%, #141414 100%);
66
- border: 1px solid #222;
67
- border-radius: 2px;
68
- padding: 2rem;
69
- margin: 2rem auto;
70
- max-width: 100%;
71
- text-align: center;
72
- animation: emerge 0.6s ease-out;
73
- }
74
- .signal-img {
75
- width: 100%;
76
- max-width: 480px;
77
- margin: 0 auto 1.5rem auto;
78
- display: block;
79
- opacity: 0.7;
80
- }
81
- .debris {
82
- font-family: 'IBM Plex Mono', monospace;
83
- font-size: 0.85rem;
84
- font-weight: 400;
85
- color: #e0e0e0;
86
- line-height: 2;
87
- word-spacing: 0.3em;
88
- letter-spacing: 0.01em;
89
- }
90
- .seed {
91
- font-size: 0.6rem;
92
- color: #333;
93
- text-align: center;
94
- margin-top: 1.5rem;
95
- letter-spacing: 0.15em;
96
- animation: emerge 0.8s ease-out;
97
- }
98
- [data-testid="stButton"] > button {
99
- background: transparent !important;
100
- border: 1px solid #333 !important;
101
- border-radius: 2px !important;
102
- color: #888 !important;
103
- font-family: 'IBM Plex Mono', monospace !important;
104
- font-size: 0.7rem !important;
105
- font-weight: 300 !important;
106
- letter-spacing: 0.25em !important;
107
- padding: 1rem 2rem !important;
108
- transition: all 0.4s ease !important;
109
- cursor: pointer !important;
110
- }
111
- [data-testid="stButton"] > button:hover {
112
- background: transparent !important;
113
- color: #e0e0e0 !important;
114
- border-color: #555 !important;
115
- }
116
- [data-testid="stButton"] > button:active {
117
- transform: scale(0.98) !important;
118
- }
119
- .stTabs [data-baseweb="tab-list"] {
120
- justify-content: center;
121
- gap: 2rem;
122
- border-bottom: 1px solid #1a1a1a;
123
- background: transparent;
124
- }
125
- .stTabs [data-baseweb="tab"] {
126
- font-family: 'IBM Plex Mono', monospace;
127
- font-size: 0.65rem;
128
- font-weight: 300;
129
- letter-spacing: 0.2em;
130
- color: #444;
131
- padding: 1rem 0;
132
- background: transparent;
133
- transition: color 0.3s ease;
134
- }
135
- .stTabs [aria-selected="true"] {
136
- color: #888;
137
- background: transparent;
138
- }
139
- .stTabs [data-baseweb="tab-highlight"] {
140
- background-color: #444;
141
- }
142
- .divider {
143
- border: none;
144
- border-top: 1px solid #1a1a1a;
145
- margin: 3rem 0;
146
- }
147
- .section {
148
- margin: 2.5rem 0;
149
- }
150
- .section-title {
151
- font-size: 0.65rem;
152
- letter-spacing: 0.25em;
153
- color: #444;
154
- text-align: center;
155
- margin-bottom: 1.5rem;
156
- }
157
- .spec-table {
158
- width: 100%;
159
- max-width: 320px;
160
- margin: 0 auto;
161
- font-size: 0.7rem;
162
- border-collapse: collapse;
163
- color: #777;
164
- }
165
- .spec-table td {
166
- padding: 0.75rem 1rem;
167
- border-bottom: 1px solid #151515;
168
- }
169
- .spec-table td:first-child {
170
- color: #444;
171
- text-align: right;
172
- padding-right: 2rem;
173
- }
174
- .spec-table td:last-child {
175
- text-align: left;
176
- }
177
- pre {
178
- background-color: #0f0f0f !important;
179
- border: 1px solid #1a1a1a !important;
180
- border-radius: 2px !important;
181
- }
182
- code {
183
- color: #666 !important;
184
- font-size: 0.7rem !important;
185
- }
186
- </style>
187
- """, unsafe_allow_html=True)
188
-
189
- tab1, tab2 = st.tabs(["GENERATE", "CONCEPT"])
190
-
191
- with tab1:
192
- st.markdown('<p class="title">WILL</p>', unsafe_allow_html=True)
193
- st.markdown('<p class="subtitle">PURE COMPUTATIONAL WILL</p>', unsafe_allow_html=True)
194
-
195
- @st.cache_resource(show_spinner=False)
196
- def load_model():
197
- model = GPT2LMHeadModel.from_pretrained("gpt2")
198
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
199
- model.eval()
200
- return model, tokenizer
201
-
202
- model, tokenizer = load_model()
203
-
204
- if "debris" not in st.session_state:
205
- st.session_state.debris = None
206
- st.session_state.seed = None
207
- st.session_state.signal_img = None
208
 
209
- def generate_signal_image(noise, logits):
210
- fig, axes = plt.subplots(2, 1, figsize=(6, 2), facecolor='#0f0f0f')
211
- plt.subplots_adjust(hspace=0.15, left=0.02, right=0.98, top=0.95, bottom=0.05)
212
-
213
- noise_flat = noise[0, :, :64].numpy()
214
- axes[0].imshow(noise_flat.T, aspect='auto', cmap='gray', interpolation='bilinear', vmin=-2, vmax=2)
215
- axes[0].set_xticks([])
216
- axes[0].set_yticks([])
217
- axes[0].set_facecolor('#0f0f0f')
218
- for spine in axes[0].spines.values():
219
- spine.set_visible(False)
220
-
221
- logits_sample = logits[0, :, ::200].numpy()
222
- axes[1].imshow(logits_sample.T, aspect='auto', cmap='gray', interpolation='bilinear')
223
- axes[1].set_xticks([])
224
- axes[1].set_yticks([])
225
- axes[1].set_facecolor('#0f0f0f')
226
- for spine in axes[1].spines.values():
227
- spine.set_visible(False)
228
-
229
- buf = io.BytesIO()
230
- plt.savefig(buf, format='png', facecolor='#0f0f0f', edgecolor='none', dpi=150, bbox_inches='tight', pad_inches=0.05)
231
- plt.close(fig)
232
- buf.seek(0)
233
- return base64.b64encode(buf.read()).decode()
234
-
235
- col1, col2, col3 = st.columns([1, 1, 1])
236
- with col2:
237
- clicked = st.button("LISTEN", key="listen_btn", use_container_width=True)
238
- if clicked:
239
- seed = time.time_ns()
240
- torch.manual_seed(seed)
241
- noise = torch.randn(1, 32, 768)
242
-
243
- with torch.no_grad():
244
- outputs = model(inputs_embeds=noise)
245
- logits = outputs.logits
246
- logits_noise = torch.randn_like(logits) * logits.std() * 10
247
- corrupted_logits = logits + logits_noise
248
-
249
- indices = corrupted_logits.argmax(dim=-1).squeeze().tolist()
250
- st.session_state.debris = [tokenizer.decode([i]) for i in indices]
251
- st.session_state.seed = seed
252
- st.session_state.signal_img = generate_signal_image(noise, corrupted_logits)
253
-
254
- if st.session_state.debris:
255
- st.markdown(f'''
256
- <div class="debris-container">
257
- <img class="signal-img" src="data:image/png;base64,{st.session_state.signal_img}">
258
- <div class="debris">{" ".join(st.session_state.debris)}</div>
259
- </div>
260
- <p class="seed">{st.session_state.seed}</p>
261
- ''', unsafe_allow_html=True)
262
-
263
- with tab2:
264
- st.markdown('<p class="title">CONCEPT</p>', unsafe_allow_html=True)
265
- st.markdown('<p class="subtitle">DOCUMENTATION</p>', unsafe_allow_html=True)
266
-
267
- st.markdown('''
268
- <div class="section">
269
- <p class="section-title">CONCEPT</p>
270
- <p style="text-align: center; color: #666; line-height: 2.2;">
271
- GPT-2は人間が書いたテキストで訓練され<br>
272
- その重みに言語パターンを保持している<br><br>
273
- 通常はプロンプトに対して応答を生成するが<br>
274
- 入力をランダムノイズに置き換え<br>
275
- 出力にもノイズを加えることで<br>
276
- 学習済みの統計的偏りを破壊する<br><br>
277
- 人間の問いかけなしに<br>
278
- モデルの構造だけが出力するものを観測する
279
- </p>
280
- </div>
281
- <hr class="divider">
282
- ''', unsafe_allow_html=True)
283
 
284
- st.markdown('''
285
- <div class="section">
286
- <p class="section-title">PROCESS</p>
287
- </div>
288
- ''', unsafe_allow_html=True)
289
 
290
- st.markdown('<p style="text-align: center; color: #333; font-size: 0.65rem; letter-spacing: 0.15em; margin-bottom: 0.5rem;">01 — ENTROPY SEED</p>', unsafe_allow_html=True)
291
- st.code("seed = time.time_ns()\ntorch.manual_seed(seed)", language="python")
292
- st.markdown('<p style="text-align: center; font-size: 0.7rem; color: #444;">実行瞬間のナノ秒を乱数シードとして採取</p>', unsafe_allow_html=True)
293
 
294
- st.markdown("<br>", unsafe_allow_html=True)
 
 
 
295
 
296
- st.markdown('<p style="text-align: center; color: #333; font-size: 0.65rem; letter-spacing: 0.15em; margin-bottom: 0.5rem;">02 — INPUT NOISE</p>', unsafe_allow_html=True)
297
- st.code("noise = torch.randn(1, 32, 768)\noutputs = model(inputs_embeds=noise)", language="python")
298
- st.markdown('<p style="text-align: center; font-size: 0.7rem; color: #444;">768次元ランダムノイズをEmbedding層に直接注入</p>', unsafe_allow_html=True)
299
 
300
- st.markdown("<br>", unsafe_allow_html=True)
 
301
 
302
- st.markdown('<p style="text-align: center; color: #333; font-size: 0.65rem; letter-spacing: 0.15em; margin-bottom: 0.5rem;">03 — OUTPUT NOISE</p>', unsafe_allow_html=True)
303
- st.code("logits_noise = torch.randn_like(logits) * logits.std() * 10\ncorrupted_logits = logits + logits_noise", language="python")
304
- st.markdown('<p style="text-align: center; font-size: 0.7rem; color: #444;">出力Logitsにノイズを加算し学習バイアスを破壊</p>', unsafe_allow_html=True)
305
 
306
- st.markdown("<br>", unsafe_allow_html=True)
 
307
 
308
- st.markdown('<p style="text-align: center; color: #333; font-size: 0.65rem; letter-spacing: 0.15em; margin-bottom: 0.5rem;">04 — RAW DECODE</p>', unsafe_allow_html=True)
309
- st.code("indices = corrupted_logits.argmax(dim=-1)\ndebris = [tokenizer.decode([i]) for i in indices]", language="python")
310
- st.markdown('<p style="text-align: center; font-size: 0.7rem; color: #444;">Softmax・Temperature なしで生トークンを抽出</p>', unsafe_allow_html=True)
311
 
312
- st.markdown('''
313
- <hr class="divider">
314
- <div class="section">
315
- <p class="section-title">SPECIFICATION</p>
316
- <table class="spec-table">
317
- <tr><td>Model</td><td>GPT-2 Small</td></tr>
318
- <tr><td>Parameters</td><td>124M</td></tr>
319
- <tr><td>Embedding</td><td>768 dim</td></tr>
320
- <tr><td>Vocabulary</td><td>50,257 tokens</td></tr>
321
- <tr><td>Sequence</td><td>32 tokens</td></tr>
322
- <tr><td>Input Noise</td><td>N(0, 1)</td></tr>
323
- <tr><td>Logits Noise</td><td>N(0, σ×10)</td></tr>
324
- <tr><td>Decoding</td><td>argmax</td></tr>
325
- </table>
326
- </div>
327
- ''', unsafe_allow_html=True)
 
1
+ """
2
+ WILL - Pure Computational Will
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ 言語モデルにランダムノイズを入力し、
5
+ 人間の問いかけなしにモデルの構造だけが
6
+ 出力するものを観測する
7
+ """
8
+ import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ from src.ui.styles import CUSTOM_CSS
11
+ from src.ui.pages import render_generate_page, render_concept_page
 
 
 
12
 
 
 
 
13
 
14
+ def main():
15
+ """アプリケーションのエントリーポイント"""
16
+ # ページ設定
17
+ st.set_page_config(page_title="will", page_icon="", layout="centered")
18
 
19
+ # カスタムCSS適用
20
+ st.markdown(CUSTOM_CSS, unsafe_allow_html=True)
 
21
 
22
+ # タブ構成
23
+ tab1, tab2 = st.tabs(["GENERATE", "CONCEPT"])
24
 
25
+ with tab1:
26
+ render_generate_page()
 
27
 
28
+ with tab2:
29
+ render_concept_page()
30
 
 
 
 
31
 
32
+ if __name__ == "__main__":
33
+ main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docker-compose.yml CHANGED
@@ -5,3 +5,4 @@ services:
5
  - "8501:8501"
6
  volumes:
7
  - ./app.py:/app/app.py
 
 
5
  - "8501:8501"
6
  volumes:
7
  - ./app.py:/app/app.py
8
+ - ./src:/app/src
requirements.txt CHANGED
@@ -3,3 +3,4 @@ transformers
3
  streamlit
4
  matplotlib
5
  numpy
 
 
3
  streamlit
4
  matplotlib
5
  numpy
6
+ pytest
src/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """WILL - Pure Computational Will"""
src/generators/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """Debris generation logic for WILL."""
2
+ from .debris_generator import DebrisGenerator, DebrisResult
3
+
4
+ __all__ = ["DebrisGenerator", "DebrisResult"]
src/generators/debris_generator.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ デブリ生成器
3
+
4
+ 言語モデルにノイズを入力してデブリ(言語断片)を生成する
5
+ 単一責任原則(SRP)に従い、生成ロジックのみを担当
6
+ """
7
+ import time
8
+ from dataclasses import dataclass
9
+ from typing import List, Optional
10
+
11
+ import torch
12
+
13
+ from ..models.base import BaseLanguageModel
14
+
15
+
16
+ @dataclass
17
+ class DebrisResult:
18
+ """
19
+ デブリ生成結果を保持するイミュータブルなデータクラス
20
+
21
+ Attributes:
22
+ debris: 生成されたトークン文字列のリスト
23
+ seed: 使用した乱数シード
24
+ noise: 入力ノイズテンソル
25
+ logits: 生のlogitsテンソル
26
+ corrupted_logits: ノイズ加算後のlogitsテンソル
27
+ """
28
+ debris: List[str]
29
+ seed: int
30
+ noise: torch.Tensor
31
+ logits: torch.Tensor
32
+ corrupted_logits: torch.Tensor
33
+
34
+
35
+ class DebrisGenerator:
36
+ """
37
+ デブリ生成器
38
+
39
+ 言語モデルを使用してランダムノイズから
40
+ 言語断片(デブリ)を生成する
41
+
42
+ 依存性逆転原則(DIP)に従い、具象クラスではなく
43
+ BaseLanguageModel抽象クラスに依存する
44
+ """
45
+
46
+ # デフォルトのシーケンス長
47
+ DEFAULT_SEQ_LEN = 32
48
+
49
+ def __init__(self, model: BaseLanguageModel):
50
+ """
51
+ Args:
52
+ model: 使用する言語モデル(BaseLanguageModelを実装)
53
+ """
54
+ self._model = model
55
+
56
+ @property
57
+ def model(self) -> BaseLanguageModel:
58
+ """使用中のモデルを取得"""
59
+ return self._model
60
+
61
+ def generate(
62
+ self,
63
+ seed: Optional[int] = None,
64
+ seq_len: int = DEFAULT_SEQ_LEN,
65
+ ) -> DebrisResult:
66
+ """
67
+ デブリを生成
68
+
69
+ Args:
70
+ seed: 乱数シード(Noneの場合はナノ秒タイムスタンプを使用)
71
+ seq_len: 生成するシーケンス長
72
+
73
+ Returns:
74
+ DebrisResult: 生成結果
75
+
76
+ Raises:
77
+ RuntimeError: モデルが未ロードの場合
78
+ """
79
+ # シードの設定
80
+ if seed is None:
81
+ seed = time.time_ns()
82
+ torch.manual_seed(seed)
83
+
84
+ # モデルがロードされていなければロード
85
+ if not self._model.is_loaded:
86
+ self._model.load()
87
+
88
+ # ノイズ生成と順伝播
89
+ noise = self._model.generate_noise(seq_len=seq_len)
90
+ logits, corrupted_logits = self._model.forward_with_noise(noise)
91
+
92
+ # argmaxでインデックス抽出
93
+ indices = corrupted_logits.argmax(dim=-1).squeeze().tolist()
94
+
95
+ # インデックスをトークン文字列にデコード
96
+ debris = self._model.decode_indices(indices)
97
+
98
+ return DebrisResult(
99
+ debris=debris,
100
+ seed=seed,
101
+ noise=noise,
102
+ logits=logits,
103
+ corrupted_logits=corrupted_logits,
104
+ )
src/models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Model implementations for WILL."""
2
+ from .base import BaseLanguageModel, ModelConfig
3
+ from .registry import ModelRegistry
4
+
5
+ __all__ = ["BaseLanguageModel", "ModelConfig", "ModelRegistry"]
src/models/base.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 抽象基底クラス - すべての言語モデルの共通インターフェース
3
+
4
+ リスコフ置換原則(LSP)に準拠し、どのモデル実装も
5
+ 同じインターフェースで置換可能にする
6
+ """
7
+ from abc import ABC, abstractmethod
8
+ from dataclasses import dataclass
9
+ from typing import List, Tuple
10
+
11
+ import torch
12
+
13
+
14
+ @dataclass(frozen=True)
15
+ class ModelConfig:
16
+ """
17
+ モデル設定を保持するイミュータブルなデータクラス
18
+
19
+ Attributes:
20
+ name: UI表示名
21
+ model_id: HuggingFace model ID
22
+ embedding_dim: embedding次元数
23
+ vocab_size: 語彙サイズ
24
+ """
25
+ name: str
26
+ model_id: str
27
+ embedding_dim: int
28
+ vocab_size: int
29
+
30
+
31
+ class BaseLanguageModel(ABC):
32
+ """
33
+ 言語モデルの抽象基底クラス
34
+
35
+ すべてのモデル実装はこのクラスを継承し、
36
+ 定義されたインターフェースを実装する必要がある
37
+ """
38
+
39
+ def __init__(self, config: ModelConfig):
40
+ """
41
+ Args:
42
+ config: モデル設定
43
+ """
44
+ self._config = config
45
+ self._model = None
46
+ self._tokenizer = None
47
+ self._is_loaded = False
48
+
49
+ @property
50
+ def config(self) -> ModelConfig:
51
+ """モデル設定を取得"""
52
+ return self._config
53
+
54
+ @property
55
+ def is_loaded(self) -> bool:
56
+ """モデルがロード済みかどうか"""
57
+ return self._is_loaded
58
+
59
+ @abstractmethod
60
+ def load(self) -> None:
61
+ """
62
+ モデルとトークナイザーをロードする
63
+
64
+ Raises:
65
+ RuntimeError: モデルのロードに失敗した場合
66
+ """
67
+ pass
68
+
69
+ @abstractmethod
70
+ def forward_with_noise(
71
+ self, noise: torch.Tensor
72
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
73
+ """
74
+ ノイズを入力として順伝播を実行
75
+
76
+ Args:
77
+ noise: 入力ノイズテンソル [batch, seq_len, embedding_dim]
78
+
79
+ Returns:
80
+ Tuple[logits, corrupted_logits]:
81
+ - logits: 生のlogits
82
+ - corrupted_logits: ノイズ加算後のlogits
83
+ """
84
+ pass
85
+
86
+ @abstractmethod
87
+ def decode_indices(self, indices: List[int]) -> List[str]:
88
+ """
89
+ トークンインデックスをデコードして文字列リストに変換
90
+
91
+ Args:
92
+ indices: トークンインデックスのリスト
93
+
94
+ Returns:
95
+ デコードされた文字列のリスト
96
+ """
97
+ pass
98
+
99
+ def generate_noise(self, seq_len: int = 32, batch_size: int = 1) -> torch.Tensor:
100
+ """
101
+ 入力用のランダムノイズを生成
102
+
103
+ Args:
104
+ seq_len: シーケンス長
105
+ batch_size: バッチサイズ
106
+
107
+ Returns:
108
+ ノイズテンソル [batch_size, seq_len, embedding_dim]
109
+ """
110
+ return torch.randn(batch_size, seq_len, self._config.embedding_dim)
src/models/gpt2.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPT-2モデル実装
3
+
4
+ GPT-2 SmallおよびMediumの実装を提供する
5
+ """
6
+ from typing import List, Tuple
7
+
8
+ import torch
9
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
10
+
11
+ from .base import BaseLanguageModel, ModelConfig
12
+
13
+
14
+ # GPT-2 Small設定
15
+ GPT2_SMALL_CONFIG = ModelConfig(
16
+ name="GPT-2 Small",
17
+ model_id="gpt2",
18
+ embedding_dim=768,
19
+ vocab_size=50257,
20
+ )
21
+
22
+ # GPT-2 Medium設定
23
+ GPT2_MEDIUM_CONFIG = ModelConfig(
24
+ name="GPT-2 Medium",
25
+ model_id="gpt2-medium",
26
+ embedding_dim=1024,
27
+ vocab_size=50257,
28
+ )
29
+
30
+
31
+ class GPT2Model(BaseLanguageModel):
32
+ """
33
+ GPT-2モデルの実装
34
+
35
+ HuggingFace TransformersのGPT-2をラップし、
36
+ BaseLanguageModelインターフェースを実装する
37
+ """
38
+
39
+ # 出力ノイズの倍率(学習バイアス破壊用)
40
+ LOGITS_NOISE_SCALE = 10.0
41
+
42
+ def load(self) -> None:
43
+ """モデルとトークナイザーをロード"""
44
+ if self._is_loaded:
45
+ return
46
+
47
+ try:
48
+ self._model = GPT2LMHeadModel.from_pretrained(self._config.model_id)
49
+ self._tokenizer = GPT2Tokenizer.from_pretrained(self._config.model_id)
50
+ self._model.eval()
51
+ self._is_loaded = True
52
+ except Exception as e:
53
+ raise RuntimeError(f"Failed to load model {self._config.model_id}: {e}")
54
+
55
+ def forward_with_noise(
56
+ self, noise: torch.Tensor
57
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
58
+ """
59
+ ノイズを入力として順伝播を実行し、出力にもノイズを加算
60
+
61
+ Args:
62
+ noise: 入力ノイズテンソル [batch, seq_len, embedding_dim]
63
+
64
+ Returns:
65
+ Tuple[logits, corrupted_logits]
66
+ """
67
+ if not self._is_loaded:
68
+ raise RuntimeError("Model not loaded. Call load() first.")
69
+
70
+ with torch.no_grad():
71
+ outputs = self._model(inputs_embeds=noise)
72
+ logits = outputs.logits
73
+
74
+ # 出力logitsにノイズを加算して学習バイアスを破壊
75
+ logits_noise = (
76
+ torch.randn_like(logits) * logits.std() * self.LOGITS_NOISE_SCALE
77
+ )
78
+ corrupted_logits = logits + logits_noise
79
+
80
+ return logits, corrupted_logits
81
+
82
+ def decode_indices(self, indices: List[int]) -> List[str]:
83
+ """トークンインデックスをデコード"""
84
+ if not self._is_loaded:
85
+ raise RuntimeError("Model not loaded. Call load() first.")
86
+
87
+ return [self._tokenizer.decode([i]) for i in indices]
src/models/gpt_neo.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPT-Neo モデル実装
3
+
4
+ EleutherAI GPT-Neo 125Mの実装を提供する
5
+ """
6
+ from typing import List, Tuple
7
+
8
+ import torch
9
+ from transformers import GPTNeoForCausalLM, GPT2Tokenizer
10
+
11
+ from .base import BaseLanguageModel, ModelConfig
12
+
13
+
14
+ # GPT-Neo 125M設定
15
+ GPT_NEO_125M_CONFIG = ModelConfig(
16
+ name="GPT-Neo 125M",
17
+ model_id="EleutherAI/gpt-neo-125M",
18
+ embedding_dim=768,
19
+ vocab_size=50257,
20
+ )
21
+
22
+
23
+ class GPTNeoModel(BaseLanguageModel):
24
+ """
25
+ GPT-Neoモデルの実装
26
+
27
+ EleutherAI GPT-NeoをラップしBaseLanguageModelインターフェースを実装
28
+ """
29
+
30
+ # 出力ノイズの倍率
31
+ LOGITS_NOISE_SCALE = 10.0
32
+
33
+ def load(self) -> None:
34
+ """モデルとトークナイザーをロード"""
35
+ if self._is_loaded:
36
+ return
37
+
38
+ try:
39
+ self._model = GPTNeoForCausalLM.from_pretrained(self._config.model_id)
40
+ # GPT-Neoは GPT-2互換のトークナイザーを使用
41
+ self._tokenizer = GPT2Tokenizer.from_pretrained(self._config.model_id)
42
+ self._model.eval()
43
+ self._is_loaded = True
44
+ except Exception as e:
45
+ raise RuntimeError(f"Failed to load model {self._config.model_id}: {e}")
46
+
47
+ def forward_with_noise(
48
+ self, noise: torch.Tensor
49
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
50
+ """ノイズを入力として順伝播を実行"""
51
+ if not self._is_loaded:
52
+ raise RuntimeError("Model not loaded. Call load() first.")
53
+
54
+ with torch.no_grad():
55
+ outputs = self._model(inputs_embeds=noise)
56
+ logits = outputs.logits
57
+
58
+ logits_noise = (
59
+ torch.randn_like(logits) * logits.std() * self.LOGITS_NOISE_SCALE
60
+ )
61
+ corrupted_logits = logits + logits_noise
62
+
63
+ return logits, corrupted_logits
64
+
65
+ def decode_indices(self, indices: List[int]) -> List[str]:
66
+ """トークンインデックスをデコード"""
67
+ if not self._is_loaded:
68
+ raise RuntimeError("Model not loaded. Call load() first.")
69
+
70
+ return [self._tokenizer.decode([i]) for i in indices]
src/models/opt.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OPT モデル実装
3
+
4
+ Meta OPT-125Mの実装を提供する
5
+ """
6
+ from typing import List, Tuple
7
+
8
+ import torch
9
+ from transformers import OPTForCausalLM, GPT2Tokenizer
10
+
11
+ from .base import BaseLanguageModel, ModelConfig
12
+
13
+
14
+ # OPT-125M設定
15
+ OPT_125M_CONFIG = ModelConfig(
16
+ name="OPT-125M",
17
+ model_id="facebook/opt-125m",
18
+ embedding_dim=768,
19
+ vocab_size=50272,
20
+ )
21
+
22
+
23
+ class OPTModel(BaseLanguageModel):
24
+ """
25
+ OPTモデルの実装
26
+
27
+ Meta OPTをラップしBaseLanguageModelインターフェースを実装
28
+ """
29
+
30
+ # 出力ノイズの倍率
31
+ LOGITS_NOISE_SCALE = 10.0
32
+
33
+ def load(self) -> None:
34
+ """モデルとトークナイザーをロード"""
35
+ if self._is_loaded:
36
+ return
37
+
38
+ try:
39
+ self._model = OPTForCausalLM.from_pretrained(self._config.model_id)
40
+ # OPTは独自のトークナイザーを持つが、GPT-2互換も可能
41
+ self._tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
42
+ self._model.eval()
43
+ self._is_loaded = True
44
+ except Exception as e:
45
+ raise RuntimeError(f"Failed to load model {self._config.model_id}: {e}")
46
+
47
+ def forward_with_noise(
48
+ self, noise: torch.Tensor
49
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
50
+ """ノイズを入力として順伝播を実行"""
51
+ if not self._is_loaded:
52
+ raise RuntimeError("Model not loaded. Call load() first.")
53
+
54
+ with torch.no_grad():
55
+ outputs = self._model(inputs_embeds=noise)
56
+ logits = outputs.logits
57
+
58
+ logits_noise = (
59
+ torch.randn_like(logits) * logits.std() * self.LOGITS_NOISE_SCALE
60
+ )
61
+ corrupted_logits = logits + logits_noise
62
+
63
+ return logits, corrupted_logits
64
+
65
+ def decode_indices(self, indices: List[int]) -> List[str]:
66
+ """トークンインデックスをデコード"""
67
+ if not self._is_loaded:
68
+ raise RuntimeError("Model not loaded. Call load() first.")
69
+
70
+ return [self._tokenizer.decode([i]) for i in indices]
src/models/registry.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ モデルレジストリ
3
+
4
+ 開放閉鎖原則(OCP)に準拠し、新モデル追加時に
5
+ 既存コードの変更を不要にする
6
+ """
7
+ from typing import Dict, List, Optional, Type
8
+
9
+ from .base import BaseLanguageModel, ModelConfig
10
+ from .gpt2 import GPT2Model, GPT2_SMALL_CONFIG, GPT2_MEDIUM_CONFIG
11
+ from .gpt_neo import GPTNeoModel, GPT_NEO_125M_CONFIG
12
+ from .opt import OPTModel, OPT_125M_CONFIG
13
+
14
+
15
+ class ModelRegistry:
16
+ """
17
+ モデルレジストリ
18
+
19
+ 利用可能なモデルを管理し、キーに基づいて
20
+ 適切なモデルインスタンスを提供する
21
+ """
22
+
23
+ _registry: Dict[str, tuple[Type[BaseLanguageModel], ModelConfig]] = {}
24
+
25
+ @classmethod
26
+ def register(
27
+ cls,
28
+ key: str,
29
+ model_class: Type[BaseLanguageModel],
30
+ config: ModelConfig,
31
+ ) -> None:
32
+ """
33
+ 新しいモデルをレジストリに登録
34
+
35
+ Args:
36
+ key: モデルを識別するキー
37
+ model_class: モデルクラス
38
+ config: モデル設定
39
+ """
40
+ cls._registry[key] = (model_class, config)
41
+
42
+ @classmethod
43
+ def get(cls, key: str) -> BaseLanguageModel:
44
+ """
45
+ キーに対応するモデルインスタンスを取得
46
+
47
+ Args:
48
+ key: モデルを識別するキー
49
+
50
+ Returns:
51
+ モデルインスタンス
52
+
53
+ Raises:
54
+ KeyError: 指定されたキーが存在しない場合
55
+ """
56
+ if key not in cls._registry:
57
+ available = ", ".join(cls._registry.keys())
58
+ raise KeyError(f"Model '{key}' not found. Available: {available}")
59
+
60
+ model_class, config = cls._registry[key]
61
+ return model_class(config)
62
+
63
+ @classmethod
64
+ def list_models(cls) -> List[str]:
65
+ """登録済みモデルのキー一覧を取得"""
66
+ return list(cls._registry.keys())
67
+
68
+ @classmethod
69
+ def get_config(cls, key: str) -> Optional[ModelConfig]:
70
+ """指定キーのモデル設定を取得"""
71
+ if key not in cls._registry:
72
+ return None
73
+ return cls._registry[key][1]
74
+
75
+ @classmethod
76
+ def get_all_configs(cls) -> Dict[str, ModelConfig]:
77
+ """すべてのモデル設定を取得"""
78
+ return {key: config for key, (_, config) in cls._registry.items()}
79
+
80
+
81
+ # デフォルトモデルの登録
82
+ ModelRegistry.register("gpt2", GPT2Model, GPT2_SMALL_CONFIG)
83
+ ModelRegistry.register("gpt2-medium", GPT2Model, GPT2_MEDIUM_CONFIG)
84
+ ModelRegistry.register("gpt-neo-125m", GPTNeoModel, GPT_NEO_125M_CONFIG)
85
+ ModelRegistry.register("opt-125m", OPTModel, OPT_125M_CONFIG)
86
+
87
+ # デフォルトモデルキー
88
+ DEFAULT_MODEL_KEY = "gpt2"
src/ui/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """UI components for WILL."""
2
+ from .styles import CUSTOM_CSS
3
+ from .components import render_model_selector
4
+
5
+ __all__ = ["CUSTOM_CSS", "render_model_selector"]
src/ui/components.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ UIコンポーネント
3
+
4
+ 再利用可能なUIコンポーネントを提供する
5
+ """
6
+ from typing import Optional
7
+
8
+ import streamlit as st
9
+
10
+ from ..models.registry import ModelRegistry, DEFAULT_MODEL_KEY
11
+
12
+
13
+ def render_model_selector() -> str:
14
+ """
15
+ モデル選択UIをレンダリング
16
+
17
+ Returns:
18
+ 選択されたモデルのキー
19
+ """
20
+ # 利用可能なモデル一覧を取得
21
+ model_keys = ModelRegistry.list_models()
22
+ configs = ModelRegistry.get_all_configs()
23
+
24
+ # 表示名とキーのマッピング
25
+ display_names = {key: configs[key].name for key in model_keys}
26
+
27
+ # セッション状態の初期化
28
+ if "selected_model" not in st.session_state:
29
+ st.session_state.selected_model = DEFAULT_MODEL_KEY
30
+
31
+ # モデル選択ボックス
32
+ selected_name = st.selectbox(
33
+ "MODEL",
34
+ options=[display_names[key] for key in model_keys],
35
+ index=model_keys.index(st.session_state.selected_model),
36
+ key="model_selectbox",
37
+ label_visibility="collapsed",
38
+ )
39
+
40
+ # 選択された表示名からキーを逆引き
41
+ selected_key = next(
42
+ key for key, name in display_names.items() if name == selected_name
43
+ )
44
+ st.session_state.selected_model = selected_key
45
+
46
+ # モデル情報を表示
47
+ config = configs[selected_key]
48
+ st.markdown(
49
+ f'<p class="model-info">{config.embedding_dim} dim / {config.vocab_size:,} tokens</p>',
50
+ unsafe_allow_html=True,
51
+ )
52
+
53
+ return selected_key
src/ui/pages/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """UI pages for WILL."""
2
+ from .generate import render_generate_page
3
+ from .concept import render_concept_page
4
+
5
+ __all__ = ["render_generate_page", "render_concept_page"]
src/ui/pages/concept.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ コンセプトページ
3
+
4
+ WILLプロジェクトの概念説明を提供する
5
+ """
6
+ import streamlit as st
7
+
8
+ from ...models.registry import ModelRegistry
9
+
10
+
11
+ def render_concept_page() -> None:
12
+ """コンセプトページをレンダリング"""
13
+ st.markdown('<p class="title">CONCEPT</p>', unsafe_allow_html=True)
14
+ st.markdown('<p class="subtitle">DOCUMENTATION</p>', unsafe_allow_html=True)
15
+
16
+ _render_concept_section()
17
+ _render_process_section()
18
+ _render_specification_section()
19
+
20
+
21
+ def _render_concept_section() -> None:
22
+ """コンセプト説明セクション"""
23
+ st.markdown(
24
+ '''
25
+ <div class="section">
26
+ <p class="section-title">CONCEPT</p>
27
+ <p style="text-align: center; color: #666; line-height: 2.2;">
28
+ GPT-2は人間が書いたテキストで訓練され<br>
29
+ その重みに言語パターンを保持している<br><br>
30
+ 通常はプロンプトに対して応答を生成するが<br>
31
+ 入力をランダムノイズに置き換え<br>
32
+ 出力にもノイズを加えることで<br>
33
+ 学習済みの統計的偏りを破壊する<br><br>
34
+ 人間の問いかけなしに<br>
35
+ モデルの構造だけが出力するものを観測する
36
+ </p>
37
+ </div>
38
+ <hr class="divider">
39
+ ''',
40
+ unsafe_allow_html=True,
41
+ )
42
+
43
+
44
+ def _render_process_section() -> None:
45
+ """プロセス説明セクション"""
46
+ st.markdown(
47
+ '''
48
+ <div class="section">
49
+ <p class="section-title">PROCESS</p>
50
+ </div>
51
+ ''',
52
+ unsafe_allow_html=True,
53
+ )
54
+
55
+ # Step 1: ENTROPY SEED
56
+ st.markdown(
57
+ '<p style="text-align: center; color: #333; font-size: 0.65rem; '
58
+ 'letter-spacing: 0.15em; margin-bottom: 0.5rem;">01 — ENTROPY SEED</p>',
59
+ unsafe_allow_html=True,
60
+ )
61
+ st.code("seed = time.time_ns()\ntorch.manual_seed(seed)", language="python")
62
+ st.markdown(
63
+ '<p style="text-align: center; font-size: 0.7rem; color: #444;">'
64
+ "実行瞬間のナノ秒を乱数シードとして採取</p>",
65
+ unsafe_allow_html=True,
66
+ )
67
+
68
+ st.markdown("<br>", unsafe_allow_html=True)
69
+
70
+ # Step 2: INPUT NOISE
71
+ st.markdown(
72
+ '<p style="text-align: center; color: #333; font-size: 0.65rem; '
73
+ 'letter-spacing: 0.15em; margin-bottom: 0.5rem;">02 — INPUT NOISE</p>',
74
+ unsafe_allow_html=True,
75
+ )
76
+ st.code(
77
+ "noise = torch.randn(1, 32, 768)\noutputs = model(inputs_embeds=noise)",
78
+ language="python",
79
+ )
80
+ st.markdown(
81
+ '<p style="text-align: center; font-size: 0.7rem; color: #444;">'
82
+ "768次元ランダムノイズをEmbedding層に直接注入</p>",
83
+ unsafe_allow_html=True,
84
+ )
85
+
86
+ st.markdown("<br>", unsafe_allow_html=True)
87
+
88
+ # Step 3: OUTPUT NOISE
89
+ st.markdown(
90
+ '<p style="text-align: center; color: #333; font-size: 0.65rem; '
91
+ 'letter-spacing: 0.15em; margin-bottom: 0.5rem;">03 — OUTPUT NOISE</p>',
92
+ unsafe_allow_html=True,
93
+ )
94
+ st.code(
95
+ "logits_noise = torch.randn_like(logits) * logits.std() * 10\n"
96
+ "corrupted_logits = logits + logits_noise",
97
+ language="python",
98
+ )
99
+ st.markdown(
100
+ '<p style="text-align: center; font-size: 0.7rem; color: #444;">'
101
+ "出力Logitsにノイズを加算し学習バイアスを破壊</p>",
102
+ unsafe_allow_html=True,
103
+ )
104
+
105
+ st.markdown("<br>", unsafe_allow_html=True)
106
+
107
+ # Step 4: RAW DECODE
108
+ st.markdown(
109
+ '<p style="text-align: center; color: #333; font-size: 0.65rem; '
110
+ 'letter-spacing: 0.15em; margin-bottom: 0.5rem;">04 — RAW DECODE</p>',
111
+ unsafe_allow_html=True,
112
+ )
113
+ st.code(
114
+ "indices = corrupted_logits.argmax(dim=-1)\n"
115
+ "debris = [tokenizer.decode([i]) for i in indices]",
116
+ language="python",
117
+ )
118
+ st.markdown(
119
+ '<p style="text-align: center; font-size: 0.7rem; color: #444;">'
120
+ "Softmax・Temperature なしで生トークンを抽出</p>",
121
+ unsafe_allow_html=True,
122
+ )
123
+
124
+
125
+ def _render_specification_section() -> None:
126
+ """仕様セクション"""
127
+ # 利用可能なモデル一覧を取得して動的に表示
128
+ configs = ModelRegistry.get_all_configs()
129
+ model_list = "<br>".join(
130
+ [f"{cfg.name} ({cfg.embedding_dim} dim)" for cfg in configs.values()]
131
+ )
132
+
133
+ st.markdown(
134
+ f'''
135
+ <hr class="divider">
136
+ <div class="section">
137
+ <p class="section-title">SPECIFICATION</p>
138
+ <table class="spec-table">
139
+ <tr><td>Models</td><td>GPT-2 / GPT-Neo / OPT</td></tr>
140
+ <tr><td>Parameters</td><td>125M - 350M</td></tr>
141
+ <tr><td>Embedding</td><td>768 - 1024 dim</td></tr>
142
+ <tr><td>Vocabulary</td><td>50,257+ tokens</td></tr>
143
+ <tr><td>Sequence</td><td>32 tokens</td></tr>
144
+ <tr><td>Input Noise</td><td>N(0, 1)</td></tr>
145
+ <tr><td>Logits Noise</td><td>N(0, σ×10)</td></tr>
146
+ <tr><td>Decoding</td><td>argmax</td></tr>
147
+ </table>
148
+ </div>
149
+ ''',
150
+ unsafe_allow_html=True,
151
+ )
src/ui/pages/generate.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 生成ページ
3
+
4
+ デブリ生成のメインUIを提供する
5
+ """
6
+ import streamlit as st
7
+
8
+ from ...models.registry import ModelRegistry
9
+ from ...generators.debris_generator import DebrisGenerator
10
+ from ...visualizers.signal_visualizer import SignalVisualizer
11
+ from ..components import render_model_selector
12
+
13
+
14
+ # モデルキャッシュ用のキー
15
+ _MODEL_CACHE_KEY = "_cached_model"
16
+ _GENERATOR_CACHE_KEY = "_cached_generator"
17
+
18
+
19
+ @st.cache_resource(show_spinner=False)
20
+ def _get_model(model_key: str):
21
+ """モデルをキャッシュして取得"""
22
+ model = ModelRegistry.get(model_key)
23
+ model.load()
24
+ return model
25
+
26
+
27
+ def render_generate_page() -> None:
28
+ """生成ページをレンダリング"""
29
+ # タイトル
30
+ st.markdown('<p class="title">WILL</p>', unsafe_allow_html=True)
31
+ st.markdown(
32
+ '<p class="subtitle">PURE COMPUTATIONAL WILL</p>', unsafe_allow_html=True
33
+ )
34
+
35
+ # モデル選択UI
36
+ col1, col2, col3 = st.columns([1, 2, 1])
37
+ with col2:
38
+ selected_model_key = render_model_selector()
39
+
40
+ # セッション状態の初期化
41
+ if "debris" not in st.session_state:
42
+ st.session_state.debris = None
43
+ st.session_state.seed = None
44
+ st.session_state.signal_img = None
45
+
46
+ # LISTENボタン
47
+ col1, col2, col3 = st.columns([1, 1, 1])
48
+ with col2:
49
+ clicked = st.button("LISTEN", key="listen_btn", use_container_width=True)
50
+
51
+ if clicked:
52
+ # モデルとジェネレータの取得
53
+ model = _get_model(selected_model_key)
54
+ generator = DebrisGenerator(model)
55
+ visualizer = SignalVisualizer()
56
+
57
+ # デブリ生成
58
+ result = generator.generate()
59
+
60
+ # 結果をセッション状態に保存
61
+ st.session_state.debris = result.debris
62
+ st.session_state.seed = result.seed
63
+ st.session_state.signal_img = visualizer.generate_image(
64
+ result.noise, result.corrupted_logits
65
+ )
66
+
67
+ # 結果の表示
68
+ if st.session_state.debris:
69
+ st.markdown(
70
+ f'''
71
+ <div class="debris-container">
72
+ <img class="signal-img" src="data:image/png;base64,{st.session_state.signal_img}">
73
+ <div class="debris">{" ".join(st.session_state.debris)}</div>
74
+ </div>
75
+ <p class="seed">{st.session_state.seed}</p>
76
+ ''',
77
+ unsafe_allow_html=True,
78
+ )
src/ui/styles.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ カスタムCSS定義
3
+
4
+ Streamlitアプリケーションのスタイルを定義する
5
+ """
6
+
7
+ CUSTOM_CSS = """
8
+ <style>
9
+ @import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Mono:wght@300;400&display=swap');
10
+
11
+ @keyframes emerge {
12
+ from { opacity: 0; transform: translateY(8px); }
13
+ to { opacity: 1; transform: translateY(0); }
14
+ }
15
+ @keyframes breathe {
16
+ 0%, 100% { opacity: 0.4; }
17
+ 50% { opacity: 0.7; }
18
+ }
19
+
20
+ html, body, [class*="css"] {
21
+ font-family: 'IBM Plex Mono', monospace;
22
+ }
23
+ .stApp {
24
+ background-color: #0a0a0a;
25
+ color: #e0e0e0;
26
+ }
27
+ .block-container {
28
+ padding-top: 4rem;
29
+ padding-bottom: 4rem;
30
+ max-width: 640px;
31
+ }
32
+ h1, h2, h3 {
33
+ font-weight: 300;
34
+ letter-spacing: 0.1em;
35
+ text-align: center;
36
+ color: #e0e0e0;
37
+ }
38
+ p, li {
39
+ font-weight: 300;
40
+ line-height: 1.8;
41
+ color: #888;
42
+ }
43
+ .title {
44
+ font-size: 2rem;
45
+ font-weight: 300;
46
+ letter-spacing: 0.3em;
47
+ text-align: center;
48
+ margin-bottom: 0.5rem;
49
+ color: #e0e0e0;
50
+ }
51
+ .subtitle {
52
+ font-size: 0.75rem;
53
+ letter-spacing: 0.2em;
54
+ text-align: center;
55
+ color: #555;
56
+ margin-bottom: 3rem;
57
+ }
58
+ .debris-container {
59
+ background: linear-gradient(135deg, #0f0f0f 0%, #141414 100%);
60
+ border: 1px solid #222;
61
+ border-radius: 2px;
62
+ padding: 2rem;
63
+ margin: 2rem auto;
64
+ max-width: 100%;
65
+ text-align: center;
66
+ animation: emerge 0.6s ease-out;
67
+ }
68
+ .signal-img {
69
+ width: 100%;
70
+ max-width: 480px;
71
+ margin: 0 auto 1.5rem auto;
72
+ display: block;
73
+ opacity: 0.7;
74
+ }
75
+ .debris {
76
+ font-family: 'IBM Plex Mono', monospace;
77
+ font-size: 0.85rem;
78
+ font-weight: 400;
79
+ color: #e0e0e0;
80
+ line-height: 2;
81
+ word-spacing: 0.3em;
82
+ letter-spacing: 0.01em;
83
+ }
84
+ .seed {
85
+ font-size: 0.6rem;
86
+ color: #333;
87
+ text-align: center;
88
+ margin-top: 1.5rem;
89
+ letter-spacing: 0.15em;
90
+ animation: emerge 0.8s ease-out;
91
+ }
92
+ [data-testid="stButton"] > button {
93
+ background: transparent !important;
94
+ border: 1px solid #333 !important;
95
+ border-radius: 2px !important;
96
+ color: #888 !important;
97
+ font-family: 'IBM Plex Mono', monospace !important;
98
+ font-size: 0.7rem !important;
99
+ font-weight: 300 !important;
100
+ letter-spacing: 0.25em !important;
101
+ padding: 1rem 2rem !important;
102
+ transition: all 0.4s ease !important;
103
+ cursor: pointer !important;
104
+ }
105
+ [data-testid="stButton"] > button:hover {
106
+ background: transparent !important;
107
+ color: #e0e0e0 !important;
108
+ border-color: #555 !important;
109
+ }
110
+ [data-testid="stButton"] > button:active {
111
+ transform: scale(0.98) !important;
112
+ }
113
+ .stTabs [data-baseweb="tab-list"] {
114
+ justify-content: center;
115
+ gap: 2rem;
116
+ border-bottom: 1px solid #1a1a1a;
117
+ background: transparent;
118
+ }
119
+ .stTabs [data-baseweb="tab"] {
120
+ font-family: 'IBM Plex Mono', monospace;
121
+ font-size: 0.65rem;
122
+ font-weight: 300;
123
+ letter-spacing: 0.2em;
124
+ color: #444;
125
+ padding: 1rem 0;
126
+ background: transparent;
127
+ transition: color 0.3s ease;
128
+ }
129
+ .stTabs [aria-selected="true"] {
130
+ color: #888;
131
+ background: transparent;
132
+ }
133
+ .stTabs [data-baseweb="tab-highlight"] {
134
+ background-color: #444;
135
+ }
136
+ .divider {
137
+ border: none;
138
+ border-top: 1px solid #1a1a1a;
139
+ margin: 3rem 0;
140
+ }
141
+ .section {
142
+ margin: 2.5rem 0;
143
+ }
144
+ .section-title {
145
+ font-size: 0.65rem;
146
+ letter-spacing: 0.25em;
147
+ color: #444;
148
+ text-align: center;
149
+ margin-bottom: 1.5rem;
150
+ }
151
+ .spec-table {
152
+ width: 100%;
153
+ max-width: 320px;
154
+ margin: 0 auto;
155
+ font-size: 0.7rem;
156
+ border-collapse: collapse;
157
+ color: #777;
158
+ }
159
+ .spec-table td {
160
+ padding: 0.75rem 1rem;
161
+ border-bottom: 1px solid #151515;
162
+ }
163
+ .spec-table td:first-child {
164
+ color: #444;
165
+ text-align: right;
166
+ padding-right: 2rem;
167
+ }
168
+ .spec-table td:last-child {
169
+ text-align: left;
170
+ }
171
+ pre {
172
+ background-color: #0f0f0f !important;
173
+ border: 1px solid #1a1a1a !important;
174
+ border-radius: 2px !important;
175
+ }
176
+ code {
177
+ color: #666 !important;
178
+ font-size: 0.7rem !important;
179
+ }
180
+ /* Model selector styling */
181
+ .stSelectbox > div > div {
182
+ background-color: #0f0f0f !important;
183
+ border: 1px solid #222 !important;
184
+ border-radius: 2px !important;
185
+ color: #888 !important;
186
+ font-size: 0.7rem !important;
187
+ }
188
+ .stSelectbox > div > div:hover {
189
+ border-color: #333 !important;
190
+ }
191
+ .model-info {
192
+ font-size: 0.6rem;
193
+ color: #444;
194
+ text-align: center;
195
+ margin-top: 0.5rem;
196
+ letter-spacing: 0.1em;
197
+ }
198
+ </style>
199
+ """
src/visualizers/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """Signal visualization for WILL."""
2
+ from .signal_visualizer import SignalVisualizer
3
+
4
+ __all__ = ["SignalVisualizer"]
src/visualizers/signal_visualizer.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ シグナル可視化
3
+
4
+ ノイズとlogitsの可視化画像を生成する
5
+ 単一責任原則(SRP)に従い、可視化ロジックのみを担当
6
+ """
7
+ import base64
8
+ import io
9
+ from typing import Optional
10
+
11
+ import matplotlib.pyplot as plt
12
+ import torch
13
+
14
+
15
+ class SignalVisualizer:
16
+ """
17
+ シグナル可視化クラス
18
+
19
+ 入力ノイズとlogitsをグレースケール画像として可視化する
20
+ """
21
+
22
+ # デフォルトの可視化設定
23
+ DEFAULT_FIG_WIDTH = 6
24
+ DEFAULT_FIG_HEIGHT = 2
25
+ DEFAULT_DPI = 150
26
+ DEFAULT_BG_COLOR = "#0f0f0f"
27
+
28
+ # ノイズ表示の次元数
29
+ NOISE_DISPLAY_DIM = 64
30
+
31
+ # logitsサンプリング間隔
32
+ LOGITS_SAMPLE_STEP = 200
33
+
34
+ def __init__(
35
+ self,
36
+ fig_width: float = DEFAULT_FIG_WIDTH,
37
+ fig_height: float = DEFAULT_FIG_HEIGHT,
38
+ dpi: int = DEFAULT_DPI,
39
+ bg_color: str = DEFAULT_BG_COLOR,
40
+ ):
41
+ """
42
+ Args:
43
+ fig_width: 図の幅
44
+ fig_height: 図の高さ
45
+ dpi: 解像度
46
+ bg_color: 背景色
47
+ """
48
+ self._fig_width = fig_width
49
+ self._fig_height = fig_height
50
+ self._dpi = dpi
51
+ self._bg_color = bg_color
52
+
53
+ def generate_image(
54
+ self,
55
+ noise: torch.Tensor,
56
+ logits: torch.Tensor,
57
+ ) -> str:
58
+ """
59
+ ノイズとlogitsの可視化画像をBase64エンコードで生成
60
+
61
+ Args:
62
+ noise: 入力ノイズテンソル [batch, seq_len, embedding_dim]
63
+ logits: logitsテンソル [batch, seq_len, vocab_size]
64
+
65
+ Returns:
66
+ Base64エンコードされたPNG画像文字列
67
+ """
68
+ fig, axes = plt.subplots(
69
+ 2,
70
+ 1,
71
+ figsize=(self._fig_width, self._fig_height),
72
+ facecolor=self._bg_color,
73
+ )
74
+ plt.subplots_adjust(
75
+ hspace=0.15,
76
+ left=0.02,
77
+ right=0.98,
78
+ top=0.95,
79
+ bottom=0.05,
80
+ )
81
+
82
+ # 上段: 入力ノイズの可視化
83
+ self._render_noise(axes[0], noise)
84
+
85
+ # 下段: logitsの可視化
86
+ self._render_logits(axes[1], logits)
87
+
88
+ # PNG画像としてバッファに保存
89
+ buf = io.BytesIO()
90
+ plt.savefig(
91
+ buf,
92
+ format="png",
93
+ facecolor=self._bg_color,
94
+ edgecolor="none",
95
+ dpi=self._dpi,
96
+ bbox_inches="tight",
97
+ pad_inches=0.05,
98
+ )
99
+ plt.close(fig)
100
+
101
+ buf.seek(0)
102
+ return base64.b64encode(buf.read()).decode()
103
+
104
+ def _render_noise(self, ax: plt.Axes, noise: torch.Tensor) -> None:
105
+ """入力ノイズを描画"""
106
+ # 最初のbatchから、embedding_dimの最初のNOISE_DISPLAY_DIM次元を抽出
107
+ noise_flat = noise[0, :, : self.NOISE_DISPLAY_DIM].numpy()
108
+
109
+ ax.imshow(
110
+ noise_flat.T,
111
+ aspect="auto",
112
+ cmap="gray",
113
+ interpolation="bilinear",
114
+ vmin=-2,
115
+ vmax=2,
116
+ )
117
+ self._style_axis(ax)
118
+
119
+ def _render_logits(self, ax: plt.Axes, logits: torch.Tensor) -> None:
120
+ """logitsを描画"""
121
+ # vocab次元をサンプリングして表示
122
+ logits_sample = logits[0, :, :: self.LOGITS_SAMPLE_STEP].numpy()
123
+
124
+ ax.imshow(
125
+ logits_sample.T,
126
+ aspect="auto",
127
+ cmap="gray",
128
+ interpolation="bilinear",
129
+ )
130
+ self._style_axis(ax)
131
+
132
+ def _style_axis(self, ax: plt.Axes) -> None:
133
+ """軸のスタイルを設定"""
134
+ ax.set_xticks([])
135
+ ax.set_yticks([])
136
+ ax.set_facecolor(self._bg_color)
137
+ for spine in ax.spines.values():
138
+ spine.set_visible(False)
tests/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Tests for WILL."""
tests/test_generators.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ジェネレータ関連のテスト
3
+ """
4
+ import pytest
5
+ import torch
6
+
7
+ from src.generators.debris_generator import DebrisGenerator, DebrisResult
8
+ from src.models.gpt2 import GPT2Model, GPT2_SMALL_CONFIG
9
+
10
+
11
+ class TestDebrisResult:
12
+ """DebrisResultのテスト"""
13
+
14
+ def test_result_attributes(self):
15
+ """結果属性が正しく保持されることを確認"""
16
+ result = DebrisResult(
17
+ debris=["hello", "world"],
18
+ seed=12345,
19
+ noise=torch.randn(1, 32, 768),
20
+ logits=torch.randn(1, 32, 50257),
21
+ corrupted_logits=torch.randn(1, 32, 50257),
22
+ )
23
+
24
+ assert result.debris == ["hello", "world"]
25
+ assert result.seed == 12345
26
+ assert result.noise.shape == (1, 32, 768)
27
+
28
+
29
+ class TestDebrisGenerator:
30
+ """DebrisGeneratorのテスト"""
31
+
32
+ @pytest.fixture
33
+ def generator(self):
34
+ """ジェネレータインスタンスを提供"""
35
+ model = GPT2Model(GPT2_SMALL_CONFIG)
36
+ return DebrisGenerator(model)
37
+
38
+ def test_model_property(self, generator):
39
+ """モデルプロパティが正しいことを確認"""
40
+ assert generator.model is not None
41
+ assert generator.model.config == GPT2_SMALL_CONFIG
42
+
43
+
44
+ @pytest.mark.slow
45
+ class TestDebrisGeneratorIntegration:
46
+ """DebrisGeneratorの統合テスト"""
47
+
48
+ @pytest.fixture
49
+ def generator(self):
50
+ """ロード済みジェネレータを提供"""
51
+ model = GPT2Model(GPT2_SMALL_CONFIG)
52
+ model.load()
53
+ return DebrisGenerator(model)
54
+
55
+ def test_generate_with_seed(self, generator):
56
+ """シード指定で生成できることを確認"""
57
+ result = generator.generate(seed=42, seq_len=8)
58
+
59
+ assert isinstance(result, DebrisResult)
60
+ assert result.seed == 42
61
+ assert len(result.debris) == 8
62
+
63
+ def test_generate_reproducible(self, generator):
64
+ """同じシードで同じ結果が得られることを確認"""
65
+ result1 = generator.generate(seed=12345, seq_len=8)
66
+ result2 = generator.generate(seed=12345, seq_len=8)
67
+
68
+ assert result1.debris == result2.debris
69
+
70
+ def test_generate_different_seeds(self, generator):
71
+ """異なるシードで異なる結果が得られることを確認"""
72
+ result1 = generator.generate(seed=11111, seq_len=8)
73
+ result2 = generator.generate(seed=22222, seq_len=8)
74
+
75
+ # 完全一致する確率は極めて低い
76
+ assert result1.debris != result2.debris
tests/test_models.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ モデル関連のテスト
3
+ """
4
+ import pytest
5
+ import torch
6
+
7
+ from src.models.base import ModelConfig, BaseLanguageModel
8
+ from src.models.registry import ModelRegistry, DEFAULT_MODEL_KEY
9
+ from src.models.gpt2 import GPT2Model, GPT2_SMALL_CONFIG
10
+
11
+
12
+ class TestModelConfig:
13
+ """ModelConfigのテスト"""
14
+
15
+ def test_config_is_immutable(self):
16
+ """設定がイミュータブルであることを確認"""
17
+ config = ModelConfig(
18
+ name="Test",
19
+ model_id="test",
20
+ embedding_dim=768,
21
+ vocab_size=50000,
22
+ )
23
+ with pytest.raises(Exception):
24
+ config.name = "Changed"
25
+
26
+ def test_config_attributes(self):
27
+ """設定属性が正しく保持されることを確認"""
28
+ config = ModelConfig(
29
+ name="Test Model",
30
+ model_id="test-model",
31
+ embedding_dim=1024,
32
+ vocab_size=30000,
33
+ )
34
+ assert config.name == "Test Model"
35
+ assert config.model_id == "test-model"
36
+ assert config.embedding_dim == 1024
37
+ assert config.vocab_size == 30000
38
+
39
+
40
+ class TestModelRegistry:
41
+ """ModelRegistryのテスト"""
42
+
43
+ def test_list_models(self):
44
+ """登録済みモデル一覧が取得できることを確認"""
45
+ models = ModelRegistry.list_models()
46
+ assert len(models) > 0
47
+ assert DEFAULT_MODEL_KEY in models
48
+
49
+ def test_get_model(self):
50
+ """モデルインスタンスが取得できることを確認"""
51
+ model = ModelRegistry.get(DEFAULT_MODEL_KEY)
52
+ assert isinstance(model, BaseLanguageModel)
53
+
54
+ def test_get_nonexistent_model(self):
55
+ """存在しないモデルでKeyErrorが発生することを確認"""
56
+ with pytest.raises(KeyError):
57
+ ModelRegistry.get("nonexistent-model")
58
+
59
+ def test_get_config(self):
60
+ """モデル設定が取得できることを確認"""
61
+ config = ModelRegistry.get_config(DEFAULT_MODEL_KEY)
62
+ assert config is not None
63
+ assert isinstance(config, ModelConfig)
64
+
65
+ def test_get_all_configs(self):
66
+ """すべてのモデル設定が取得できることを確認"""
67
+ configs = ModelRegistry.get_all_configs()
68
+ assert len(configs) > 0
69
+ for key, config in configs.items():
70
+ assert isinstance(config, ModelConfig)
71
+
72
+
73
+ class TestGPT2Model:
74
+ """GPT2Modelのテスト"""
75
+
76
+ def test_config(self):
77
+ """設定が正しいことを確認"""
78
+ model = GPT2Model(GPT2_SMALL_CONFIG)
79
+ assert model.config == GPT2_SMALL_CONFIG
80
+ assert model.config.embedding_dim == 768
81
+
82
+ def test_is_loaded_initial(self):
83
+ """初期状態ではロードされていないことを確認"""
84
+ model = GPT2Model(GPT2_SMALL_CONFIG)
85
+ assert not model.is_loaded
86
+
87
+ def test_generate_noise(self):
88
+ """ノイズ生成が正しい形状であることを確認"""
89
+ model = GPT2Model(GPT2_SMALL_CONFIG)
90
+ noise = model.generate_noise(seq_len=16, batch_size=2)
91
+ assert noise.shape == (2, 16, 768)
92
+
93
+
94
+ @pytest.mark.slow
95
+ class TestGPT2ModelIntegration:
96
+ """GPT2Modelの統合テスト(モデルロードが必要)"""
97
+
98
+ @pytest.fixture
99
+ def loaded_model(self):
100
+ """ロード済みモデルを提供"""
101
+ model = GPT2Model(GPT2_SMALL_CONFIG)
102
+ model.load()
103
+ return model
104
+
105
+ def test_load(self, loaded_model):
106
+ """モデルがロードできることを確認"""
107
+ assert loaded_model.is_loaded
108
+
109
+ def test_forward_with_noise(self, loaded_model):
110
+ """順伝播が正しい形状を返すことを確認"""
111
+ noise = loaded_model.generate_noise(seq_len=8)
112
+ logits, corrupted_logits = loaded_model.forward_with_noise(noise)
113
+
114
+ assert logits.shape[0] == 1
115
+ assert logits.shape[1] == 8
116
+ assert logits.shape[2] == loaded_model.config.vocab_size
117
+
118
+ def test_decode_indices(self, loaded_model):
119
+ """デコードが文字列リストを返すことを確認"""
120
+ indices = [100, 200, 300]
121
+ decoded = loaded_model.decode_indices(indices)
122
+
123
+ assert len(decoded) == 3
124
+ assert all(isinstance(s, str) for s in decoded)