"""M11 — SNMF-Diff: Symmetric NMF on binary-probed activations. Novel contribution: applies Symmetric Non-negative Matrix Factorisation at rank r=2 to the combined [H_pos; H_neg] matrix. The non-negativity constraint filters image-content activations from the VLM multimodal residual stream. Steering vector: v = C_s - C_not_s (style component minus noise component) """ import logging from typing import Tuple import numpy as np from src.methods.base import SteeringMethod logger = logging.getLogger(__name__) def _snmf( X: np.ndarray, rank: int = 2, max_iter: int = 500, tol: float = 1e-6, seed: int = 42, ) -> np.ndarray: """Symmetric Non-negative Matrix Factorisation. Factorises X ≈ H @ H^T where H ∈ R^(n × rank), H ≥ 0. Uses multiplicative update rules: H_ij ← H_ij * sqrt((X @ H)_ij / (H @ H^T @ H)_ij) Args: X: (n, n) symmetric non-negative matrix (e.g. gram matrix) rank: number of components max_iter: maximum iterations tol: convergence tolerance (relative change in Frobenius norm) seed: random seed for initialisation Returns: H: (n, rank) non-negative factor matrix """ rng = np.random.RandomState(seed) n = X.shape[0] # Initialise H randomly H = rng.rand(n, rank).astype(np.float64) + 1e-6 prev_cost = np.inf for iteration in range(max_iter): # Numerator: X @ H numerator = X @ H # Denominator: H @ H^T @ H denominator = H @ (H.T @ H) + 1e-12 # Multiplicative update H = H * np.sqrt(numerator / denominator) # Ensure non-negativity H = np.maximum(H, 1e-12) # Check convergence if iteration % 20 == 0: reconstruction = H @ H.T cost = np.linalg.norm(X - reconstruction, "fro") relative_change = abs(prev_cost - cost) / (prev_cost + 1e-12) if relative_change < tol: logger.debug(f"SNMF converged at iteration {iteration} (cost={cost:.4f})") break prev_cost = cost return H class SNMFDiff(SteeringMethod): """SNMF-Diff — Novel training-free steering method. Applies Symmetric NMF to the gram matrix of the combined activation matrix [H_pos; H_neg] to extract a style component and a noise component. The steering vector is the difference between the centroids of these two components. """ def __init__(self, rank: int = 2, max_iter: int = 500, **kwargs): self.rank = rank self.max_iter = max_iter @property def name(self) -> str: return "SNMF-Diff" @property def method_id(self) -> str: return "M11" def extract_vector( self, h_pos: np.ndarray, h_neg: np.ndarray, **kwargs, ) -> np.ndarray: """Extract steering vector via SNMF decomposition. Steps: 1. Combine H = [H_pos; H_neg] and row-normalise 2. Compute gram matrix G = H @ H^T 3. SNMF: G ≈ W @ W^T with rank r 4. Assign each sample to its dominant component 5. C_s = centroid of style-dominant samples in original space 6. C_not_s = centroid of noise-dominant samples 7. v = C_s - C_not_s Args: h_pos: (N_pos, d) positive activations h_neg: (N_neg, d) negative activations Returns: (d,) steering vector """ rank = kwargs.get("rank", self.rank) max_iter = kwargs.get("max_iter", self.max_iter) # Step 1: Combine and normalise H = np.concatenate([h_pos, h_neg], axis=0).astype(np.float64) n_pos = len(h_pos) # Row-normalise to unit norm (per PROJECT.md §17 fix for near-zero components) norms = np.linalg.norm(H, axis=1, keepdims=True) norms = np.maximum(norms, 1e-8) H_norm = H / norms # Step 2: Gram matrix (shift to non-negative) G = H_norm @ H_norm.T G = G - G.min() + 1e-6 # Ensure non-negative # Step 3: SNMF W = _snmf(G, rank=rank, max_iter=max_iter) # Step 4: Assign each sample to dominant component assignments = W.argmax(axis=1) # (n,) # Determine which component is the "style" component: # The one that has more positive samples assigned to it pos_mask = np.zeros(len(H), dtype=bool) pos_mask[:n_pos] = True component_pos_counts = [] for c in range(rank): c_mask = assignments == c n_pos_in_c = (c_mask & pos_mask).sum() component_pos_counts.append(n_pos_in_c) style_component = np.argmax(component_pos_counts) noise_component = np.argmin(component_pos_counts) # Step 5-6: Compute centroids in original (un-normalised) space style_mask = assignments == style_component noise_mask = assignments == noise_component C_s = H[style_mask].mean(axis=0) if style_mask.any() else H[:n_pos].mean(axis=0) C_not_s = H[noise_mask].mean(axis=0) if noise_mask.any() else H[n_pos:].mean(axis=0) # Step 7: Steering vector v = C_s - C_not_s # Log diagnostics cos_sim = np.dot(C_s, C_not_s) / (np.linalg.norm(C_s) * np.linalg.norm(C_not_s) + 1e-8) logger.info( f"SNMF-Diff (rank={rank}): " f"style_comp={style_component}, " f"pos_in_style={component_pos_counts[style_component]}/{n_pos}, " f"cos_sim(C_s, C_not_s)={cos_sim:.4f}, " f"|v|={np.linalg.norm(v):.4f}" ) return v def get_component_similarity( self, h_pos: np.ndarray, h_neg: np.ndarray, rank: int = 2, ) -> float: """Compute cosine similarity between SNMF components. Used in the snmf_rank ablation (Section 11 of PROJECT.md). High similarity = components not well separated = rank too low. """ H = np.concatenate([h_pos, h_neg], axis=0).astype(np.float64) norms = np.linalg.norm(H, axis=1, keepdims=True) H_norm = H / np.maximum(norms, 1e-8) G = H_norm @ H_norm.T G = G - G.min() + 1e-6 W = _snmf(G, rank=rank, max_iter=self.max_iter) # Compute pairwise cosine similarity between component centroids assignments = W.argmax(axis=1) centroids = [] for c in range(rank): mask = assignments == c if mask.any(): centroids.append(H[mask].mean(axis=0)) if len(centroids) < 2: return 1.0 # degenerate case cos_sim = np.dot(centroids[0], centroids[1]) / ( np.linalg.norm(centroids[0]) * np.linalg.norm(centroids[1]) + 1e-8 ) return float(cos_sim)