File size: 6,833 Bytes
e6f24ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""M11 — SNMF-Diff: Symmetric NMF on binary-probed activations.

Novel contribution: applies Symmetric Non-negative Matrix Factorisation
at rank r=2 to the combined [H_pos; H_neg] matrix. The non-negativity
constraint filters image-content activations from the VLM multimodal
residual stream.

Steering vector: v = C_s - C_not_s (style component minus noise component)
"""

import logging
from typing import Tuple

import numpy as np

from src.methods.base import SteeringMethod

logger = logging.getLogger(__name__)


def _snmf(
    X: np.ndarray,
    rank: int = 2,
    max_iter: int = 500,
    tol: float = 1e-6,
    seed: int = 42,
) -> np.ndarray:
    """Symmetric Non-negative Matrix Factorisation.

    Factorises X ≈ H @ H^T where H ∈ R^(n × rank), H ≥ 0.

    Uses multiplicative update rules:
        H_ij ← H_ij * sqrt((X @ H)_ij / (H @ H^T @ H)_ij)

    Args:
        X: (n, n) symmetric non-negative matrix (e.g. gram matrix)
        rank: number of components
        max_iter: maximum iterations
        tol: convergence tolerance (relative change in Frobenius norm)
        seed: random seed for initialisation

    Returns:
        H: (n, rank) non-negative factor matrix
    """
    rng = np.random.RandomState(seed)
    n = X.shape[0]

    # Initialise H randomly
    H = rng.rand(n, rank).astype(np.float64) + 1e-6

    prev_cost = np.inf
    for iteration in range(max_iter):
        # Numerator: X @ H
        numerator = X @ H

        # Denominator: H @ H^T @ H
        denominator = H @ (H.T @ H) + 1e-12

        # Multiplicative update
        H = H * np.sqrt(numerator / denominator)

        # Ensure non-negativity
        H = np.maximum(H, 1e-12)

        # Check convergence
        if iteration % 20 == 0:
            reconstruction = H @ H.T
            cost = np.linalg.norm(X - reconstruction, "fro")
            relative_change = abs(prev_cost - cost) / (prev_cost + 1e-12)
            if relative_change < tol:
                logger.debug(f"SNMF converged at iteration {iteration} (cost={cost:.4f})")
                break
            prev_cost = cost

    return H


class SNMFDiff(SteeringMethod):
    """SNMF-Diff — Novel training-free steering method.

    Applies Symmetric NMF to the gram matrix of the combined activation
    matrix [H_pos; H_neg] to extract a style component and a noise component.
    The steering vector is the difference between the centroids of these
    two components.
    """

    def __init__(self, rank: int = 2, max_iter: int = 500, **kwargs):
        self.rank = rank
        self.max_iter = max_iter

    @property
    def name(self) -> str:
        return "SNMF-Diff"

    @property
    def method_id(self) -> str:
        return "M11"

    def extract_vector(
        self,
        h_pos: np.ndarray,
        h_neg: np.ndarray,
        **kwargs,
    ) -> np.ndarray:
        """Extract steering vector via SNMF decomposition.

        Steps:
            1. Combine H = [H_pos; H_neg] and row-normalise
            2. Compute gram matrix G = H @ H^T
            3. SNMF: G ≈ W @ W^T with rank r
            4. Assign each sample to its dominant component
            5. C_s = centroid of style-dominant samples in original space
            6. C_not_s = centroid of noise-dominant samples
            7. v = C_s - C_not_s

        Args:
            h_pos: (N_pos, d) positive activations
            h_neg: (N_neg, d) negative activations

        Returns:
            (d,) steering vector
        """
        rank = kwargs.get("rank", self.rank)
        max_iter = kwargs.get("max_iter", self.max_iter)

        # Step 1: Combine and normalise
        H = np.concatenate([h_pos, h_neg], axis=0).astype(np.float64)
        n_pos = len(h_pos)

        # Row-normalise to unit norm (per PROJECT.md §17 fix for near-zero components)
        norms = np.linalg.norm(H, axis=1, keepdims=True)
        norms = np.maximum(norms, 1e-8)
        H_norm = H / norms

        # Step 2: Gram matrix (shift to non-negative)
        G = H_norm @ H_norm.T
        G = G - G.min() + 1e-6  # Ensure non-negative

        # Step 3: SNMF
        W = _snmf(G, rank=rank, max_iter=max_iter)

        # Step 4: Assign each sample to dominant component
        assignments = W.argmax(axis=1)  # (n,)

        # Determine which component is the "style" component:
        # The one that has more positive samples assigned to it
        pos_mask = np.zeros(len(H), dtype=bool)
        pos_mask[:n_pos] = True

        component_pos_counts = []
        for c in range(rank):
            c_mask = assignments == c
            n_pos_in_c = (c_mask & pos_mask).sum()
            component_pos_counts.append(n_pos_in_c)

        style_component = np.argmax(component_pos_counts)
        noise_component = np.argmin(component_pos_counts)

        # Step 5-6: Compute centroids in original (un-normalised) space
        style_mask = assignments == style_component
        noise_mask = assignments == noise_component

        C_s = H[style_mask].mean(axis=0) if style_mask.any() else H[:n_pos].mean(axis=0)
        C_not_s = H[noise_mask].mean(axis=0) if noise_mask.any() else H[n_pos:].mean(axis=0)

        # Step 7: Steering vector
        v = C_s - C_not_s

        # Log diagnostics
        cos_sim = np.dot(C_s, C_not_s) / (np.linalg.norm(C_s) * np.linalg.norm(C_not_s) + 1e-8)
        logger.info(
            f"SNMF-Diff (rank={rank}): "
            f"style_comp={style_component}, "
            f"pos_in_style={component_pos_counts[style_component]}/{n_pos}, "
            f"cos_sim(C_s, C_not_s)={cos_sim:.4f}, "
            f"|v|={np.linalg.norm(v):.4f}"
        )

        return v

    def get_component_similarity(
        self,
        h_pos: np.ndarray,
        h_neg: np.ndarray,
        rank: int = 2,
    ) -> float:
        """Compute cosine similarity between SNMF components.

        Used in the snmf_rank ablation (Section 11 of PROJECT.md).
        High similarity = components not well separated = rank too low.
        """
        H = np.concatenate([h_pos, h_neg], axis=0).astype(np.float64)
        norms = np.linalg.norm(H, axis=1, keepdims=True)
        H_norm = H / np.maximum(norms, 1e-8)
        G = H_norm @ H_norm.T
        G = G - G.min() + 1e-6

        W = _snmf(G, rank=rank, max_iter=self.max_iter)

        # Compute pairwise cosine similarity between component centroids
        assignments = W.argmax(axis=1)
        centroids = []
        for c in range(rank):
            mask = assignments == c
            if mask.any():
                centroids.append(H[mask].mean(axis=0))

        if len(centroids) < 2:
            return 1.0  # degenerate case

        cos_sim = np.dot(centroids[0], centroids[1]) / (
            np.linalg.norm(centroids[0]) * np.linalg.norm(centroids[1]) + 1e-8
        )
        return float(cos_sim)