| from __future__ import annotations |
| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import logging |
| import os |
| from functools import partial |
| from typing import TYPE_CHECKING, Any |
| from dataclasses import dataclass |
| from torch import Tensor |
| import math |
| import warnings |
|
|
|
|
| |
| |
| |
|
|
| class RMSNorm(nn.Module): |
| """ |
| Root Mean Square Layer Normalization. |
| """ |
|
|
| def __init__(self, config, size: int | None=None): |
| super().__init__() |
| self.eps = getattr(config, 'norm_eps', 1e-05) |
| norm_size = size if size is not None else config.n_embd |
| self.weight = nn.Parameter(torch.ones(norm_size)) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| og = x.dtype |
| x = x.float() |
| var = x.pow(2).mean(-1, keepdim=True) |
| x = x * torch.rsqrt(var + self.eps) |
| return (self.weight * x).to(og) |
|
|
| class BufferCache: |
| """Simple cache for storing tensors (used by RotaryEmbedding).""" |
|
|
| def __init__(self): |
| self._cache: dict[str, torch.Tensor] = {} |
|
|
| def get(self, key: str) -> torch.Tensor | None: |
| return self._cache.get(key) |
|
|
| def __setitem__(self, key: str, value: torch.Tensor): |
| self._cache[key] = value |
|
|
| def __getitem__(self, key: str) -> torch.Tensor: |
| return self._cache[key] |
|
|
| class RotaryEmbedding(nn.Module): |
| """ |
| Rotary Position Embeddings (RoPE). |
| |
| Applies rotary embeddings to queries and keys for position information. |
| |
| Args: |
| dim: Dimension of the rotary embeddings (typically head_dim) |
| max_seq_len: Maximum sequence length to cache |
| base: Base for inverse frequency computation (theta) |
| rope_full_precision: Whether to compute RoPE in full precision |
| """ |
|
|
| def __init__(self, dim: int, max_seq_len: int=2048, base: float=10000.0, rope_full_precision: bool=True): |
| super().__init__() |
| self.dim = dim |
| self.max_seq_len = max_seq_len |
| self.rope_theta = base |
| self.rope_full_precision = rope_full_precision |
| self.__cache = BufferCache() |
| self.get_rotary_embedding(max_seq_len, torch.device('cpu')) |
|
|
| def get_rotary_embedding(self, seq_len: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: |
| """Get or compute rotary embeddings for given sequence length.""" |
| pos_sin = self.__cache.get('rope_pos_sin') |
| pos_cos = self.__cache.get('rope_pos_cos') |
| if pos_sin is not None and pos_cos is not None and (pos_sin.shape[-2] >= seq_len) and (pos_cos.shape[-2] >= seq_len): |
| if pos_sin.device != device: |
| pos_sin = pos_sin.to(device) |
| self.__cache['rope_pos_sin'] = pos_sin |
| if pos_cos.device != device: |
| pos_cos = pos_cos.to(device) |
| self.__cache['rope_pos_cos'] = pos_cos |
| return (pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]) |
| with torch.autocast(device.type, enabled=False): |
| inv_freq = 1.0 / self.rope_theta ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float) / self.dim) |
| seq = torch.arange(seq_len, device=device, dtype=torch.float) |
| freqs = torch.outer(seq, inv_freq) |
| positions = torch.cat((freqs, freqs), dim=-1) |
| pos_sin = positions.sin()[None, None, :, :] |
| pos_cos = positions.cos()[None, None, :, :] |
| self.__cache['rope_pos_sin'] = pos_sin |
| self.__cache['rope_pos_cos'] = pos_cos |
| return (pos_sin, pos_cos) |
|
|
| def rotate_half(self, x: torch.Tensor) -> torch.Tensor: |
| """Rotate half the hidden dims of the input.""" |
| B, nh, T, hs = x.size() |
| x = x.view(B, nh, T, 2, hs // 2) |
| x1, x2 = x.unbind(dim=-2) |
| return torch.cat((-x2, x1), dim=-1) |
|
|
| def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
| """Apply rotary position embeddings to input tensor.""" |
| return (t * pos_cos + self.rotate_half(t) * pos_sin).to(t.dtype) |
|
|
| def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| """Apply rotary embeddings to queries and keys.""" |
| if self.rope_full_precision: |
| q_, k_ = (q.float(), k.float()) |
| else: |
| q_, k_ = (q, k) |
| with torch.autocast(q.device.type, enabled=False): |
| query_len, key_len = (q_.shape[-2], k_.shape[-2]) |
| pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device) |
| pos_sin = pos_sin.type_as(q_) |
| pos_cos = pos_cos.type_as(q_) |
| q_ = self.apply_rotary_pos_emb(pos_sin[:, :, key_len - query_len:key_len, :], pos_cos[:, :, key_len - query_len:key_len, :], q_) |
| k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_) |
| return (q_.type_as(q), k_.type_as(k)) |
|
|
| class MLP(nn.Module): |
| """ |
| Multi-Layer Perceptron with SwiGLU or standard activation. |
| |
| Args: |
| config: Model config with n_embd, mlp_ratio, use_bias, mlp_type, activation |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| if hasattr(config, 'intermediate_size') and config.intermediate_size is not None: |
| intermediate_size = config.intermediate_size |
| else: |
| intermediate_size = getattr(config, 'mlp_ratio', 4) * config.n_embd |
| use_bias = config.use_bias |
| mlp_type = config.mlp_type |
| if mlp_type == 'swiglu': |
| self.c_fc = nn.Linear(config.n_embd, 2 * intermediate_size, bias=use_bias) |
| self.c_proj = nn.Linear(intermediate_size, config.n_embd, bias=use_bias) |
| self.activation = None |
| else: |
| self.c_fc = nn.Linear(config.n_embd, intermediate_size, bias=use_bias) |
| self.c_proj = nn.Linear(intermediate_size, config.n_embd, bias=use_bias) |
| act_map = {'gelu': nn.GELU(approximate='tanh'), 'relu': nn.ReLU(), 'silu': nn.SiLU()} |
| self.activation = act_map[config.activation] |
| self.c_proj.SCALE_INIT = 1 |
| self.config = config |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| mlp_type = getattr(self.config, 'mlp_type', 'swiglu') |
| if mlp_type == 'swiglu': |
| gate_up = self.c_fc(x) |
| up, gate = gate_up.chunk(2, dim=-1) |
| intermediate = F.silu(gate) * up |
| else: |
| intermediate = self.c_fc(x) |
| intermediate = self.activation(intermediate) |
| return self.c_proj(intermediate) |
|
|
| |
| |
| |
|
|
| logger = logging.getLogger(__name__) |
| try: |
| from torch.nn.attention.flex_attention import BlockMask, _dense_to_ordered, flex_attention |
| _FLEX_ATTN_AVAILABLE = True |
| except ImportError: |
| _FLEX_ATTN_AVAILABLE = False |
| BlockMask: Any = None |
| flex_attention: Any = None |
| _dense_to_ordered: Any = None |
| if os.environ.get('STEERLING_USE_FLEX_ATTN', '0') != '1': |
| _FLEX_ATTN_AVAILABLE = False |
| if TYPE_CHECKING: |
| from torch.nn.attention.flex_attention import BlockMask as BlockMaskType |
| from steerling.configs.causal_diffusion import CausalDiffusionConfig |
| if torch.cuda.is_available() and _FLEX_ATTN_AVAILABLE: |
| compiled_flex_attention = torch.compile(flex_attention, fullgraph=True) |
| else: |
| compiled_flex_attention = flex_attention |
|
|
| def block_causal_mask_mod(b: Any, h: Any, q_idx: torch.Tensor, kv_idx: torch.Tensor, *, block_size: int) -> torch.Tensor: |
| """Block-causal mask: causal across blocks, bidirectional within blocks.""" |
| return q_idx // block_size >= kv_idx // block_size |
|
|
| def fast_create_block_causal_mask(attn_block_size: int, seq_length: int, mask_block_size: int, device: torch.device) -> BlockMaskType: |
| """ |
| Fast block-causal mask creation for flex_attention. |
| |
| Analytically computes the sparse block structure instead of evaluating |
| the mask function at every position. |
| """ |
| if not _FLEX_ATTN_AVAILABLE or _dense_to_ordered is None or BlockMask is None: |
| raise RuntimeError('flex_attention not available') |
| num_mask_blocks = -(-seq_length // mask_block_size) |
| attn_blocks_per_mask_block, rem = divmod(mask_block_size, attn_block_size) |
| if rem != 0: |
| raise ValueError(f'mask_block_size ({mask_block_size}) must be divisible by attn_block_size ({attn_block_size})') |
| num_attn_blocks = num_mask_blocks * attn_blocks_per_mask_block |
| lowres_attn_mask = torch.tril(torch.ones(num_attn_blocks, num_attn_blocks, dtype=torch.bool, device=device)) |
| block_attn_count = lowres_attn_mask.reshape(num_mask_blocks, attn_blocks_per_mask_block, num_mask_blocks, attn_blocks_per_mask_block).permute(0, 2, 1, 3).sum(dim=[-2, -1]) |
| max_count = attn_blocks_per_mask_block * attn_blocks_per_mask_block |
| full_block_mask = block_attn_count == max_count |
| if seq_length % mask_block_size > 0: |
| full_block_mask[-1, :] = False |
| normal_block_mask = (block_attn_count > 0) & ~full_block_mask |
| kv_num_blocks, kv_indices = _dense_to_ordered(normal_block_mask) |
| full_kv_num_blocks, full_kv_indices = _dense_to_ordered(full_block_mask) |
| q_num_blocks, q_indices = _dense_to_ordered(normal_block_mask.transpose(-2, -1)) |
| full_q_num_blocks, full_q_indices = _dense_to_ordered(full_block_mask.transpose(-2, -1)) |
| return BlockMask(seq_lengths=(seq_length, seq_length), kv_num_blocks=kv_num_blocks[None, None, ...], kv_indices=kv_indices[None, None, ...], full_kv_num_blocks=full_kv_num_blocks[None, None, ...], full_kv_indices=full_kv_indices[None, None, ...], q_num_blocks=q_num_blocks[None, None, ...], q_indices=q_indices[None, None, ...], full_q_num_blocks=full_q_num_blocks[None, None, ...], full_q_indices=full_q_indices[None, None, ...], mask_mod=partial(block_causal_mask_mod, block_size=attn_block_size), BLOCK_SIZE=(mask_block_size, mask_block_size)) |
|
|
| def sdpa_with_block_causal_mask(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, diff_block_size: int, mask_cache: dict[str, torch.Tensor], enable_gqa: bool=False) -> torch.Tensor: |
| """Fallback using SDPA with dense mask when flex_attention unavailable.""" |
| B, H, T, D = q.shape |
| device = q.device |
| dtype = q.dtype |
| cache_key = f'sdpa_{T}_{device}_{dtype}' |
| if cache_key not in mask_cache: |
| q_idx = torch.arange(T, device=device).unsqueeze(1) |
| kv_idx = torch.arange(T, device=device).unsqueeze(0) |
| bool_mask = q_idx // diff_block_size >= kv_idx // diff_block_size |
| attn_mask = torch.zeros(T, T, device=device, dtype=dtype) |
| attn_mask.masked_fill_(~bool_mask, float('-inf')) |
| mask_cache[cache_key] = attn_mask |
| return F.scaled_dot_product_attention(q, k, v, attn_mask=mask_cache[cache_key], dropout_p=0.0, is_causal=False, enable_gqa=enable_gqa) |
|
|
| class BlockCausalAttention(nn.Module): |
| """Block-causal self-attention with FlexAttention and optional GQA.""" |
| FLEX_MASK_BLOCK_SIZE = 128 |
|
|
| def __init__(self, config: CausalDiffusionConfig) -> None: |
| super().__init__() |
| if not hasattr(config, 'diff_block_size'): |
| raise ValueError("BlockCausalAttention requires 'diff_block_size' in config.") |
| assert config.n_embd % config.n_head == 0 |
| self.config = config |
| self.n_head = config.n_head |
| self.n_embd = config.n_embd |
| self.head_dim = config.n_embd // config.n_head |
| n_kv = getattr(config, 'n_kv_heads', None) |
| self.n_kv_heads = self.n_head if n_kv is None else int(n_kv) |
| if self.n_kv_heads <= 0: |
| raise ValueError(f'n_kv_heads must be >= 1 (got {self.n_kv_heads})') |
| if self.n_head % self.n_kv_heads != 0: |
| raise ValueError(f'n_head ({self.n_head}) must be divisible by n_kv_heads ({self.n_kv_heads})') |
| self.kv_repeat = self.n_head // self.n_kv_heads |
| use_bias = getattr(config, 'use_bias', False) |
| kv_out = self.n_kv_heads * self.head_dim |
| attn_out = self.n_embd + 2 * kv_out |
| self.c_attn = nn.Linear(config.n_embd, attn_out, bias=use_bias) |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=use_bias) |
| self.c_proj.SCALE_INIT = 1 |
| if getattr(config, 'use_qk_norm', False): |
| if getattr(config, 'use_rms_norm', True): |
| self.q_norm: nn.Module | None = RMSNorm(config, size=self.head_dim) |
| self.k_norm: nn.Module | None = RMSNorm(config, size=self.head_dim) |
| else: |
| self.q_norm = nn.LayerNorm(self.head_dim) |
| self.k_norm = nn.LayerNorm(self.head_dim) |
| else: |
| self.q_norm = None |
| self.k_norm = None |
| if getattr(config, 'use_rope', True): |
| self.rope: RotaryEmbedding | None = RotaryEmbedding(dim=self.head_dim, max_seq_len=config.block_size, base=getattr(config, 'rope_base', 500000.0), rope_full_precision=getattr(config, 'rope_full_precision', True)) |
| else: |
| self.rope = None |
| self._mask_cache: dict = {} |
| self._sdpa_mask_cache: dict[str, torch.Tensor] = {} |
| self._logged_attention_mode = False |
|
|
| def _get_block_mask(self, T: int, device: torch.device): |
| cache_key = f'flex_{T}_{device}' |
| if cache_key not in self._mask_cache: |
| diff_block_size = self.config.diff_block_size |
| mask_block_size = self.FLEX_MASK_BLOCK_SIZE |
| if mask_block_size % diff_block_size != 0: |
| mask_block_size = diff_block_size * (mask_block_size // diff_block_size) |
| if mask_block_size == 0: |
| mask_block_size = diff_block_size |
| self._mask_cache[cache_key] = fast_create_block_causal_mask(attn_block_size=diff_block_size, seq_length=T, mask_block_size=mask_block_size, device=device) |
| return self._mask_cache[cache_key] |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| B, T, C = x.size() |
| device = x.device |
| use_flex = _FLEX_ATTN_AVAILABLE and x.is_cuda and (flex_attention is not None) |
| if not self._logged_attention_mode: |
| self._logged_attention_mode = True |
| mode = 'flex_attention' if use_flex else 'SDPA fallback' |
| logger.debug(f'[CausalDiffusion] Using {mode} with GQA (n_head={self.n_head}, n_kv_heads={self.n_kv_heads})') |
| qkv = self.c_attn(x) |
| clip_qkv = getattr(self.config, 'clip_qkv', None) |
| if clip_qkv is not None: |
| qkv = qkv.clamp(min=-clip_qkv, max=clip_qkv) |
| kv_dim = self.n_kv_heads * self.head_dim |
| q, k, v = qkv.split([self.n_embd, kv_dim, kv_dim], dim=2) |
| q = q.reshape(B, T, self.n_head, self.head_dim).transpose(1, 2) |
| k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) |
| v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) |
| if self.q_norm is not None and self.k_norm is not None: |
| q = self.q_norm(q) |
| k = self.k_norm(k) |
| if self.rope is not None: |
| q, k = self.rope(q, k) |
| if use_flex: |
| block_mask = self._get_block_mask(T, device) |
| assert flex_attention is not None and compiled_flex_attention is not None |
| if q.is_cuda: |
| y = compiled_flex_attention(q, k, v, block_mask=block_mask, enable_gqa=True) |
| else: |
| y = flex_attention(q, k, v, block_mask=block_mask, enable_gqa=True) |
| else: |
| y = sdpa_with_block_causal_mask(q, k, v, diff_block_size=self.config.diff_block_size, mask_cache=self._sdpa_mask_cache, enable_gqa=True) |
| y = y.transpose(1, 2).reshape(B, T, C) |
| y = self.c_proj(y) |
| return y |
|
|
| class CausalDiffusionBlock(nn.Module): |
| """Transformer block for CausalDiffusionLM (block-causal attention + MLP).""" |
|
|
| def __init__(self, config: CausalDiffusionConfig) -> None: |
| super().__init__() |
| use_rms_norm = getattr(config, 'use_rms_norm', True) |
| if use_rms_norm: |
| self.ln_1: nn.Module = RMSNorm(config) |
| self.ln_2: nn.Module = RMSNorm(config) |
| else: |
| self.ln_1 = nn.LayerNorm(config.n_embd) |
| self.ln_2 = nn.LayerNorm(config.n_embd) |
| self.norm_order = getattr(config, 'norm_order', 'post') |
| self.attn = BlockCausalAttention(config) |
| self.mlp = MLP(config) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| if self.norm_order == 'pre': |
| x = x + self.attn(self.ln_1(x)) |
| x = x + self.mlp(self.ln_2(x)) |
| else: |
| x = x + self.ln_1(self.attn(x)) |
| x = x + self.ln_2(self.mlp(x)) |
| return x |
|
|
| |
| |
| |
|
|
| class CausalDiffusionLM(nn.Module): |
| """ |
| CausalDiffusionLM transformer backbone with block-causal attention. |
| |
| Pure compute graph — no training code, no loss logic. |
| |
| Args: |
| config: CausalDiffusionConfig with model hyperparameters |
| vocab_size: Vocabulary size (including special tokens) |
| """ |
|
|
| def __init__(self, config: CausalDiffusionConfig, vocab_size: int) -> None: |
| super().__init__() |
| self.config = config |
| self.vocab_size = vocab_size |
| self.tok_emb = nn.Embedding(vocab_size, config.n_embd) |
| self.blocks = nn.ModuleList([CausalDiffusionBlock(config) for _ in range(config.n_layers)]) |
| if config.use_rms_norm: |
| self.ln_f: nn.Module = RMSNorm(config) |
| else: |
| self.ln_f = nn.LayerNorm(config.n_embd) |
| self.lm_head = nn.Linear(config.n_embd, vocab_size, bias=False) |
| if config.weight_sharing: |
| self.tok_emb.weight = self.lm_head.weight |
|
|
| def forward(self, input_ids: torch.Tensor, *, input_embeds: torch.Tensor | None=None, return_hidden: bool=False) -> torch.Tensor: |
| """ |
| Forward pass. |
| |
| Args: |
| input_ids: Token indices [B, T] (may contain mask tokens) |
| input_embeds: Pre-computed embeddings [B, T, D]. If provided, input_ids is ignored. |
| return_hidden: If True, return hidden states before lm_head. |
| |
| Returns: |
| logits [B, T, vocab_size] or hidden_states [B, T, n_embd] |
| """ |
| if input_embeds is not None: |
| x = input_embeds |
| elif input_ids is not None: |
| x = self.tok_emb(input_ids) |
| else: |
| raise ValueError('Either input_ids or input_embeds must be provided') |
| for block in self.blocks: |
| x = block(x) |
| x = self.ln_f(x) |
| if return_hidden: |
| return x |
| return self.lm_head(x) |
|
|
| def get_num_params(self, non_embedding: bool=True) -> int: |
| """Return number of parameters.""" |
| n_params = sum((p.numel() for p in self.parameters())) |
| if non_embedding: |
| n_params -= self.tok_emb.weight.numel() |
| return n_params |
|
|
| def _restore_weight_tying(self) -> None: |
| """Re-establish weight tying after to_empty() or device transfer.""" |
| if self.config.weight_sharing: |
| self.tok_emb.weight = self.lm_head.weight |
|
|
| def _init_weights(self, module: nn.Module) -> None: |
| """Initialize model weights (used for fresh models, not loaded checkpoints).""" |
| if isinstance(module, nn.Linear): |
| std = 0.02 |
| if hasattr(module, 'SCALE_INIT'): |
| std *= (2 * self.config.n_layers) ** (-0.5) |
| torch.nn.init.normal_(module.weight, mean=0.0, std=std) |
| if module.bias is not None: |
| torch.nn.init.zeros_(module.bias) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| elif isinstance(module, RMSNorm): |
| torch.nn.init.ones_(module.weight) |
|
|
| |
| |
| |
|
|
| @dataclass |
| class InterpretableOutput: |
| """ |
| Full output from InterpretableCausalDiffusionLM; it contains all decomposition components for attribution and analysis. |
| """ |
| hidden: Tensor |
| known_features: Tensor |
| known_logits: Tensor | None |
| known_gt_features: Tensor | None |
| known_predicted: Tensor |
| known_weights: Tensor | None |
| known_topk_indices: Tensor | None |
| known_topk_logits: Tensor | None |
| unk: Tensor |
| unk_hat: Tensor | None |
| unk_for_lm: Tensor |
| unknown_logits: Tensor | None |
| unknown_weights: Tensor | None |
| unknown_topk_indices: Tensor | None |
| unknown_topk_logits: Tensor | None |
| composed: Tensor |
| epsilon: Tensor | None |
| epsilon_true: Tensor | None |
|
|
| |
| |
| |
|
|
| logger = logging.getLogger(__name__) |
| LARGE_CONCEPT_THRESHOLD = 50000 |
|
|
| @dataclass |
| class ConceptHeadOutput: |
| """Output from ConceptHead forward pass. |
| |
| Attributes: |
| features: Final concept features after teacher forcing/intervention (B, T, D) |
| gt_features: Ground truth pooled features. None for unknown heads. (B, T, D) or None |
| logits: Full concept logits (B, T, C). Only set if return_logits=True. Usually None. |
| predicted: Predicted features before teacher forcing mixing (B, T, D) |
| weights: Full concept weights (B, T, C). Only set if return_logits=True. Usually None. |
| topk_indices: Top-k concept indices (B, T, k). Set when using streaming top-k. |
| topk_logits: Logits for top-k concepts (B, T, k). Set when using streaming top-k. |
| hidden: Hidden states passed to this head (B, T, D). Stored for attribution. |
| """ |
| features: Tensor |
| gt_features: Tensor | None |
| logits: Tensor | None |
| predicted: Tensor |
| weights: Tensor | None = None |
| topk_indices: Tensor | None = None |
| topk_logits: Tensor | None = None |
| hidden: Tensor | None = None |
|
|
| class ConceptHead(nn.Module): |
| """ |
| Concept decomposition head supporting both known and unknown concepts. |
| Memory-efficient implementation that avoids (B, T, C) allocations by default. |
| |
| Modes: |
| - Known (is_unknown=False): Supports GT, teacher forcing, top-k, interventions |
| - Unknown (is_unknown=True): No GT, no teacher forcing |
| |
| Architectures: |
| - use_attention=False: Linear predictor (n_embd -> n_concepts) |
| - use_attention=True: Query projection + sigmoid attention over embeddings |
| |
| Factorization (for large unknown heads): |
| - factorize=False: Dense embeddings (C, D) and predictor (D, C) |
| - factorize=True: Factorized embeddings (C, r) @ (r, D) where r << D |
| Reduces memory by ~10-20x for large C |
| |
| Memory Safety: |
| - Unknown heads with n_concepts > 50k cannot use dense operations |
| - Interventions are only supported for known heads |
| - return_logits=True is forbidden for large unknown heads |
| - All tensor indexing uses F.embedding for DTensor safety |
| |
| Args: |
| n_concepts: Number of concepts (C) |
| concept_dim: Dimension of concept embeddings (should equal n_embd) |
| n_embd: Model hidden dimension |
| is_unknown: If True, skip GT pooling and teacher forcing |
| use_attention: If True, use attention; else use linear predictor |
| topk: Top-k sparsity for concept weights. None = no sparsity. |
| block_size: Block size for memory-efficient operations |
| pad_multiple: Pad n_concepts to a multiple of this for efficiency |
| store_unknown_weights: If True and use_attention & is_unknown, store logits/weights |
| apply_topk_to_unknown: If True, also apply top-k to unknown concepts |
| topk_on_logits: If True, apply top-k on logits (then sigmoid). If False, on weights. |
| teacher_force_alpha: If None, hard TF. If in [0,1], soft mixing. |
| factorize: If True, use low-rank factorized embeddings |
| factorize_rank: Rank for factorization (r). Lower = less memory, less expressivity. |
| """ |
|
|
| class ConceptPooling(nn.Module): |
| """Memory-efficient sum pooling using scatter-add.""" |
|
|
| def __init__(self, concept_dim: int): |
| super().__init__() |
| self.concept_dim = concept_dim |
|
|
| def forward(self, concept_ids: Tensor, concept_mask: Tensor, concept_embeddings: nn.Embedding) -> Tensor: |
| """ |
| Pool concept embeddings based on ground truth IDs. |
| Uses scatter-add to avoid (B, T, K, D) allocation when K is sparse. |
| |
| Args: |
| concept_ids: (B, T, K) concept indices, -1 for invalid |
| concept_mask: (B, T, K) boolean mask for valid concepts |
| concept_embeddings: Embedding layer to look up |
| |
| Returns: |
| Pooled features (B, T, D) |
| """ |
| B, T, K = concept_ids.shape |
| D = concept_embeddings.embedding_dim |
| device = concept_ids.device |
| valid_mask = concept_mask & (concept_ids != -1) |
| pooled = torch.zeros(B, T, D, device=device, dtype=concept_embeddings.weight.dtype) |
| if not valid_mask.any(): |
| return pooled |
| b_idx, t_idx, k_idx = torch.where(valid_mask) |
| c_ids = concept_ids[b_idx, t_idx, k_idx].long() |
| emb = concept_embeddings(c_ids) |
| flat_idx = b_idx * T + t_idx |
| flat_idx = flat_idx.unsqueeze(-1).expand(-1, D) |
| pooled_flat = pooled.view(B * T, D) |
| pooled_flat.scatter_add_(0, flat_idx, emb) |
| return pooled.view(B, T, D) |
|
|
| def __init__(self, n_concepts: int, concept_dim: int, n_embd: int, is_unknown: bool=False, use_attention: bool=False, topk: int | None=16, topk_features: int | None=None, block_size: int=8192, *, pad_multiple: int=16, store_unknown_weights: bool=False, apply_topk_to_unknown: bool=False, topk_on_logits: bool=False, factorize: bool=False, factorize_rank: int=256): |
| super().__init__() |
| self.n_concepts = n_concepts |
| self.concept_dim = concept_dim |
| self.n_embd = n_embd |
| self.is_unknown = is_unknown |
| self.use_attention = use_attention |
| self.topk = topk |
| self.topk_features = topk_features if topk_features is not None else topk |
| self.block_size = block_size |
| self.pad_multiple = pad_multiple |
| self.store_unknown_weights = store_unknown_weights |
| self.apply_topk_to_unknown = apply_topk_to_unknown |
| self.topk_on_logits = topk_on_logits |
| self.factorize = factorize |
| self.factorize_rank = factorize_rank |
| self._is_large = n_concepts > LARGE_CONCEPT_THRESHOLD |
| self.n_concepts_padded = (n_concepts + pad_multiple - 1) // pad_multiple * pad_multiple |
| if factorize: |
| self.embedding_coef = nn.Embedding(self.n_concepts_padded, factorize_rank) |
| self.embedding_basis = nn.Linear(factorize_rank, concept_dim, bias=False) |
| self.concept_embedding = None |
| if not use_attention: |
| self.predictor_down = nn.Linear(n_embd, factorize_rank, bias=False) |
| self.predictor_up = nn.Linear(factorize_rank, self.n_concepts_padded, bias=False) |
| self.concept_predictor = None |
| else: |
| self.concept_query_projection = nn.Linear(n_embd, concept_dim, bias=False) |
| self.predictor_down = None |
| self.predictor_up = None |
| self.concept_predictor = None |
| dense_params = n_concepts * concept_dim * 2 |
| factorized_params = n_concepts * factorize_rank + factorize_rank * concept_dim + (n_embd * factorize_rank + factorize_rank * n_concepts if not use_attention else 0) |
| logger.info(f'[ConceptHead] Factorized mode: {n_concepts} concepts, rank={factorize_rank}') |
| logger.info(f'[ConceptHead] Memory: {dense_params * 2 / 1000000000.0:.2f} GB (dense) -> {factorized_params * 2 / 1000000000.0:.2f} GB (factorized) = {(1 - factorized_params / dense_params) * 100:.1f}% reduction') |
| else: |
| self.concept_embedding = nn.Embedding(self.n_concepts_padded, concept_dim) |
| self.embedding_coef = None |
| self.embedding_basis = None |
| if use_attention: |
| self.concept_query_projection = nn.Linear(n_embd, concept_dim, bias=False) |
| self.concept_predictor = None |
| else: |
| self.concept_predictor = nn.Linear(n_embd, self.n_concepts_padded, bias=False) |
| self.predictor_down = None |
| self.predictor_up = None |
| self.concept_pooling = self.ConceptPooling(concept_dim) |
| if self.topk_features != self.topk: |
| logger.info(f"[ConceptHead] {('Unknown' if is_unknown else 'Known')} head: topk={self.topk} (loss), topk_features={self.topk_features} (features)") |
| if is_unknown and apply_topk_to_unknown: |
| logger.info(f'[ConceptHead] Unknown head: apply_topk_to_unknown=True, topk={self.topk}') |
| self._init_weights() |
|
|
| def _init_weights(self): |
| """Initialize weights with small values.""" |
| if self.factorize: |
| nn.init.normal_(self.embedding_coef.weight, mean=0.0, std=0.02) |
| nn.init.normal_(self.embedding_basis.weight, mean=0.0, std=0.02) |
| if self.predictor_down is not None: |
| nn.init.normal_(self.predictor_down.weight, mean=0.0, std=0.02) |
| if self.predictor_up is not None: |
| nn.init.normal_(self.predictor_up.weight, mean=0.0, std=0.02) |
| else: |
| if self.concept_embedding is not None: |
| nn.init.normal_(self.concept_embedding.weight, mean=0.0, std=0.02) |
| if self.concept_predictor is not None: |
| nn.init.normal_(self.concept_predictor.weight, mean=0.0, std=0.02) |
| if hasattr(self, 'concept_query_projection') and self.concept_query_projection is not None: |
| nn.init.normal_(self.concept_query_projection.weight, mean=0.0, std=0.02) |
|
|
| def _check_dense_allowed(self, operation: str) -> None: |
| """Raise error if dense operations are requested for large unknown heads.""" |
| if self.is_unknown and self._is_large: |
| raise ValueError(f'{operation} requested for unknown head with {self.n_concepts} concepts. This would allocate multi-GB tensors. Use streaming mode instead. (Threshold: {LARGE_CONCEPT_THRESHOLD})') |
|
|
| @staticmethod |
| def _safe_index(weight: Tensor, indices: Tensor) -> Tensor: |
| """ |
| DTensor-safe indexing using F.embedding. |
| |
| Replaces weight[indices] which crashes under FSDP2/DTensor. |
| |
| Args: |
| weight: (N, D) weight matrix |
| indices: (...) indices to select |
| |
| Returns: |
| (..., D) selected embeddings |
| """ |
| original_shape = indices.shape |
| flat_indices = indices.reshape(-1) |
| flat_result = F.embedding(flat_indices, weight) |
| return flat_result.reshape(*original_shape, -1) |
|
|
| def _get_embedding_weight(self) -> Tensor: |
| """ |
| Get full embedding matrix. |
| |
| For dense: returns concept_embedding.weight |
| For factorized: computes coef @ basis (materializes full matrix) |
| |
| Returns: |
| (C, D) embedding matrix |
| """ |
| if self.concept_embedding is not None: |
| return self.concept_embedding.weight |
| else: |
| return self.embedding_basis(self.embedding_coef.weight) |
|
|
| def _get_embedding(self, indices: Tensor) -> Tensor: |
| """ |
| Get embeddings for specific indices (DTensor-safe). |
| |
| For dense: uses F.embedding |
| For factorized: looks up coef, then applies basis |
| |
| Args: |
| indices: (...) concept indices |
| |
| Returns: |
| (..., D) embeddings |
| """ |
| if self.concept_embedding is not None: |
| return self.concept_embedding(indices) |
| else: |
| coef = self.embedding_coef(indices) |
| return self.embedding_basis(coef) |
|
|
| def _get_predictor_weight(self) -> Tensor | None: |
| """ |
| Get full predictor weight matrix (for linear path only). |
| |
| Returns: |
| (C, D) predictor weight, or None if using attention |
| """ |
| if self.concept_predictor is not None: |
| return self.concept_predictor.weight |
| elif self.predictor_down is not None and self.predictor_up is not None: |
| return self.predictor_up.weight @ self.predictor_down.weight |
| else: |
| return None |
|
|
| @staticmethod |
| def _merge_topk(topv: Tensor, topi: Tensor, v_blk: Tensor, i_blk: Tensor, k: int) -> tuple[Tensor, Tensor]: |
| """Efficient merge of two top-k sets. Memory: O(BT × 2k).""" |
| cand_v = torch.cat([topv, v_blk], dim=1) |
| cand_i = torch.cat([topi, i_blk], dim=1) |
| new_v, sel = torch.topk(cand_v, k, dim=1) |
| new_i = torch.gather(cand_i, 1, sel) |
| return (new_v, new_i) |
|
|
| @staticmethod |
| def linear_block_features(hidden: Tensor, predictor_weight: Tensor, embeddings: Tensor, block_size: int=4096) -> Tensor: |
| """ |
| Memory-efficient linear prediction without materializing (B, T, C). |
| |
| Args: |
| hidden: (B, T, D) |
| predictor_weight: (C, D) |
| embeddings: (C, D) |
| block_size: Concepts per block |
| |
| Returns: |
| Features (B, T, D) |
| """ |
| B, T, D = hidden.shape |
| C = predictor_weight.size(0) |
| output = torch.zeros(B, T, D, dtype=hidden.dtype, device=hidden.device) |
| flat_h = hidden.reshape(-1, D) |
| W_t = predictor_weight.t().contiguous() |
| for start in range(0, C, block_size): |
| end = min(start + block_size, C) |
| logits_block = (flat_h @ W_t[:, start:end]).to(torch.float32) |
| logits_block = logits_block.clamp(-15, 15) |
| weights_block = torch.sigmoid(logits_block) |
| E_block = embeddings[start:end].to(weights_block.dtype) |
| output.add_((weights_block @ E_block).reshape(B, T, D)) |
| return output.to(hidden.dtype) |
|
|
| @staticmethod |
| def attention_block_features(query: Tensor, embeddings: Tensor, block_size: int=4096) -> Tensor: |
| """Memory-efficient attention features without materializing (B, T, C).""" |
| B, T, D = query.shape |
| C = embeddings.shape[0] |
| scale = 1.0 / math.sqrt(D) |
| flat_q = query.reshape(-1, D) |
| emb_T = embeddings.t().contiguous() |
| output = torch.zeros(B * T, D, dtype=query.dtype, device=query.device) |
| for start in range(0, C, block_size): |
| end = min(start + block_size, C) |
| scores = (flat_q @ emb_T[:, start:end]).to(torch.float32) * scale |
| scores = scores.clamp(-15, 15) |
| weights = torch.sigmoid(scores) |
| output.add_(weights @ embeddings[start:end].to(weights.dtype)) |
| return output.reshape(B, T, D).to(query.dtype) |
|
|
| @staticmethod |
| def linear_features_topk_streaming(hidden: Tensor, predictor_weight: Tensor, embeddings: Tensor, k: int, block_size: int=4096, topk_on_logits: bool=False) -> tuple[Tensor, Tensor, Tensor]: |
| """ |
| Memory-efficient linear prediction with streaming top-k. |
| |
| Uses merge-k-with-k to keep memory O(BT × k), not O(BT × block_size). |
| |
| Args: |
| hidden: (B, T, D) |
| predictor_weight: (C, D) |
| embeddings: (C, D) |
| k: Number of top concepts |
| block_size: Concepts per block |
| topk_on_logits: If True, select top-k by logits; else by sigmoid |
| |
| Returns: |
| features: (B, T, D) weighted concept features |
| topk_indices: (B, T, k) indices of top-k concepts |
| topk_logits: (B, T, k) logits for top-k concepts |
| """ |
| B, T, D = hidden.shape |
| C = predictor_weight.size(0) |
| BT = B * T |
| device = hidden.device |
| k = min(k, C) |
| flat_h = hidden.reshape(BT, D) |
| W_t = predictor_weight.t().contiguous() |
| topv = torch.full((BT, k), float('-inf'), device=device, dtype=hidden.dtype) |
| topi = torch.zeros((BT, k), device=device, dtype=torch.long) |
| for start in range(0, C, block_size): |
| end = min(start + block_size, C) |
| logits_blk = (flat_h @ W_t[:, start:end]).to(torch.float32).clamp_(-15, 15) |
| vals_blk = logits_blk if topk_on_logits else torch.sigmoid(logits_blk) |
| blk_k = min(k, end - start) |
| v_blk, idx_blk = torch.topk(vals_blk, blk_k, dim=1) |
| i_blk = idx_blk + start |
| if blk_k < k: |
| pad_v = torch.full((BT, k - blk_k), float('-inf'), device=device, dtype=torch.float32) |
| pad_i = torch.zeros((BT, k - blk_k), device=device, dtype=torch.long) |
| v_blk = torch.cat([v_blk, pad_v], dim=1) |
| i_blk = torch.cat([i_blk, pad_i], dim=1) |
| topv, topi = ConceptHead._merge_topk(topv, topi, v_blk, i_blk, k) |
| W_sel = ConceptHead._safe_index(predictor_weight, topi) |
| logits_sel = torch.einsum('bd,bkd->bk', flat_h.to(torch.float32), W_sel.to(torch.float32)) |
| logits_sel = logits_sel.clamp(-15, 15) |
| del W_sel |
| weights_sel = torch.sigmoid(logits_sel) |
| E_sel = ConceptHead._safe_index(embeddings, topi) |
| features = torch.einsum('bk,bkd->bd', weights_sel.to(E_sel.dtype), E_sel) |
| return (features.reshape(B, T, D).to(hidden.dtype), topi.reshape(B, T, k), logits_sel.reshape(B, T, k)) |
|
|
| @staticmethod |
| def attention_features_topk_streaming(query: Tensor, embeddings: Tensor, k: int, block_size: int=4096, topk_on_logits: bool=False) -> tuple[Tensor, Tensor, Tensor]: |
| """Memory-efficient attention with streaming top-k.""" |
| B, T, D = query.shape |
| C = embeddings.shape[0] |
| BT = B * T |
| device = query.device |
| scale = 1.0 / math.sqrt(D) |
| k = min(k, C) |
| flat_q = query.reshape(BT, D) |
| emb_T = embeddings.t().contiguous() |
| topv = torch.full((BT, k), float('-inf'), device=device, dtype=query.dtype) |
| topi = torch.zeros((BT, k), device=device, dtype=torch.long) |
| for start in range(0, C, block_size): |
| end = min(start + block_size, C) |
| logits_blk = (flat_q @ emb_T[:, start:end]).to(torch.float32) * scale |
| logits_blk = logits_blk.clamp(-15, 15) |
| vals_blk = logits_blk if topk_on_logits else torch.sigmoid(logits_blk) |
| blk_k = min(k, end - start) |
| v_blk, idx_blk = torch.topk(vals_blk, blk_k, dim=1) |
| i_blk = idx_blk + start |
| if blk_k < k: |
| pad_v = torch.full((BT, k - blk_k), float('-inf'), device=device, dtype=torch.float32) |
| pad_i = torch.zeros((BT, k - blk_k), device=device, dtype=torch.long) |
| v_blk = torch.cat([v_blk, pad_v], dim=1) |
| i_blk = torch.cat([i_blk, pad_i], dim=1) |
| topv, topi = ConceptHead._merge_topk(topv, topi, v_blk, i_blk, k) |
| E_sel = ConceptHead._safe_index(embeddings, topi) |
| logits_sel = torch.einsum('bd,bkd->bk', flat_q.to(torch.float32), E_sel.to(torch.float32)) * scale |
| logits_sel = logits_sel.clamp(-15, 15) |
| weights_sel = torch.sigmoid(logits_sel) |
| features = torch.einsum('bk,bkd->bd', weights_sel.to(E_sel.dtype), E_sel) |
| return (features.reshape(B, T, D).to(query.dtype), topi.reshape(B, T, k), logits_sel.reshape(B, T, k)) |
|
|
| def attention_block_features_factorized(self, query: Tensor, block_size: int=4096) -> Tensor: |
| """ |
| Memory-efficient factorized attention over ALL concepts. |
| |
| Uses factorized scoring and feature computation: |
| - Scoring: (query @ basis.T) @ coef.T instead of query @ E.T |
| - Features: (weights @ coef) @ basis instead of weights @ E |
| |
| FLOPs: O(BT * r * (D + C)) instead of O(BT * D * C) |
| |
| Args: |
| query: (B, T, D) query vectors from concept_query_projection |
| block_size: Concepts per block for chunked processing |
| |
| Returns: |
| (B, T, D) weighted concept features |
| """ |
| assert self.factorize, 'Only valid for factorized head' |
| B, T, D = query.shape |
| BT = B * T |
| C = self.n_concepts |
| _ = self.factorize_rank |
| device = query.device |
| scale = 1.0 / math.sqrt(D) |
| flat_q = query.reshape(BT, D) |
| coef = self.embedding_coef.weight[:C] |
| basis_weight = self.embedding_basis.weight |
| q_compressed = flat_q @ basis_weight |
| output = torch.zeros(BT, D, dtype=query.dtype, device=device) |
| _ = (C + block_size - 1) // block_size |
| for _block_idx, start in enumerate(range(0, C, block_size)): |
| end = min(start + block_size, C) |
| coef_chunk = coef[start:end] |
| scores_chunk = (q_compressed @ coef_chunk.T).float() * scale |
| scores_chunk = scores_chunk.clamp(-15, 15) |
| weights_chunk = torch.sigmoid(scores_chunk) |
| weighted_coef = weights_chunk @ coef_chunk.float() |
| features_chunk = weighted_coef @ basis_weight.T.to(weighted_coef.dtype) |
| output.add_(features_chunk) |
| return output.reshape(B, T, D).to(query.dtype) |
|
|
| def attention_features_topk_factorized(self, query: Tensor, k: int, block_size: int=4096) -> tuple[Tensor, Tensor, Tensor]: |
| """ |
| Memory-efficient factorized attention with streaming top-k. |
| |
| Pass 1: Find top-k concepts using factorized scoring |
| Pass 2: Compute features using only top-k embeddings |
| |
| Args: |
| query: (B, T, D) query vectors |
| k: Number of top concepts per token |
| block_size: Concepts per block |
| |
| Returns: |
| features: (B, T, D) weighted concept features |
| topk_indices: (B, T, k) top-k concept indices |
| topk_logits: (B, T, k) logits for top-k concepts |
| """ |
| assert self.factorize, 'Only valid for factorized head' |
| B, T, D = query.shape |
| BT = B * T |
| C = self.n_concepts |
| _ = self.factorize_rank |
| device = query.device |
| scale = 1.0 / math.sqrt(D) |
| k = min(k, C) |
| flat_q = query.reshape(BT, D) |
| coef = self.embedding_coef.weight[:C] |
| basis_weight = self.embedding_basis.weight |
| q_compressed = flat_q @ basis_weight |
| topv = torch.full((BT, k), float('-inf'), device=device, dtype=query.dtype) |
| topi = torch.zeros((BT, k), device=device, dtype=torch.long) |
| for start in range(0, C, block_size): |
| end = min(start + block_size, C) |
| coef_chunk = coef[start:end] |
| scores_chunk = q_compressed.float() @ coef_chunk.T.float() * scale |
| scores_chunk = scores_chunk.clamp(-15, 15) |
| blk_k = min(k, end - start) |
| v_chunk, idx_chunk = torch.topk(scores_chunk, blk_k, dim=1) |
| i_chunk = idx_chunk + start |
| if blk_k < k: |
| pad_v = torch.full((BT, k - blk_k), float('-inf'), device=device, dtype=torch.float32) |
| pad_i = torch.zeros((BT, k - blk_k), device=device, dtype=torch.long) |
| v_chunk = torch.cat([v_chunk, pad_v], dim=1) |
| i_chunk = torch.cat([i_chunk, pad_i], dim=1) |
| topv, topi = self._merge_topk(topv, topi, v_chunk, i_chunk, k) |
| coef_sel = self.embedding_coef(topi) |
| logits_sel = torch.einsum('br,bkr->bk', q_compressed.float(), coef_sel.float()) * scale |
| logits_sel = logits_sel.clamp(-15, 15) |
| weights_sel = torch.sigmoid(logits_sel) |
| weighted_coef = torch.einsum('bk,bkr->br', weights_sel.to(coef_sel.dtype), coef_sel) |
| features = weighted_coef @ basis_weight.T.to(weighted_coef.dtype) |
| return (features.reshape(B, T, D).to(query.dtype), topi.reshape(B, T, k), logits_sel.reshape(B, T, k)) |
|
|
| def linear_block_features_factorized(self, hidden: Tensor, block_size: int=4096) -> Tensor: |
| """ |
| Memory-efficient factorized linear prediction over ALL concepts. |
| |
| Uses factorized predictor: logits = hidden @ down @ up.T |
| Uses factorized embeddings: features = weights @ coef @ basis |
| |
| Args: |
| hidden: (B, T, D) hidden states |
| block_size: Concepts per block |
| |
| Returns: |
| (B, T, D) weighted concept features |
| """ |
| assert self.factorize, 'Only valid for factorized head' |
| assert self.predictor_down is not None, 'Linear path requires predictor' |
| B, T, D = hidden.shape |
| BT = B * T |
| C = self.n_concepts |
| _ = self.factorize_rank |
| device = hidden.device |
| flat_h = hidden.reshape(BT, D) |
| coef = self.embedding_coef.weight[:C] |
| basis_weight = self.embedding_basis.weight |
| down_weight = self.predictor_down.weight |
| up_weight = self.predictor_up.weight[:C] |
| h_compressed = flat_h @ down_weight.T |
| output = torch.zeros(BT, D, dtype=hidden.dtype, device=device) |
| for start in range(0, C, block_size): |
| end = min(start + block_size, C) |
| up_chunk = up_weight[start:end] |
| coef_chunk = coef[start:end] |
| logits_chunk = h_compressed.float() @ up_chunk.T.float() |
| logits_chunk = logits_chunk.clamp(-15, 15) |
| weights_chunk = torch.sigmoid(logits_chunk) |
| weighted_coef = weights_chunk @ coef_chunk.float() |
| features_chunk = weighted_coef @ basis_weight.T.to(weighted_coef.dtype) |
| output.add_(features_chunk) |
| return output.reshape(B, T, D).to(hidden.dtype) |
|
|
| def linear_features_topk_factorized(self, hidden: Tensor, k: int, block_size: int=4096) -> tuple[Tensor, Tensor, Tensor]: |
| """ |
| Memory-efficient factorized linear with streaming top-k. |
| |
| Args: |
| hidden: (B, T, D) hidden states |
| k: Number of top concepts per token |
| block_size: Concepts per block |
| |
| Returns: |
| features: (B, T, D) weighted concept features |
| topk_indices: (B, T, k) top-k concept indices |
| topk_logits: (B, T, k) logits for top-k concepts |
| """ |
| assert self.factorize, 'Only valid for factorized head' |
| assert self.predictor_down is not None, 'Linear path requires predictor' |
| B, T, D = hidden.shape |
| BT = B * T |
| C = self.n_concepts |
| _ = self.factorize_rank |
| device = hidden.device |
| k = min(k, C) |
| flat_h = hidden.reshape(BT, D) |
| down_weight = self.predictor_down.weight |
| up_weight = self.predictor_up.weight[:C] |
| basis_weight = self.embedding_basis.weight |
| h_compressed = flat_h @ down_weight.T |
| topv = torch.full((BT, k), float('-inf'), device=device, dtype=hidden.dtype) |
| topi = torch.zeros((BT, k), device=device, dtype=torch.long) |
| for start in range(0, C, block_size): |
| end = min(start + block_size, C) |
| up_chunk = up_weight[start:end] |
| logits_chunk = h_compressed.float() @ up_chunk.T.float() |
| logits_chunk = logits_chunk.clamp(-15, 15) |
| blk_k = min(k, end - start) |
| v_chunk, idx_chunk = torch.topk(logits_chunk, blk_k, dim=1) |
| i_chunk = idx_chunk + start |
| if blk_k < k: |
| pad_v = torch.full((BT, k - blk_k), float('-inf'), device=device, dtype=torch.float32) |
| pad_i = torch.zeros((BT, k - blk_k), device=device, dtype=torch.long) |
| v_chunk = torch.cat([v_chunk, pad_v], dim=1) |
| i_chunk = torch.cat([i_chunk, pad_i], dim=1) |
| topv, topi = self._merge_topk(topv, topi, v_chunk, i_chunk, k) |
| coef_sel = self.embedding_coef(topi) |
| up_sel = self._safe_index(self.predictor_up.weight[:C], topi) |
| logits_sel = torch.einsum('br,bkr->bk', h_compressed.float(), up_sel.float()) |
| logits_sel = logits_sel.clamp(-15, 15) |
| weights_sel = torch.sigmoid(logits_sel) |
| weighted_coef = torch.einsum('bk,bkr->br', weights_sel.to(coef_sel.dtype), coef_sel) |
| features = weighted_coef @ basis_weight.T.to(weighted_coef.dtype) |
| return (features.reshape(B, T, D).to(hidden.dtype), topi.reshape(B, T, k), logits_sel.reshape(B, T, k)) |
|
|
| def compute_logits_for_indices(self, hidden: Tensor, indices: Tensor) -> Tensor: |
| """ |
| Compute logits for specific concept indices only (sparse). |
| |
| Supports both dense and factorized heads. |
| |
| IMPORTANT: This function materializes (M, K, D) where M is the number of |
| tokens in hidden. Only call this with small M (e.g., masked tokens only). |
| |
| Args: |
| hidden: (M, D) or (B, T, D) hidden states |
| indices: (M, K) or (B, T, K) concept indices |
| |
| Returns: |
| logits: Same shape as indices |
| """ |
| if hidden.dim() == 2: |
| M, D = hidden.shape |
| K = indices.size(-1) |
| flat_h = hidden |
| flat_idx = indices |
| output_shape = indices.shape |
| else: |
| B, T, D = hidden.shape |
| K = indices.size(-1) |
| M = B * T |
| flat_h = hidden.reshape(M, D) |
| flat_idx = indices.reshape(M, K) |
| output_shape = indices.shape |
| estimated_bytes = M * K * D * 2 |
| if estimated_bytes > 1000000000.0: |
| warnings.warn(f'compute_logits_for_indices will allocate ~{estimated_bytes / 1000000000.0:.1f} GB. Consider reducing M={M} (use masked tokens only) or K={K}.') |
| n_valid = self.n_concepts |
| indices_safe = flat_idx.clamp(0, n_valid - 1) |
| if self.use_attention: |
| query = self.concept_query_projection(flat_h.unsqueeze(0)).squeeze(0) |
| scale = 1.0 / math.sqrt(self.concept_dim) |
| E_sel = self._get_embedding(indices_safe) |
| logits = torch.einsum('md,mkd->mk', query.float(), E_sel.float()) * scale |
| else: |
| if self.factorize: |
| W = self._get_predictor_weight()[:n_valid] |
| W_sel = self._safe_index(W, indices_safe) |
| else: |
| W = self.concept_predictor.weight[:n_valid] |
| W_sel = self._safe_index(W, indices_safe) |
| logits = torch.einsum('md,mkd->mk', flat_h.float(), W_sel.float()) |
| return logits.clamp(-15, 15).reshape(output_shape) |
|
|
| def get_concept_weights(self, hidden: Tensor, concept_ids: Tensor) -> Tensor: |
| """ |
| Get sigmoid weights for specific concepts (for attribution). |
| |
| Args: |
| hidden: (B, T, D) or (M, D) hidden states |
| concept_ids: (B, T, K) or (M, K) or (K,) concept indices |
| |
| Returns: |
| weights: Same shape as concept_ids, values in [0, 1] |
| """ |
| if concept_ids.dim() == 1: |
| if hidden.dim() == 2: |
| M = hidden.size(0) |
| concept_ids = concept_ids.unsqueeze(0).expand(M, -1) |
| else: |
| B, T, _ = hidden.shape |
| concept_ids = concept_ids.unsqueeze(0).unsqueeze(0).expand(B, T, -1) |
| logits = self.compute_logits_for_indices(hidden, concept_ids) |
| return torch.sigmoid(logits) |
|
|
| @staticmethod |
| def blocked_logits(query: Tensor, embeddings: Tensor, block_size: int=8192, out_device: torch.device | None=None, out_dtype: torch.dtype=torch.float32) -> Tensor: |
| """ |
| Compute concept logits in column blocks for memory efficiency. |
| |
| logits = query @ embeddings.T / sqrt(D) |
| """ |
| B, T, D = query.shape |
| C = embeddings.size(0) |
| scale = 1.0 / math.sqrt(D) |
| dev = query.device if out_device is None else out_device |
| logits = torch.empty(B, T, C, device=dev, dtype=out_dtype) |
| q = query.reshape(-1, D).to(torch.float32) |
| Et = embeddings.t().contiguous().to(torch.float32) |
| for s in range(0, C, block_size): |
| e = min(s + block_size, C) |
| scores = q @ Et[:, s:e] * scale |
| scores = scores.clamp(-15, 15) |
| logits[:, :, s:e] = scores.reshape(B, T, e - s).to(out_dtype) |
| return logits |
|
|
| @staticmethod |
| def blocked_mix(weights: Tensor, embeddings: Tensor, block_size: int=8192) -> Tensor: |
| """ |
| Compute weighted sum of embeddings in column blocks. |
| |
| output = weights @ embeddings |
| """ |
| B, T, C = weights.shape |
| D = embeddings.size(1) |
| out = torch.zeros(B, T, D, device=weights.device, dtype=weights.dtype) |
| for s in range(0, C, block_size): |
| e = min(s + block_size, C) |
| w_blk = weights[:, :, s:e].to(torch.float32) |
| V_blk = embeddings[s:e].to(w_blk.dtype) |
| out.add_(w_blk @ V_blk) |
| return out.to(weights.dtype) |
|
|
| @staticmethod |
| def sigmoid_block_attention(query: Tensor, embeddings: Tensor, block_size: int=8192, return_logits: bool=False) -> Tensor | tuple[Tensor, Tensor]: |
| """Memory-efficient sigmoid attention using block processing.""" |
| B, T, D = query.shape |
| C = embeddings.shape[0] |
| scale = 1.0 / math.sqrt(D) |
| flat_q = query.reshape(-1, D) |
| emb_T = embeddings.t().contiguous() |
| output = torch.zeros(B * T, D, dtype=query.dtype, device=query.device) |
| logits: Tensor | None = None |
| if return_logits: |
| logits = torch.empty(B, T, C, dtype=torch.float32, device=query.device) |
| for start in range(0, C, block_size): |
| end = min(start + block_size, C) |
| scores = (flat_q @ emb_T[:, start:end]).to(torch.float32) * scale |
| scores = scores.clamp(-15, 15) |
| if logits is not None: |
| logits[:, :, start:end] = scores.reshape(B, T, end - start) |
| weights = torch.sigmoid(scores) |
| output.add_(weights @ embeddings[start:end].to(weights.dtype)) |
| output = output.reshape(B, T, D).to(query.dtype) |
| if return_logits: |
| assert logits is not None |
| return (output, logits) |
| return output |
|
|
| def _apply_sparse_interventions(self, features: Tensor, hidden: Tensor, intervene_ids: Tensor, intervene_vals: Tensor) -> Tensor: |
| """ |
| Apply sparse interventions matching original dense behavior. |
| |
| Original dense behavior: |
| weights = sigmoid(logits) # (B, T, C) |
| weights[..., c] = new_val # Override |
| features = weights @ embeddings |
| |
| Sparse equivalent: |
| features += (new_val - current_weight) * embedding[c] |
| """ |
| B, T, D = features.shape |
| valid = intervene_ids != -1 |
| if not valid.any(): |
| return features |
| ids_safe = intervene_ids.clamp(0, self.n_concepts - 1) |
| current_logits = self.compute_logits_for_indices(hidden, ids_safe) |
| current_weights = torch.sigmoid(current_logits) |
| emb = self._get_embedding(ids_safe) |
| delta = (intervene_vals - current_weights) * valid.float() |
| correction = (delta.unsqueeze(-1) * emb).sum(dim=2) |
| return features + correction |
|
|
| def _apply_dense_interventions(self, concept_weight: Tensor, intervene_ids: Tensor, intervene_vals: Tensor) -> Tensor: |
| """Apply interventions by overriding concept weights (dense path).""" |
| n_valid = min(self.n_concepts, concept_weight.size(-1)) |
| valid_edit = intervene_ids != -1 |
| ids = intervene_ids.clamp(0, n_valid - 1).long() |
| vals = intervene_vals.to(concept_weight.dtype) |
| updates = torch.zeros_like(concept_weight) |
| updates.scatter_add_(2, ids, torch.where(valid_edit, vals, torch.zeros_like(vals))) |
| set_mask = torch.zeros_like(concept_weight, dtype=torch.bool) |
| set_mask.scatter_(2, ids, valid_edit) |
| return torch.where(set_mask, updates, concept_weight) |
|
|
| def topk_with_cutoff(self, tensor: Tensor, dim: int=-1) -> Tensor: |
| """ |
| Apply top-k sparsity, zeroing out all but top-k values. |
| |
| Args: |
| tensor: Input tensor, typically (B, T, C) |
| dim: Dimension to apply top-k (default: last) |
| |
| Returns: |
| Sparse tensor with only top-k values preserved |
| """ |
| assert dim == -1 or dim == tensor.dim() - 1 |
| if self.topk is None: |
| return tensor |
| padded = tensor.size(dim) |
| n_valid = min(self.n_concepts, padded) |
| if n_valid <= 0: |
| return torch.zeros_like(tensor) |
| x = tensor.narrow(dim, 0, n_valid) |
| kk = min(self.topk, n_valid) |
| topv, topi = torch.topk(x, kk, dim=dim) |
| out = torch.zeros_like(x) |
| out.scatter_(dim, topi, topv) |
| if n_valid < padded: |
| pad_shape = list(out.shape) |
| pad_shape[dim] = padded - n_valid |
| pad_zeros = out.new_zeros(pad_shape) |
| out = torch.cat([out, pad_zeros], dim=dim) |
| return out |
|
|
| def _compute_weights(self, concept_logits: Tensor, E: Tensor) -> Tensor: |
| """Compute concept weights from logits, with optional top-k sparsity.""" |
| apply_topk = self.topk is not None and (not self.is_unknown or self.apply_topk_to_unknown) |
| if apply_topk and self.topk_on_logits: |
| logits_for_weights = self.topk_with_cutoff(concept_logits) |
| weights = torch.sigmoid(logits_for_weights).to(E.dtype) |
| return weights |
| weights = torch.sigmoid(concept_logits).to(E.dtype) |
| if apply_topk and (not self.topk_on_logits): |
| weights = self.topk_with_cutoff(weights) |
| return weights |
|
|
| @torch.compiler.disable |
| def forward(self, hidden: Tensor, intervene_ids: Tensor | None=None, intervene_vals: Tensor | None=None, return_logits: bool=False, store_hidden: bool=False) -> ConceptHeadOutput: |
| """ |
| Forward pass for concept decomposition (inference only, no teacher forcing). |
| |
| Args: |
| hidden: Transformer hidden states (B, T, n_embd) |
| intervene_ids: Concept IDs to intervene on (B, T, K_int), -1 = skip |
| intervene_vals: Intervention strength values (B, T, K_int) |
| return_logits: If True, compute full (B, T, C) logits. Forbidden for large heads. |
| store_hidden: If True, store hidden in output for later attribution. |
| |
| Returns: |
| ConceptHeadOutput with features, predicted, topk_indices, topk_logits |
| """ |
| B, T, _ = hidden.shape |
| has_interventions = intervene_ids is not None and intervene_vals is not None |
| if return_logits: |
| self._check_dense_allowed('return_logits=True') |
| n_valid = self.n_concepts |
| concept_logits: Tensor | None = None |
| concept_weight: Tensor | None = None |
| predicted: Tensor |
| topk_indices: Tensor | None = None |
| topk_logits: Tensor | None = None |
| apply_topk = self.topk is not None and (not self.is_unknown or self.apply_topk_to_unknown) |
| k_features = self.topk_features if self.topk_features is not None else self.topk |
| use_dense_intervention = has_interventions and (not self._is_large) |
| if use_dense_intervention: |
| E = self._get_embedding_weight()[:n_valid] |
| if self.use_attention: |
| query = self.concept_query_projection(hidden) |
| concept_logits = self.blocked_logits(query, E, block_size=self.block_size) |
| else: |
| if self.factorize: |
| W = self._get_predictor_weight()[:n_valid] |
| raw_logits = hidden @ W.T |
| else: |
| raw_logits = self.concept_predictor(hidden)[..., :n_valid] |
| concept_logits = raw_logits.float().clamp(-15, 15) |
| concept_weight = self._compute_weights(concept_logits, E) |
| assert intervene_ids is not None and intervene_vals is not None |
| concept_weight = self._apply_dense_interventions(concept_weight, intervene_ids, intervene_vals) |
| predicted = self.blocked_mix(concept_weight, E, block_size=self.block_size) |
| elif self.factorize: |
| if self.use_attention: |
| query = self.concept_query_projection(hidden) |
| if apply_topk: |
| predicted, topk_indices, topk_logits = self.attention_features_topk_factorized(query, k=k_features, block_size=self.block_size) |
| else: |
| predicted = self.attention_block_features_factorized(query, block_size=self.block_size) |
| elif apply_topk: |
| predicted, topk_indices, topk_logits = self.linear_features_topk_factorized(hidden, k=k_features, block_size=self.block_size) |
| else: |
| predicted = self.linear_block_features_factorized(hidden, block_size=self.block_size) |
| elif apply_topk: |
| E = self._get_embedding_weight()[:n_valid] |
| if self.use_attention: |
| query = self.concept_query_projection(hidden) |
| predicted, topk_indices, topk_logits = self.attention_features_topk_streaming(query, E, k=k_features, block_size=self.block_size, topk_on_logits=self.topk_on_logits) |
| else: |
| W = self.concept_predictor.weight[:n_valid] |
| predicted, topk_indices, topk_logits = self.linear_features_topk_streaming(hidden, W, E, k=k_features, block_size=self.block_size, topk_on_logits=self.topk_on_logits) |
| else: |
| E = self._get_embedding_weight()[:n_valid] |
| if self.use_attention: |
| query = self.concept_query_projection(hidden) |
| predicted = self.attention_block_features(query, E, block_size=self.block_size) |
| else: |
| W = self.concept_predictor.weight[:n_valid] |
| predicted = self.linear_block_features(hidden, W, E, block_size=self.block_size) |
| if topk_indices is not None and self.topk is not None and (self.topk_features is not None) and (self.topk_features > self.topk): |
| _, rerank_idx = torch.topk(topk_logits, self.topk, dim=-1) |
| topk_indices = torch.gather(topk_indices, -1, rerank_idx) |
| topk_logits = torch.gather(topk_logits, -1, rerank_idx) |
| if return_logits and (not use_dense_intervention): |
| E = self._get_embedding_weight()[:n_valid] |
| if self.use_attention: |
| query = self.concept_query_projection(hidden) |
| concept_logits = self.blocked_logits(query, E, block_size=self.block_size) |
| else: |
| if self.factorize: |
| W = self._get_predictor_weight()[:n_valid] |
| raw_logits = hidden @ W.T |
| else: |
| raw_logits = self.concept_predictor(hidden)[..., :n_valid] |
| concept_logits = raw_logits.float().clamp(-15, 15) |
| concept_weight = self._compute_weights(concept_logits, E) |
| if not hasattr(self, '_logged_forward_path'): |
| self._logged_forward_path = True |
| path = 'dense_intervention' if use_dense_intervention else 'factorized_topk' if self.factorize and apply_topk else 'factorized_all' if self.factorize else 'streaming_topk' if apply_topk else 'dense_all' |
| logger.info(f"[ConceptHead] {('Unknown' if self.is_unknown else 'Known')} head: path={path}, topk={self.topk}, topk_features={self.topk_features}, n_concepts={self.n_concepts}, factorize={self.factorize}, apply_topk={apply_topk}") |
| if topk_indices is not None and self.topk is not None and (self.topk_features is not None) and (self.topk_features > self.topk): |
| if not hasattr(self, '_logged_topk_slice'): |
| self._logged_topk_slice = True |
| logger.info(f"[ConceptHead] {('Unknown' if self.is_unknown else 'Known')} head: Sliced topk: {self.topk_features} features -> {self.topk} for loss") |
| if has_interventions and (not use_dense_intervention): |
| assert intervene_ids is not None and intervene_vals is not None |
| predicted = self._apply_sparse_interventions(predicted, hidden, intervene_ids, intervene_vals) |
| return ConceptHeadOutput(features=predicted, gt_features=None, logits=concept_logits, predicted=predicted, weights=concept_weight, topk_indices=topk_indices, topk_logits=topk_logits, hidden=hidden.detach() if store_hidden else None) |
|
|
| |
| |
| |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class InterpretableCausalDiffusionLM(nn.Module): |
| """ |
| Interpretable CausalDiffusionLM with concept decomposition heads. |
| |
| Wraps a CausalDiffusionLM and adds: |
| - Known concept head: predicts known concepts from hidden states |
| - Unknown concept head: captures residual features (optional) |
| - Steering via concept interventions |
| |
| Args: |
| config: CausalDiffusionConfig (model architecture) |
| concept_config: ConceptConfig (concept decomposition) |
| vocab_size: Vocabulary size |
| """ |
|
|
| def __init__(self, config: CausalDiffusionConfig, concept_config: ConceptConfig, vocab_size: int): |
| super().__init__() |
| self.config = config |
| self.concept_config = concept_config |
| self.vocab_size = vocab_size |
| self.transformer = CausalDiffusionLM(config, vocab_size) |
| self.known_head = ConceptHead(n_concepts=concept_config.n_concepts, concept_dim=concept_config.concept_dim, n_embd=config.n_embd, is_unknown=False, use_attention=concept_config.use_attention_known, topk=concept_config.topk_known, topk_features=concept_config.topk_known_features, block_size=concept_config.block_size, pad_multiple=concept_config.pad_multiple, store_unknown_weights=False, apply_topk_to_unknown=False, topk_on_logits=concept_config.topk_on_logits) |
| if concept_config.use_unknown: |
| if concept_config.n_unknown_concepts is None: |
| raise ValueError('n_unknown_concepts must be set when use_unknown=True') |
| self.unknown_head: ConceptHead | None = ConceptHead(n_concepts=concept_config.n_unknown_concepts, concept_dim=concept_config.concept_dim, n_embd=config.n_embd, is_unknown=True, use_attention=concept_config.use_attention_unknown, topk=concept_config.unknown_topk, block_size=concept_config.block_size, pad_multiple=concept_config.pad_multiple, store_unknown_weights=False, apply_topk_to_unknown=concept_config.apply_topk_to_unknown, topk_on_logits=concept_config.topk_on_logits, factorize=concept_config.factorize_unknown, factorize_rank=concept_config.factorize_rank) |
| else: |
| self.unknown_head = None |
|
|
| def forward(self, input_ids: Tensor, *, input_embeds: Tensor | None=None, intervene_known_ids: Tensor | None=None, intervene_known_vals: Tensor | None=None, intervene_unknown_ids: Tensor | None=None, intervene_unknown_vals: Tensor | None=None, minimal_output: bool=False, position_injection: Tensor | None=None, steering_inject_layer: int | None=None, steering_inject_alpha: float=1.0, unknown_topk: int=64) -> tuple[Tensor, InterpretableOutput]: |
| """ |
| Forward pass with concept decomposition. |
| |
| Args: |
| input_ids: Token IDs (B, T). May contain mask tokens. |
| input_embeds: Pre-computed embeddings (B, T, D). Overrides input_ids. |
| intervene_known_ids: Known concept IDs to intervene (B, T, K_int) |
| intervene_known_vals: Intervention values for known (B, T, K_int) |
| intervene_unknown_ids: Unknown concept IDs to intervene (B, T, K_int) |
| intervene_unknown_vals: Intervention values for unknown (B, T, K_int) |
| minimal_output: If True, skip some expensive computations |
| position_injection: Per-position steering injection (B, T, D) |
| steering_inject_layer: Inject at layers >= this |
| steering_inject_alpha: Injection strength |
| unknown_topk: Top-k for unknown head attribution |
| |
| Returns: |
| logits: LM logits (B, T, V) |
| outputs: InterpretableOutput with all decomposition components |
| """ |
| need_dense_logits = not minimal_output |
| if position_injection is not None and steering_inject_layer is not None: |
| hidden = self._forward_with_injection(input_ids, input_embeds, position_injection, steering_inject_layer, steering_inject_alpha) |
| else: |
| hidden = self.transformer(input_ids, input_embeds=input_embeds, return_hidden=True) |
| known_out: ConceptHeadOutput = self.known_head(hidden, intervene_ids=intervene_known_ids, intervene_vals=intervene_known_vals, return_logits=need_dense_logits) |
| known_features = known_out.features.to(hidden.dtype) |
| unk = hidden - known_features.detach() |
| unk_for_lm: Tensor = unk |
| unknown_out: ConceptHeadOutput | None = None |
| unk_hat: Tensor | None = None |
| if self.unknown_head is not None: |
| unknown_out = self.unknown_head(hidden.detach(), intervene_ids=intervene_unknown_ids, intervene_vals=intervene_unknown_vals, return_logits=not minimal_output and (not self.unknown_head._is_large)) |
| assert unknown_out is not None |
| unk_hat = unknown_out.features.to(hidden.dtype) |
| unk_for_lm = unk_hat.detach() |
| epsilon_true = None |
| if self.unknown_head is not None and unk_hat is not None: |
| epsilon_true = hidden.detach() - (known_out.predicted + unk_hat) |
| epsilon = None |
| if self.concept_config.use_epsilon_correction and intervene_known_ids is None: |
| epsilon = hidden - (unk_for_lm + known_features) |
| unk_for_lm = unk_for_lm + epsilon |
| composed = unk_for_lm + known_features |
| logits = self.transformer.lm_head(composed) |
| _unk_topk_indices = unknown_out.topk_indices if unknown_out else None |
| _unk_topk_logits = unknown_out.topk_logits if unknown_out else None |
| if not minimal_output and self.unknown_head is not None and (unknown_out is not None) and (_unk_topk_indices is None) and (unknown_topk > 0): |
| with torch.no_grad(): |
| _unk_topk_indices, _unk_topk_logits = self._compute_unknown_topk(hidden, unknown_topk) |
| outputs = InterpretableOutput(hidden=hidden, known_features=known_features, known_logits=known_out.logits, known_gt_features=known_out.gt_features, known_predicted=known_out.predicted, known_weights=known_out.weights, known_topk_indices=known_out.topk_indices, known_topk_logits=known_out.topk_logits, unk=unk, unk_hat=unk_hat, unk_for_lm=unk_for_lm, unknown_logits=unknown_out.logits if unknown_out else None, unknown_weights=unknown_out.weights if unknown_out else None, unknown_topk_indices=_unk_topk_indices, unknown_topk_logits=_unk_topk_logits, composed=composed, epsilon=epsilon, epsilon_true=epsilon_true) |
| return (logits, outputs) |
|
|
| def _compute_unknown_topk(self, hidden: Tensor, unknown_topk: int) -> tuple[Tensor | None, Tensor | None]: |
| """Compute unknown head top-k indices for attribution.""" |
| assert self.unknown_head is not None |
| if self.unknown_head.factorize: |
| if self.unknown_head.use_attention: |
| _query = self.unknown_head.concept_query_projection(hidden.detach()) |
| _, indices, logits = self.unknown_head.attention_features_topk_factorized(_query, k=unknown_topk, block_size=self.unknown_head.block_size) |
| else: |
| _, indices, logits = self.unknown_head.linear_features_topk_factorized(hidden.detach(), k=unknown_topk, block_size=self.unknown_head.block_size) |
| else: |
| _E = self.unknown_head._get_embedding_weight()[:self.unknown_head.n_concepts] |
| if self.unknown_head.use_attention: |
| _query = self.unknown_head.concept_query_projection(hidden.detach()) |
| _, indices, logits = self.unknown_head.attention_features_topk_streaming(_query, _E, k=unknown_topk, block_size=self.unknown_head.block_size) |
| else: |
| _W = self.unknown_head.concept_predictor.weight[:self.unknown_head.n_concepts] |
| _, indices, logits = self.unknown_head.linear_features_topk_streaming(hidden.detach(), _W, _E, k=unknown_topk, block_size=self.unknown_head.block_size) |
| return (indices, logits) |
|
|
| def _forward_with_injection(self, input_ids: Tensor, input_embeds: Tensor | None, position_injection: Tensor, inject_layer: int, inject_alpha: float) -> Tensor: |
| """Forward through transformer with steering injection at specified layers.""" |
| x = input_embeds if input_embeds is not None else self.transformer.tok_emb(input_ids) |
| for i, block in enumerate(self.transformer.blocks): |
| x = block(x) |
| if i + 1 >= inject_layer: |
| x = x + inject_alpha * position_injection |
| x = self.transformer.ln_f(x) |
| return x |
|
|
| @torch.no_grad() |
| def intervene(self, input_ids: Tensor, known: dict[int, float] | None=None, unknown: dict[int, float] | None=None, positions: Tensor | None=None) -> tuple[Tensor, InterpretableOutput]: |
| """ |
| Run inference with concept interventions. |
| |
| Args: |
| input_ids: Input token IDs (B, T) |
| known: Dict mapping known concept IDs to intervention strengths |
| unknown: Dict mapping unknown concept IDs to intervention strengths |
| positions: Bool mask of positions to intervene (B, T). Default: all. |
| |
| Returns: |
| logits: LM logits (B, T, V) |
| outputs: InterpretableOutput |
| """ |
| B, T = input_ids.shape |
| device = input_ids.device |
| if positions is None: |
| positions = torch.ones(B, T, dtype=torch.bool, device=device) |
| int_known_ids, int_known_vals = (None, None) |
| if known is not None and len(known) > 0: |
| int_known_ids, int_known_vals = self._build_intervention_tensors(known, B, T, positions, device) |
| int_unknown_ids, int_unknown_vals = (None, None) |
| if unknown is not None and len(unknown) > 0: |
| int_unknown_ids, int_unknown_vals = self._build_intervention_tensors(unknown, B, T, positions, device) |
| return self(input_ids, intervene_known_ids=int_known_ids, intervene_known_vals=int_known_vals, intervene_unknown_ids=int_unknown_ids, intervene_unknown_vals=int_unknown_vals, minimal_output=False) |
|
|
| @staticmethod |
| def _build_intervention_tensors(interventions: dict[int, float], B: int, T: int, positions: Tensor, device: torch.device) -> tuple[Tensor, Tensor]: |
| """Build intervention tensors for concept steering.""" |
| K = len(interventions) |
| concept_ids = list(interventions.keys()) |
| directions = list(interventions.values()) |
| ids = torch.full((B, T, K), -1, dtype=torch.long, device=device) |
| vals = torch.zeros((B, T, K), dtype=torch.float32, device=device) |
| concept_tensor = torch.tensor(concept_ids, device=device) |
| direction_tensor = torch.tensor(directions, dtype=torch.float32, device=device) |
| n_active = int(positions.sum().item()) |
| ids[positions] = concept_tensor.unsqueeze(0).expand(n_active, -1) |
| vals[positions] = direction_tensor.unsqueeze(0).expand(n_active, -1) |
| return (ids, vals) |
|
|
| def get_num_params(self, non_embedding: bool=True) -> int: |
| n_params = sum((p.numel() for p in self.parameters())) |
| if non_embedding and hasattr(self.transformer, 'tok_emb'): |
| n_params -= self.transformer.tok_emb.weight.numel() |
| return n_params |
| from transformers import PreTrainedModel |
| from .configuration_steerling import SteerlingConfig |
|
|
|
|
| |
| SteerlingBackbone = CausalDiffusionLM |
|
|
|
|
| class SteerlingForCausalLM(PreTrainedModel): |
| config_class = SteerlingConfig |
| supports_gradient_checkpointing = False |
| _tied_weights_keys = ["transformer.lm_head.weight"] |
|
|
| def __init__(self, config: SteerlingConfig): |
| super().__init__(config) |
| |
| self.concept_config = config |
| self.transformer = SteerlingBackbone(config, config.vocab_size) |
| self.known_head = ConceptHead( |
| n_concepts=config.n_concepts, |
| concept_dim=config.concept_dim, |
| n_embd=config.n_embd, |
| is_unknown=False, |
| use_attention=config.use_attention_known, |
| topk=config.topk_known, |
| topk_features=config.topk_known_features, |
| block_size=config.concept_block_size, |
| pad_multiple=config.pad_multiple, |
| store_unknown_weights=False, |
| apply_topk_to_unknown=False, |
| topk_on_logits=config.topk_on_logits, |
| factorize=False, |
| ) |
| if config.use_unknown: |
| self.unknown_head = ConceptHead( |
| n_concepts=config.n_unknown_concepts, |
| concept_dim=config.concept_dim, |
| n_embd=config.n_embd, |
| is_unknown=True, |
| use_attention=config.use_attention_unknown, |
| topk=config.unknown_topk, |
| block_size=config.concept_block_size, |
| pad_multiple=config.pad_multiple, |
| store_unknown_weights=config.store_unknown_weights, |
| apply_topk_to_unknown=config.apply_topk_to_unknown, |
| topk_on_logits=config.topk_on_logits, |
| factorize=config.factorize_unknown, |
| factorize_rank=config.factorize_rank, |
| ) |
| else: |
| self.unknown_head = None |
| self.post_init() |
|
|
| def _init_weights(self, module): |
| pass |
|
|
| def _tie_weights(self): |
| if self.config.weight_sharing: |
| self.transformer.lm_head.weight = self.transformer.tok_emb.weight |
|
|
| def forward(self, input_ids=None, **kwargs): |
| if self.config.interpretable: |
| return InterpretableCausalDiffusionLM.forward(self, input_ids, **kwargs) |
| else: |
| kwargs.pop('minimal_output', None) |
| return CausalDiffusionLM.forward(self, input_ids, **kwargs) |
|
|