File size: 3,720 Bytes
d1033d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
シグナル可視化

ノイズとlogitsの可視化画像を生成する
単一責任原則(SRP)に従い、可視化ロジックのみを担当
"""
import base64
import io
from typing import Optional

import matplotlib.pyplot as plt
import torch


class SignalVisualizer:
    """
    シグナル可視化クラス

    入力ノイズとlogitsをグレースケール画像として可視化する
    """

    # デフォルトの可視化設定
    DEFAULT_FIG_WIDTH = 6
    DEFAULT_FIG_HEIGHT = 2
    DEFAULT_DPI = 150
    DEFAULT_BG_COLOR = "#0f0f0f"

    # ノイズ表示の次元数
    NOISE_DISPLAY_DIM = 64

    # logitsサンプリング間隔
    LOGITS_SAMPLE_STEP = 200

    def __init__(
        self,
        fig_width: float = DEFAULT_FIG_WIDTH,
        fig_height: float = DEFAULT_FIG_HEIGHT,
        dpi: int = DEFAULT_DPI,
        bg_color: str = DEFAULT_BG_COLOR,
    ):
        """
        Args:
            fig_width: 図の幅
            fig_height: 図の高さ
            dpi: 解像度
            bg_color: 背景色
        """
        self._fig_width = fig_width
        self._fig_height = fig_height
        self._dpi = dpi
        self._bg_color = bg_color

    def generate_image(
        self,
        noise: torch.Tensor,
        logits: torch.Tensor,
    ) -> str:
        """
        ノイズとlogitsの可視化画像をBase64エンコードで生成

        Args:
            noise: 入力ノイズテンソル [batch, seq_len, embedding_dim]
            logits: logitsテンソル [batch, seq_len, vocab_size]

        Returns:
            Base64エンコードされたPNG画像文字列
        """
        fig, axes = plt.subplots(
            2,
            1,
            figsize=(self._fig_width, self._fig_height),
            facecolor=self._bg_color,
        )
        plt.subplots_adjust(
            hspace=0.15,
            left=0.02,
            right=0.98,
            top=0.95,
            bottom=0.05,
        )

        # 上段: 入力ノイズの可視化
        self._render_noise(axes[0], noise)

        # 下段: logitsの可視化
        self._render_logits(axes[1], logits)

        # PNG画像としてバッファに保存
        buf = io.BytesIO()
        plt.savefig(
            buf,
            format="png",
            facecolor=self._bg_color,
            edgecolor="none",
            dpi=self._dpi,
            bbox_inches="tight",
            pad_inches=0.05,
        )
        plt.close(fig)

        buf.seek(0)
        return base64.b64encode(buf.read()).decode()

    def _render_noise(self, ax: plt.Axes, noise: torch.Tensor) -> None:
        """入力ノイズを描画"""
        # 最初のbatchから、embedding_dimの最初のNOISE_DISPLAY_DIM次元を抽出
        noise_flat = noise[0, :, : self.NOISE_DISPLAY_DIM].numpy()

        ax.imshow(
            noise_flat.T,
            aspect="auto",
            cmap="gray",
            interpolation="bilinear",
            vmin=-2,
            vmax=2,
        )
        self._style_axis(ax)

    def _render_logits(self, ax: plt.Axes, logits: torch.Tensor) -> None:
        """logitsを描画"""
        # vocab次元をサンプリングして表示
        logits_sample = logits[0, :, :: self.LOGITS_SAMPLE_STEP].numpy()

        ax.imshow(
            logits_sample.T,
            aspect="auto",
            cmap="gray",
            interpolation="bilinear",
        )
        self._style_axis(ax)

    def _style_axis(self, ax: plt.Axes) -> None:
        """軸のスタイルを設定"""
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_facecolor(self._bg_color)
        for spine in ax.spines.values():
            spine.set_visible(False)