Spaces:
Running
Running
| """M11 — SNMF-Diff: Symmetric NMF on binary-probed activations. | |
| Novel contribution: applies Symmetric Non-negative Matrix Factorisation | |
| at rank r=2 to the combined [H_pos; H_neg] matrix. The non-negativity | |
| constraint filters image-content activations from the VLM multimodal | |
| residual stream. | |
| Steering vector: v = C_s - C_not_s (style component minus noise component) | |
| """ | |
| import logging | |
| from typing import Tuple | |
| import numpy as np | |
| from src.methods.base import SteeringMethod | |
| logger = logging.getLogger(__name__) | |
| def _snmf( | |
| X: np.ndarray, | |
| rank: int = 2, | |
| max_iter: int = 500, | |
| tol: float = 1e-6, | |
| seed: int = 42, | |
| ) -> np.ndarray: | |
| """Symmetric Non-negative Matrix Factorisation. | |
| Factorises X ≈ H @ H^T where H ∈ R^(n × rank), H ≥ 0. | |
| Uses multiplicative update rules: | |
| H_ij ← H_ij * sqrt((X @ H)_ij / (H @ H^T @ H)_ij) | |
| Args: | |
| X: (n, n) symmetric non-negative matrix (e.g. gram matrix) | |
| rank: number of components | |
| max_iter: maximum iterations | |
| tol: convergence tolerance (relative change in Frobenius norm) | |
| seed: random seed for initialisation | |
| Returns: | |
| H: (n, rank) non-negative factor matrix | |
| """ | |
| rng = np.random.RandomState(seed) | |
| n = X.shape[0] | |
| # Initialise H randomly | |
| H = rng.rand(n, rank).astype(np.float64) + 1e-6 | |
| prev_cost = np.inf | |
| for iteration in range(max_iter): | |
| # Numerator: X @ H | |
| numerator = X @ H | |
| # Denominator: H @ H^T @ H | |
| denominator = H @ (H.T @ H) + 1e-12 | |
| # Multiplicative update | |
| H = H * np.sqrt(numerator / denominator) | |
| # Ensure non-negativity | |
| H = np.maximum(H, 1e-12) | |
| # Check convergence | |
| if iteration % 20 == 0: | |
| reconstruction = H @ H.T | |
| cost = np.linalg.norm(X - reconstruction, "fro") | |
| relative_change = abs(prev_cost - cost) / (prev_cost + 1e-12) | |
| if relative_change < tol: | |
| logger.debug(f"SNMF converged at iteration {iteration} (cost={cost:.4f})") | |
| break | |
| prev_cost = cost | |
| return H | |
| class SNMFDiff(SteeringMethod): | |
| """SNMF-Diff — Novel training-free steering method. | |
| Applies Symmetric NMF to the gram matrix of the combined activation | |
| matrix [H_pos; H_neg] to extract a style component and a noise component. | |
| The steering vector is the difference between the centroids of these | |
| two components. | |
| """ | |
| def __init__(self, rank: int = 2, max_iter: int = 500, **kwargs): | |
| self.rank = rank | |
| self.max_iter = max_iter | |
| def name(self) -> str: | |
| return "SNMF-Diff" | |
| def method_id(self) -> str: | |
| return "M11" | |
| def extract_vector( | |
| self, | |
| h_pos: np.ndarray, | |
| h_neg: np.ndarray, | |
| **kwargs, | |
| ) -> np.ndarray: | |
| """Extract steering vector via SNMF decomposition. | |
| Steps: | |
| 1. Combine H = [H_pos; H_neg] and row-normalise | |
| 2. Compute gram matrix G = H @ H^T | |
| 3. SNMF: G ≈ W @ W^T with rank r | |
| 4. Assign each sample to its dominant component | |
| 5. C_s = centroid of style-dominant samples in original space | |
| 6. C_not_s = centroid of noise-dominant samples | |
| 7. v = C_s - C_not_s | |
| Args: | |
| h_pos: (N_pos, d) positive activations | |
| h_neg: (N_neg, d) negative activations | |
| Returns: | |
| (d,) steering vector | |
| """ | |
| rank = kwargs.get("rank", self.rank) | |
| max_iter = kwargs.get("max_iter", self.max_iter) | |
| # Step 1: Combine and normalise | |
| H = np.concatenate([h_pos, h_neg], axis=0).astype(np.float64) | |
| n_pos = len(h_pos) | |
| # Row-normalise to unit norm (per PROJECT.md §17 fix for near-zero components) | |
| norms = np.linalg.norm(H, axis=1, keepdims=True) | |
| norms = np.maximum(norms, 1e-8) | |
| H_norm = H / norms | |
| # Step 2: Gram matrix (shift to non-negative) | |
| G = H_norm @ H_norm.T | |
| G = G - G.min() + 1e-6 # Ensure non-negative | |
| # Step 3: SNMF | |
| W = _snmf(G, rank=rank, max_iter=max_iter) | |
| # Step 4: Assign each sample to dominant component | |
| assignments = W.argmax(axis=1) # (n,) | |
| # Determine which component is the "style" component: | |
| # The one that has more positive samples assigned to it | |
| pos_mask = np.zeros(len(H), dtype=bool) | |
| pos_mask[:n_pos] = True | |
| component_pos_counts = [] | |
| for c in range(rank): | |
| c_mask = assignments == c | |
| n_pos_in_c = (c_mask & pos_mask).sum() | |
| component_pos_counts.append(n_pos_in_c) | |
| style_component = np.argmax(component_pos_counts) | |
| noise_component = np.argmin(component_pos_counts) | |
| # Step 5-6: Compute centroids in original (un-normalised) space | |
| style_mask = assignments == style_component | |
| noise_mask = assignments == noise_component | |
| C_s = H[style_mask].mean(axis=0) if style_mask.any() else H[:n_pos].mean(axis=0) | |
| C_not_s = H[noise_mask].mean(axis=0) if noise_mask.any() else H[n_pos:].mean(axis=0) | |
| # Step 7: Steering vector | |
| v = C_s - C_not_s | |
| # Log diagnostics | |
| cos_sim = np.dot(C_s, C_not_s) / (np.linalg.norm(C_s) * np.linalg.norm(C_not_s) + 1e-8) | |
| logger.info( | |
| f"SNMF-Diff (rank={rank}): " | |
| f"style_comp={style_component}, " | |
| f"pos_in_style={component_pos_counts[style_component]}/{n_pos}, " | |
| f"cos_sim(C_s, C_not_s)={cos_sim:.4f}, " | |
| f"|v|={np.linalg.norm(v):.4f}" | |
| ) | |
| return v | |
| def get_component_similarity( | |
| self, | |
| h_pos: np.ndarray, | |
| h_neg: np.ndarray, | |
| rank: int = 2, | |
| ) -> float: | |
| """Compute cosine similarity between SNMF components. | |
| Used in the snmf_rank ablation (Section 11 of PROJECT.md). | |
| High similarity = components not well separated = rank too low. | |
| """ | |
| H = np.concatenate([h_pos, h_neg], axis=0).astype(np.float64) | |
| norms = np.linalg.norm(H, axis=1, keepdims=True) | |
| H_norm = H / np.maximum(norms, 1e-8) | |
| G = H_norm @ H_norm.T | |
| G = G - G.min() + 1e-6 | |
| W = _snmf(G, rank=rank, max_iter=self.max_iter) | |
| # Compute pairwise cosine similarity between component centroids | |
| assignments = W.argmax(axis=1) | |
| centroids = [] | |
| for c in range(rank): | |
| mask = assignments == c | |
| if mask.any(): | |
| centroids.append(H[mask].mean(axis=0)) | |
| if len(centroids) < 2: | |
| return 1.0 # degenerate case | |
| cos_sim = np.dot(centroids[0], centroids[1]) / ( | |
| np.linalg.norm(centroids[0]) * np.linalg.norm(centroids[1]) + 1e-8 | |
| ) | |
| return float(cos_sim) | |