""" シグナル可視化 ノイズとlogitsの可視化画像を生成する 単一責任原則(SRP)に従い、可視化ロジックのみを担当 """ import base64 import io from typing import Optional import matplotlib.pyplot as plt import torch class SignalVisualizer: """ シグナル可視化クラス 入力ノイズとlogitsをグレースケール画像として可視化する """ # デフォルトの可視化設定 DEFAULT_FIG_WIDTH = 6 DEFAULT_FIG_HEIGHT = 2 DEFAULT_DPI = 150 DEFAULT_BG_COLOR = "#0f0f0f" # ノイズ表示の次元数 NOISE_DISPLAY_DIM = 64 # logitsサンプリング間隔 LOGITS_SAMPLE_STEP = 200 def __init__( self, fig_width: float = DEFAULT_FIG_WIDTH, fig_height: float = DEFAULT_FIG_HEIGHT, dpi: int = DEFAULT_DPI, bg_color: str = DEFAULT_BG_COLOR, ): """ Args: fig_width: 図の幅 fig_height: 図の高さ dpi: 解像度 bg_color: 背景色 """ self._fig_width = fig_width self._fig_height = fig_height self._dpi = dpi self._bg_color = bg_color def generate_image( self, noise: torch.Tensor, logits: torch.Tensor, ) -> str: """ ノイズとlogitsの可視化画像をBase64エンコードで生成 Args: noise: 入力ノイズテンソル [batch, seq_len, embedding_dim] logits: logitsテンソル [batch, seq_len, vocab_size] Returns: Base64エンコードされたPNG画像文字列 """ fig, axes = plt.subplots( 2, 1, figsize=(self._fig_width, self._fig_height), facecolor=self._bg_color, ) plt.subplots_adjust( hspace=0.15, left=0.02, right=0.98, top=0.95, bottom=0.05, ) # 上段: 入力ノイズの可視化 self._render_noise(axes[0], noise) # 下段: logitsの可視化 self._render_logits(axes[1], logits) # PNG画像としてバッファに保存 buf = io.BytesIO() plt.savefig( buf, format="png", facecolor=self._bg_color, edgecolor="none", dpi=self._dpi, bbox_inches="tight", pad_inches=0.05, ) plt.close(fig) buf.seek(0) return base64.b64encode(buf.read()).decode() def _render_noise(self, ax: plt.Axes, noise: torch.Tensor) -> None: """入力ノイズを描画""" # 最初のbatchから、embedding_dimの最初のNOISE_DISPLAY_DIM次元を抽出 noise_flat = noise[0, :, : self.NOISE_DISPLAY_DIM].numpy() ax.imshow( noise_flat.T, aspect="auto", cmap="gray", interpolation="bilinear", vmin=-2, vmax=2, ) self._style_axis(ax) def _render_logits(self, ax: plt.Axes, logits: torch.Tensor) -> None: """logitsを描画""" # vocab次元をサンプリングして表示 logits_sample = logits[0, :, :: self.LOGITS_SAMPLE_STEP].numpy() ax.imshow( logits_sample.T, aspect="auto", cmap="gray", interpolation="bilinear", ) self._style_axis(ax) def _style_axis(self, ax: plt.Axes) -> None: """軸のスタイルを設定""" ax.set_xticks([]) ax.set_yticks([]) ax.set_facecolor(self._bg_color) for spine in ax.spines.values(): spine.set_visible(False)