will / src /visualizers /signal_visualizer.py
matt1847's picture
リファクタ: srcディレクトリ構造への移行とDocker対応
d1033d4
"""
シグナル可視化
ノイズとlogitsの可視化画像を生成する
単一責任原則(SRP)に従い、可視化ロジックのみを担当
"""
import base64
import io
from typing import Optional
import matplotlib.pyplot as plt
import torch
class SignalVisualizer:
"""
シグナル可視化クラス
入力ノイズとlogitsをグレースケール画像として可視化する
"""
# デフォルトの可視化設定
DEFAULT_FIG_WIDTH = 6
DEFAULT_FIG_HEIGHT = 2
DEFAULT_DPI = 150
DEFAULT_BG_COLOR = "#0f0f0f"
# ノイズ表示の次元数
NOISE_DISPLAY_DIM = 64
# logitsサンプリング間隔
LOGITS_SAMPLE_STEP = 200
def __init__(
self,
fig_width: float = DEFAULT_FIG_WIDTH,
fig_height: float = DEFAULT_FIG_HEIGHT,
dpi: int = DEFAULT_DPI,
bg_color: str = DEFAULT_BG_COLOR,
):
"""
Args:
fig_width: 図の幅
fig_height: 図の高さ
dpi: 解像度
bg_color: 背景色
"""
self._fig_width = fig_width
self._fig_height = fig_height
self._dpi = dpi
self._bg_color = bg_color
def generate_image(
self,
noise: torch.Tensor,
logits: torch.Tensor,
) -> str:
"""
ノイズとlogitsの可視化画像をBase64エンコードで生成
Args:
noise: 入力ノイズテンソル [batch, seq_len, embedding_dim]
logits: logitsテンソル [batch, seq_len, vocab_size]
Returns:
Base64エンコードされたPNG画像文字列
"""
fig, axes = plt.subplots(
2,
1,
figsize=(self._fig_width, self._fig_height),
facecolor=self._bg_color,
)
plt.subplots_adjust(
hspace=0.15,
left=0.02,
right=0.98,
top=0.95,
bottom=0.05,
)
# 上段: 入力ノイズの可視化
self._render_noise(axes[0], noise)
# 下段: logitsの可視化
self._render_logits(axes[1], logits)
# PNG画像としてバッファに保存
buf = io.BytesIO()
plt.savefig(
buf,
format="png",
facecolor=self._bg_color,
edgecolor="none",
dpi=self._dpi,
bbox_inches="tight",
pad_inches=0.05,
)
plt.close(fig)
buf.seek(0)
return base64.b64encode(buf.read()).decode()
def _render_noise(self, ax: plt.Axes, noise: torch.Tensor) -> None:
"""入力ノイズを描画"""
# 最初のbatchから、embedding_dimの最初のNOISE_DISPLAY_DIM次元を抽出
noise_flat = noise[0, :, : self.NOISE_DISPLAY_DIM].numpy()
ax.imshow(
noise_flat.T,
aspect="auto",
cmap="gray",
interpolation="bilinear",
vmin=-2,
vmax=2,
)
self._style_axis(ax)
def _render_logits(self, ax: plt.Axes, logits: torch.Tensor) -> None:
"""logitsを描画"""
# vocab次元をサンプリングして表示
logits_sample = logits[0, :, :: self.LOGITS_SAMPLE_STEP].numpy()
ax.imshow(
logits_sample.T,
aspect="auto",
cmap="gray",
interpolation="bilinear",
)
self._style_axis(ax)
def _style_axis(self, ax: plt.Axes) -> None:
"""軸のスタイルを設定"""
ax.set_xticks([])
ax.set_yticks([])
ax.set_facecolor(self._bg_color)
for spine in ax.spines.values():
spine.set_visible(False)