File size: 3,720 Bytes
d1033d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
"""
シグナル可視化
ノイズとlogitsの可視化画像を生成する
単一責任原則(SRP)に従い、可視化ロジックのみを担当
"""
import base64
import io
from typing import Optional
import matplotlib.pyplot as plt
import torch
class SignalVisualizer:
"""
シグナル可視化クラス
入力ノイズとlogitsをグレースケール画像として可視化する
"""
# デフォルトの可視化設定
DEFAULT_FIG_WIDTH = 6
DEFAULT_FIG_HEIGHT = 2
DEFAULT_DPI = 150
DEFAULT_BG_COLOR = "#0f0f0f"
# ノイズ表示の次元数
NOISE_DISPLAY_DIM = 64
# logitsサンプリング間隔
LOGITS_SAMPLE_STEP = 200
def __init__(
self,
fig_width: float = DEFAULT_FIG_WIDTH,
fig_height: float = DEFAULT_FIG_HEIGHT,
dpi: int = DEFAULT_DPI,
bg_color: str = DEFAULT_BG_COLOR,
):
"""
Args:
fig_width: 図の幅
fig_height: 図の高さ
dpi: 解像度
bg_color: 背景色
"""
self._fig_width = fig_width
self._fig_height = fig_height
self._dpi = dpi
self._bg_color = bg_color
def generate_image(
self,
noise: torch.Tensor,
logits: torch.Tensor,
) -> str:
"""
ノイズとlogitsの可視化画像をBase64エンコードで生成
Args:
noise: 入力ノイズテンソル [batch, seq_len, embedding_dim]
logits: logitsテンソル [batch, seq_len, vocab_size]
Returns:
Base64エンコードされたPNG画像文字列
"""
fig, axes = plt.subplots(
2,
1,
figsize=(self._fig_width, self._fig_height),
facecolor=self._bg_color,
)
plt.subplots_adjust(
hspace=0.15,
left=0.02,
right=0.98,
top=0.95,
bottom=0.05,
)
# 上段: 入力ノイズの可視化
self._render_noise(axes[0], noise)
# 下段: logitsの可視化
self._render_logits(axes[1], logits)
# PNG画像としてバッファに保存
buf = io.BytesIO()
plt.savefig(
buf,
format="png",
facecolor=self._bg_color,
edgecolor="none",
dpi=self._dpi,
bbox_inches="tight",
pad_inches=0.05,
)
plt.close(fig)
buf.seek(0)
return base64.b64encode(buf.read()).decode()
def _render_noise(self, ax: plt.Axes, noise: torch.Tensor) -> None:
"""入力ノイズを描画"""
# 最初のbatchから、embedding_dimの最初のNOISE_DISPLAY_DIM次元を抽出
noise_flat = noise[0, :, : self.NOISE_DISPLAY_DIM].numpy()
ax.imshow(
noise_flat.T,
aspect="auto",
cmap="gray",
interpolation="bilinear",
vmin=-2,
vmax=2,
)
self._style_axis(ax)
def _render_logits(self, ax: plt.Axes, logits: torch.Tensor) -> None:
"""logitsを描画"""
# vocab次元をサンプリングして表示
logits_sample = logits[0, :, :: self.LOGITS_SAMPLE_STEP].numpy()
ax.imshow(
logits_sample.T,
aspect="auto",
cmap="gray",
interpolation="bilinear",
)
self._style_axis(ax)
def _style_axis(self, ax: plt.Axes) -> None:
"""軸のスタイルを設定"""
ax.set_xticks([])
ax.set_yticks([])
ax.set_facecolor(self._bg_color)
for spine in ax.spines.values():
spine.set_visible(False)
|