stylsteer-vlm / src /methods /snmf_diff.py
abka03's picture
Deploy StyleSteer-VLM demo
e6f24ae verified
"""M11 — SNMF-Diff: Symmetric NMF on binary-probed activations.
Novel contribution: applies Symmetric Non-negative Matrix Factorisation
at rank r=2 to the combined [H_pos; H_neg] matrix. The non-negativity
constraint filters image-content activations from the VLM multimodal
residual stream.
Steering vector: v = C_s - C_not_s (style component minus noise component)
"""
import logging
from typing import Tuple
import numpy as np
from src.methods.base import SteeringMethod
logger = logging.getLogger(__name__)
def _snmf(
X: np.ndarray,
rank: int = 2,
max_iter: int = 500,
tol: float = 1e-6,
seed: int = 42,
) -> np.ndarray:
"""Symmetric Non-negative Matrix Factorisation.
Factorises X ≈ H @ H^T where H ∈ R^(n × rank), H ≥ 0.
Uses multiplicative update rules:
H_ij ← H_ij * sqrt((X @ H)_ij / (H @ H^T @ H)_ij)
Args:
X: (n, n) symmetric non-negative matrix (e.g. gram matrix)
rank: number of components
max_iter: maximum iterations
tol: convergence tolerance (relative change in Frobenius norm)
seed: random seed for initialisation
Returns:
H: (n, rank) non-negative factor matrix
"""
rng = np.random.RandomState(seed)
n = X.shape[0]
# Initialise H randomly
H = rng.rand(n, rank).astype(np.float64) + 1e-6
prev_cost = np.inf
for iteration in range(max_iter):
# Numerator: X @ H
numerator = X @ H
# Denominator: H @ H^T @ H
denominator = H @ (H.T @ H) + 1e-12
# Multiplicative update
H = H * np.sqrt(numerator / denominator)
# Ensure non-negativity
H = np.maximum(H, 1e-12)
# Check convergence
if iteration % 20 == 0:
reconstruction = H @ H.T
cost = np.linalg.norm(X - reconstruction, "fro")
relative_change = abs(prev_cost - cost) / (prev_cost + 1e-12)
if relative_change < tol:
logger.debug(f"SNMF converged at iteration {iteration} (cost={cost:.4f})")
break
prev_cost = cost
return H
class SNMFDiff(SteeringMethod):
"""SNMF-Diff — Novel training-free steering method.
Applies Symmetric NMF to the gram matrix of the combined activation
matrix [H_pos; H_neg] to extract a style component and a noise component.
The steering vector is the difference between the centroids of these
two components.
"""
def __init__(self, rank: int = 2, max_iter: int = 500, **kwargs):
self.rank = rank
self.max_iter = max_iter
@property
def name(self) -> str:
return "SNMF-Diff"
@property
def method_id(self) -> str:
return "M11"
def extract_vector(
self,
h_pos: np.ndarray,
h_neg: np.ndarray,
**kwargs,
) -> np.ndarray:
"""Extract steering vector via SNMF decomposition.
Steps:
1. Combine H = [H_pos; H_neg] and row-normalise
2. Compute gram matrix G = H @ H^T
3. SNMF: G ≈ W @ W^T with rank r
4. Assign each sample to its dominant component
5. C_s = centroid of style-dominant samples in original space
6. C_not_s = centroid of noise-dominant samples
7. v = C_s - C_not_s
Args:
h_pos: (N_pos, d) positive activations
h_neg: (N_neg, d) negative activations
Returns:
(d,) steering vector
"""
rank = kwargs.get("rank", self.rank)
max_iter = kwargs.get("max_iter", self.max_iter)
# Step 1: Combine and normalise
H = np.concatenate([h_pos, h_neg], axis=0).astype(np.float64)
n_pos = len(h_pos)
# Row-normalise to unit norm (per PROJECT.md §17 fix for near-zero components)
norms = np.linalg.norm(H, axis=1, keepdims=True)
norms = np.maximum(norms, 1e-8)
H_norm = H / norms
# Step 2: Gram matrix (shift to non-negative)
G = H_norm @ H_norm.T
G = G - G.min() + 1e-6 # Ensure non-negative
# Step 3: SNMF
W = _snmf(G, rank=rank, max_iter=max_iter)
# Step 4: Assign each sample to dominant component
assignments = W.argmax(axis=1) # (n,)
# Determine which component is the "style" component:
# The one that has more positive samples assigned to it
pos_mask = np.zeros(len(H), dtype=bool)
pos_mask[:n_pos] = True
component_pos_counts = []
for c in range(rank):
c_mask = assignments == c
n_pos_in_c = (c_mask & pos_mask).sum()
component_pos_counts.append(n_pos_in_c)
style_component = np.argmax(component_pos_counts)
noise_component = np.argmin(component_pos_counts)
# Step 5-6: Compute centroids in original (un-normalised) space
style_mask = assignments == style_component
noise_mask = assignments == noise_component
C_s = H[style_mask].mean(axis=0) if style_mask.any() else H[:n_pos].mean(axis=0)
C_not_s = H[noise_mask].mean(axis=0) if noise_mask.any() else H[n_pos:].mean(axis=0)
# Step 7: Steering vector
v = C_s - C_not_s
# Log diagnostics
cos_sim = np.dot(C_s, C_not_s) / (np.linalg.norm(C_s) * np.linalg.norm(C_not_s) + 1e-8)
logger.info(
f"SNMF-Diff (rank={rank}): "
f"style_comp={style_component}, "
f"pos_in_style={component_pos_counts[style_component]}/{n_pos}, "
f"cos_sim(C_s, C_not_s)={cos_sim:.4f}, "
f"|v|={np.linalg.norm(v):.4f}"
)
return v
def get_component_similarity(
self,
h_pos: np.ndarray,
h_neg: np.ndarray,
rank: int = 2,
) -> float:
"""Compute cosine similarity between SNMF components.
Used in the snmf_rank ablation (Section 11 of PROJECT.md).
High similarity = components not well separated = rank too low.
"""
H = np.concatenate([h_pos, h_neg], axis=0).astype(np.float64)
norms = np.linalg.norm(H, axis=1, keepdims=True)
H_norm = H / np.maximum(norms, 1e-8)
G = H_norm @ H_norm.T
G = G - G.min() + 1e-6
W = _snmf(G, rank=rank, max_iter=self.max_iter)
# Compute pairwise cosine similarity between component centroids
assignments = W.argmax(axis=1)
centroids = []
for c in range(rank):
mask = assignments == c
if mask.any():
centroids.append(H[mask].mean(axis=0))
if len(centroids) < 2:
return 1.0 # degenerate case
cos_sim = np.dot(centroids[0], centroids[1]) / (
np.linalg.norm(centroids[0]) * np.linalg.norm(centroids[1]) + 1e-8
)
return float(cos_sim)