Spaces:
Running
Running
File size: 6,833 Bytes
e6f24ae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 | """M11 — SNMF-Diff: Symmetric NMF on binary-probed activations.
Novel contribution: applies Symmetric Non-negative Matrix Factorisation
at rank r=2 to the combined [H_pos; H_neg] matrix. The non-negativity
constraint filters image-content activations from the VLM multimodal
residual stream.
Steering vector: v = C_s - C_not_s (style component minus noise component)
"""
import logging
from typing import Tuple
import numpy as np
from src.methods.base import SteeringMethod
logger = logging.getLogger(__name__)
def _snmf(
X: np.ndarray,
rank: int = 2,
max_iter: int = 500,
tol: float = 1e-6,
seed: int = 42,
) -> np.ndarray:
"""Symmetric Non-negative Matrix Factorisation.
Factorises X ≈ H @ H^T where H ∈ R^(n × rank), H ≥ 0.
Uses multiplicative update rules:
H_ij ← H_ij * sqrt((X @ H)_ij / (H @ H^T @ H)_ij)
Args:
X: (n, n) symmetric non-negative matrix (e.g. gram matrix)
rank: number of components
max_iter: maximum iterations
tol: convergence tolerance (relative change in Frobenius norm)
seed: random seed for initialisation
Returns:
H: (n, rank) non-negative factor matrix
"""
rng = np.random.RandomState(seed)
n = X.shape[0]
# Initialise H randomly
H = rng.rand(n, rank).astype(np.float64) + 1e-6
prev_cost = np.inf
for iteration in range(max_iter):
# Numerator: X @ H
numerator = X @ H
# Denominator: H @ H^T @ H
denominator = H @ (H.T @ H) + 1e-12
# Multiplicative update
H = H * np.sqrt(numerator / denominator)
# Ensure non-negativity
H = np.maximum(H, 1e-12)
# Check convergence
if iteration % 20 == 0:
reconstruction = H @ H.T
cost = np.linalg.norm(X - reconstruction, "fro")
relative_change = abs(prev_cost - cost) / (prev_cost + 1e-12)
if relative_change < tol:
logger.debug(f"SNMF converged at iteration {iteration} (cost={cost:.4f})")
break
prev_cost = cost
return H
class SNMFDiff(SteeringMethod):
"""SNMF-Diff — Novel training-free steering method.
Applies Symmetric NMF to the gram matrix of the combined activation
matrix [H_pos; H_neg] to extract a style component and a noise component.
The steering vector is the difference between the centroids of these
two components.
"""
def __init__(self, rank: int = 2, max_iter: int = 500, **kwargs):
self.rank = rank
self.max_iter = max_iter
@property
def name(self) -> str:
return "SNMF-Diff"
@property
def method_id(self) -> str:
return "M11"
def extract_vector(
self,
h_pos: np.ndarray,
h_neg: np.ndarray,
**kwargs,
) -> np.ndarray:
"""Extract steering vector via SNMF decomposition.
Steps:
1. Combine H = [H_pos; H_neg] and row-normalise
2. Compute gram matrix G = H @ H^T
3. SNMF: G ≈ W @ W^T with rank r
4. Assign each sample to its dominant component
5. C_s = centroid of style-dominant samples in original space
6. C_not_s = centroid of noise-dominant samples
7. v = C_s - C_not_s
Args:
h_pos: (N_pos, d) positive activations
h_neg: (N_neg, d) negative activations
Returns:
(d,) steering vector
"""
rank = kwargs.get("rank", self.rank)
max_iter = kwargs.get("max_iter", self.max_iter)
# Step 1: Combine and normalise
H = np.concatenate([h_pos, h_neg], axis=0).astype(np.float64)
n_pos = len(h_pos)
# Row-normalise to unit norm (per PROJECT.md §17 fix for near-zero components)
norms = np.linalg.norm(H, axis=1, keepdims=True)
norms = np.maximum(norms, 1e-8)
H_norm = H / norms
# Step 2: Gram matrix (shift to non-negative)
G = H_norm @ H_norm.T
G = G - G.min() + 1e-6 # Ensure non-negative
# Step 3: SNMF
W = _snmf(G, rank=rank, max_iter=max_iter)
# Step 4: Assign each sample to dominant component
assignments = W.argmax(axis=1) # (n,)
# Determine which component is the "style" component:
# The one that has more positive samples assigned to it
pos_mask = np.zeros(len(H), dtype=bool)
pos_mask[:n_pos] = True
component_pos_counts = []
for c in range(rank):
c_mask = assignments == c
n_pos_in_c = (c_mask & pos_mask).sum()
component_pos_counts.append(n_pos_in_c)
style_component = np.argmax(component_pos_counts)
noise_component = np.argmin(component_pos_counts)
# Step 5-6: Compute centroids in original (un-normalised) space
style_mask = assignments == style_component
noise_mask = assignments == noise_component
C_s = H[style_mask].mean(axis=0) if style_mask.any() else H[:n_pos].mean(axis=0)
C_not_s = H[noise_mask].mean(axis=0) if noise_mask.any() else H[n_pos:].mean(axis=0)
# Step 7: Steering vector
v = C_s - C_not_s
# Log diagnostics
cos_sim = np.dot(C_s, C_not_s) / (np.linalg.norm(C_s) * np.linalg.norm(C_not_s) + 1e-8)
logger.info(
f"SNMF-Diff (rank={rank}): "
f"style_comp={style_component}, "
f"pos_in_style={component_pos_counts[style_component]}/{n_pos}, "
f"cos_sim(C_s, C_not_s)={cos_sim:.4f}, "
f"|v|={np.linalg.norm(v):.4f}"
)
return v
def get_component_similarity(
self,
h_pos: np.ndarray,
h_neg: np.ndarray,
rank: int = 2,
) -> float:
"""Compute cosine similarity between SNMF components.
Used in the snmf_rank ablation (Section 11 of PROJECT.md).
High similarity = components not well separated = rank too low.
"""
H = np.concatenate([h_pos, h_neg], axis=0).astype(np.float64)
norms = np.linalg.norm(H, axis=1, keepdims=True)
H_norm = H / np.maximum(norms, 1e-8)
G = H_norm @ H_norm.T
G = G - G.min() + 1e-6
W = _snmf(G, rank=rank, max_iter=self.max_iter)
# Compute pairwise cosine similarity between component centroids
assignments = W.argmax(axis=1)
centroids = []
for c in range(rank):
mask = assignments == c
if mask.any():
centroids.append(H[mask].mean(axis=0))
if len(centroids) < 2:
return 1.0 # degenerate case
cos_sim = np.dot(centroids[0], centroids[1]) / (
np.linalg.norm(centroids[0]) * np.linalg.norm(centroids[1]) + 1e-8
)
return float(cos_sim)
|