|
|
""" |
|
|
シグナル可視化 |
|
|
|
|
|
ノイズとlogitsの可視化画像を生成する |
|
|
単一責任原則(SRP)に従い、可視化ロジックのみを担当 |
|
|
""" |
|
|
import base64 |
|
|
import io |
|
|
from typing import Optional |
|
|
|
|
|
import matplotlib.pyplot as plt |
|
|
import torch |
|
|
|
|
|
|
|
|
class SignalVisualizer: |
|
|
""" |
|
|
シグナル可視化クラス |
|
|
|
|
|
入力ノイズとlogitsをグレースケール画像として可視化する |
|
|
""" |
|
|
|
|
|
|
|
|
DEFAULT_FIG_WIDTH = 6 |
|
|
DEFAULT_FIG_HEIGHT = 2 |
|
|
DEFAULT_DPI = 150 |
|
|
DEFAULT_BG_COLOR = "#0f0f0f" |
|
|
|
|
|
|
|
|
NOISE_DISPLAY_DIM = 64 |
|
|
|
|
|
|
|
|
LOGITS_SAMPLE_STEP = 200 |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
fig_width: float = DEFAULT_FIG_WIDTH, |
|
|
fig_height: float = DEFAULT_FIG_HEIGHT, |
|
|
dpi: int = DEFAULT_DPI, |
|
|
bg_color: str = DEFAULT_BG_COLOR, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
fig_width: 図の幅 |
|
|
fig_height: 図の高さ |
|
|
dpi: 解像度 |
|
|
bg_color: 背景色 |
|
|
""" |
|
|
self._fig_width = fig_width |
|
|
self._fig_height = fig_height |
|
|
self._dpi = dpi |
|
|
self._bg_color = bg_color |
|
|
|
|
|
def generate_image( |
|
|
self, |
|
|
noise: torch.Tensor, |
|
|
logits: torch.Tensor, |
|
|
) -> str: |
|
|
""" |
|
|
ノイズとlogitsの可視化画像をBase64エンコードで生成 |
|
|
|
|
|
Args: |
|
|
noise: 入力ノイズテンソル [batch, seq_len, embedding_dim] |
|
|
logits: logitsテンソル [batch, seq_len, vocab_size] |
|
|
|
|
|
Returns: |
|
|
Base64エンコードされたPNG画像文字列 |
|
|
""" |
|
|
fig, axes = plt.subplots( |
|
|
2, |
|
|
1, |
|
|
figsize=(self._fig_width, self._fig_height), |
|
|
facecolor=self._bg_color, |
|
|
) |
|
|
plt.subplots_adjust( |
|
|
hspace=0.15, |
|
|
left=0.02, |
|
|
right=0.98, |
|
|
top=0.95, |
|
|
bottom=0.05, |
|
|
) |
|
|
|
|
|
|
|
|
self._render_noise(axes[0], noise) |
|
|
|
|
|
|
|
|
self._render_logits(axes[1], logits) |
|
|
|
|
|
|
|
|
buf = io.BytesIO() |
|
|
plt.savefig( |
|
|
buf, |
|
|
format="png", |
|
|
facecolor=self._bg_color, |
|
|
edgecolor="none", |
|
|
dpi=self._dpi, |
|
|
bbox_inches="tight", |
|
|
pad_inches=0.05, |
|
|
) |
|
|
plt.close(fig) |
|
|
|
|
|
buf.seek(0) |
|
|
return base64.b64encode(buf.read()).decode() |
|
|
|
|
|
def _render_noise(self, ax: plt.Axes, noise: torch.Tensor) -> None: |
|
|
"""入力ノイズを描画""" |
|
|
|
|
|
noise_flat = noise[0, :, : self.NOISE_DISPLAY_DIM].numpy() |
|
|
|
|
|
ax.imshow( |
|
|
noise_flat.T, |
|
|
aspect="auto", |
|
|
cmap="gray", |
|
|
interpolation="bilinear", |
|
|
vmin=-2, |
|
|
vmax=2, |
|
|
) |
|
|
self._style_axis(ax) |
|
|
|
|
|
def _render_logits(self, ax: plt.Axes, logits: torch.Tensor) -> None: |
|
|
"""logitsを描画""" |
|
|
|
|
|
logits_sample = logits[0, :, :: self.LOGITS_SAMPLE_STEP].numpy() |
|
|
|
|
|
ax.imshow( |
|
|
logits_sample.T, |
|
|
aspect="auto", |
|
|
cmap="gray", |
|
|
interpolation="bilinear", |
|
|
) |
|
|
self._style_axis(ax) |
|
|
|
|
|
def _style_axis(self, ax: plt.Axes) -> None: |
|
|
"""軸のスタイルを設定""" |
|
|
ax.set_xticks([]) |
|
|
ax.set_yticks([]) |
|
|
ax.set_facecolor(self._bg_color) |
|
|
for spine in ax.spines.values(): |
|
|
spine.set_visible(False) |
|
|
|