from __future__ import annotations # Auto-generated by scripts/build_hf_files_v3.py — do not edit manually. import torch import torch.nn as nn import torch.nn.functional as F import logging import os from functools import partial from typing import TYPE_CHECKING, Any from dataclasses import dataclass from torch import Tensor import math import warnings # ====================================================================== # steerling/models/layers/primitives.py # ====================================================================== class RMSNorm(nn.Module): """ Root Mean Square Layer Normalization. """ def __init__(self, config, size: int | None=None): super().__init__() self.eps = getattr(config, 'norm_eps', 1e-05) norm_size = size if size is not None else config.n_embd self.weight = nn.Parameter(torch.ones(norm_size)) def forward(self, x: torch.Tensor) -> torch.Tensor: og = x.dtype x = x.float() var = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(var + self.eps) return (self.weight * x).to(og) class BufferCache: """Simple cache for storing tensors (used by RotaryEmbedding).""" def __init__(self): self._cache: dict[str, torch.Tensor] = {} def get(self, key: str) -> torch.Tensor | None: return self._cache.get(key) def __setitem__(self, key: str, value: torch.Tensor): self._cache[key] = value def __getitem__(self, key: str) -> torch.Tensor: return self._cache[key] class RotaryEmbedding(nn.Module): """ Rotary Position Embeddings (RoPE). Applies rotary embeddings to queries and keys for position information. Args: dim: Dimension of the rotary embeddings (typically head_dim) max_seq_len: Maximum sequence length to cache base: Base for inverse frequency computation (theta) rope_full_precision: Whether to compute RoPE in full precision """ def __init__(self, dim: int, max_seq_len: int=2048, base: float=10000.0, rope_full_precision: bool=True): super().__init__() self.dim = dim self.max_seq_len = max_seq_len self.rope_theta = base self.rope_full_precision = rope_full_precision self.__cache = BufferCache() self.get_rotary_embedding(max_seq_len, torch.device('cpu')) def get_rotary_embedding(self, seq_len: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: """Get or compute rotary embeddings for given sequence length.""" pos_sin = self.__cache.get('rope_pos_sin') pos_cos = self.__cache.get('rope_pos_cos') if pos_sin is not None and pos_cos is not None and (pos_sin.shape[-2] >= seq_len) and (pos_cos.shape[-2] >= seq_len): if pos_sin.device != device: pos_sin = pos_sin.to(device) self.__cache['rope_pos_sin'] = pos_sin if pos_cos.device != device: pos_cos = pos_cos.to(device) self.__cache['rope_pos_cos'] = pos_cos return (pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]) with torch.autocast(device.type, enabled=False): inv_freq = 1.0 / self.rope_theta ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float) / self.dim) seq = torch.arange(seq_len, device=device, dtype=torch.float) freqs = torch.outer(seq, inv_freq) positions = torch.cat((freqs, freqs), dim=-1) pos_sin = positions.sin()[None, None, :, :] pos_cos = positions.cos()[None, None, :, :] self.__cache['rope_pos_sin'] = pos_sin self.__cache['rope_pos_cos'] = pos_cos return (pos_sin, pos_cos) def rotate_half(self, x: torch.Tensor) -> torch.Tensor: """Rotate half the hidden dims of the input.""" B, nh, T, hs = x.size() x = x.view(B, nh, T, 2, hs // 2) x1, x2 = x.unbind(dim=-2) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor: """Apply rotary position embeddings to input tensor.""" return (t * pos_cos + self.rotate_half(t) * pos_sin).to(t.dtype) def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Apply rotary embeddings to queries and keys.""" if self.rope_full_precision: q_, k_ = (q.float(), k.float()) else: q_, k_ = (q, k) with torch.autocast(q.device.type, enabled=False): query_len, key_len = (q_.shape[-2], k_.shape[-2]) pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device) pos_sin = pos_sin.type_as(q_) pos_cos = pos_cos.type_as(q_) q_ = self.apply_rotary_pos_emb(pos_sin[:, :, key_len - query_len:key_len, :], pos_cos[:, :, key_len - query_len:key_len, :], q_) k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_) return (q_.type_as(q), k_.type_as(k)) class MLP(nn.Module): """ Multi-Layer Perceptron with SwiGLU or standard activation. Args: config: Model config with n_embd, mlp_ratio, use_bias, mlp_type, activation """ def __init__(self, config): super().__init__() if hasattr(config, 'intermediate_size') and config.intermediate_size is not None: intermediate_size = config.intermediate_size else: intermediate_size = getattr(config, 'mlp_ratio', 4) * config.n_embd use_bias = config.use_bias mlp_type = config.mlp_type if mlp_type == 'swiglu': self.c_fc = nn.Linear(config.n_embd, 2 * intermediate_size, bias=use_bias) self.c_proj = nn.Linear(intermediate_size, config.n_embd, bias=use_bias) self.activation = None else: self.c_fc = nn.Linear(config.n_embd, intermediate_size, bias=use_bias) self.c_proj = nn.Linear(intermediate_size, config.n_embd, bias=use_bias) act_map = {'gelu': nn.GELU(approximate='tanh'), 'relu': nn.ReLU(), 'silu': nn.SiLU()} self.activation = act_map[config.activation] self.c_proj.SCALE_INIT = 1 self.config = config def forward(self, x: torch.Tensor) -> torch.Tensor: mlp_type = getattr(self.config, 'mlp_type', 'swiglu') if mlp_type == 'swiglu': gate_up = self.c_fc(x) up, gate = gate_up.chunk(2, dim=-1) intermediate = F.silu(gate) * up else: intermediate = self.c_fc(x) intermediate = self.activation(intermediate) return self.c_proj(intermediate) # ====================================================================== # steerling/models/layers/causal_diffusion_layers.py # ====================================================================== logger = logging.getLogger(__name__) try: from torch.nn.attention.flex_attention import BlockMask, _dense_to_ordered, flex_attention _FLEX_ATTN_AVAILABLE = True except ImportError: _FLEX_ATTN_AVAILABLE = False BlockMask: Any = None flex_attention: Any = None _dense_to_ordered: Any = None if os.environ.get('STEERLING_USE_FLEX_ATTN', '0') != '1': _FLEX_ATTN_AVAILABLE = False if TYPE_CHECKING: from torch.nn.attention.flex_attention import BlockMask as BlockMaskType from steerling.configs.causal_diffusion import CausalDiffusionConfig if torch.cuda.is_available() and _FLEX_ATTN_AVAILABLE: compiled_flex_attention = torch.compile(flex_attention, fullgraph=True) else: compiled_flex_attention = flex_attention def block_causal_mask_mod(b: Any, h: Any, q_idx: torch.Tensor, kv_idx: torch.Tensor, *, block_size: int) -> torch.Tensor: """Block-causal mask: causal across blocks, bidirectional within blocks.""" return q_idx // block_size >= kv_idx // block_size def fast_create_block_causal_mask(attn_block_size: int, seq_length: int, mask_block_size: int, device: torch.device) -> BlockMaskType: """ Fast block-causal mask creation for flex_attention. Analytically computes the sparse block structure instead of evaluating the mask function at every position. """ if not _FLEX_ATTN_AVAILABLE or _dense_to_ordered is None or BlockMask is None: raise RuntimeError('flex_attention not available') num_mask_blocks = -(-seq_length // mask_block_size) attn_blocks_per_mask_block, rem = divmod(mask_block_size, attn_block_size) if rem != 0: raise ValueError(f'mask_block_size ({mask_block_size}) must be divisible by attn_block_size ({attn_block_size})') num_attn_blocks = num_mask_blocks * attn_blocks_per_mask_block lowres_attn_mask = torch.tril(torch.ones(num_attn_blocks, num_attn_blocks, dtype=torch.bool, device=device)) block_attn_count = lowres_attn_mask.reshape(num_mask_blocks, attn_blocks_per_mask_block, num_mask_blocks, attn_blocks_per_mask_block).permute(0, 2, 1, 3).sum(dim=[-2, -1]) max_count = attn_blocks_per_mask_block * attn_blocks_per_mask_block full_block_mask = block_attn_count == max_count if seq_length % mask_block_size > 0: full_block_mask[-1, :] = False normal_block_mask = (block_attn_count > 0) & ~full_block_mask kv_num_blocks, kv_indices = _dense_to_ordered(normal_block_mask) full_kv_num_blocks, full_kv_indices = _dense_to_ordered(full_block_mask) q_num_blocks, q_indices = _dense_to_ordered(normal_block_mask.transpose(-2, -1)) full_q_num_blocks, full_q_indices = _dense_to_ordered(full_block_mask.transpose(-2, -1)) return BlockMask(seq_lengths=(seq_length, seq_length), kv_num_blocks=kv_num_blocks[None, None, ...], kv_indices=kv_indices[None, None, ...], full_kv_num_blocks=full_kv_num_blocks[None, None, ...], full_kv_indices=full_kv_indices[None, None, ...], q_num_blocks=q_num_blocks[None, None, ...], q_indices=q_indices[None, None, ...], full_q_num_blocks=full_q_num_blocks[None, None, ...], full_q_indices=full_q_indices[None, None, ...], mask_mod=partial(block_causal_mask_mod, block_size=attn_block_size), BLOCK_SIZE=(mask_block_size, mask_block_size)) def sdpa_with_block_causal_mask(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, diff_block_size: int, mask_cache: dict[str, torch.Tensor], enable_gqa: bool=False) -> torch.Tensor: """Fallback using SDPA with dense mask when flex_attention unavailable.""" B, H, T, D = q.shape device = q.device dtype = q.dtype cache_key = f'sdpa_{T}_{device}_{dtype}' if cache_key not in mask_cache: q_idx = torch.arange(T, device=device).unsqueeze(1) kv_idx = torch.arange(T, device=device).unsqueeze(0) bool_mask = q_idx // diff_block_size >= kv_idx // diff_block_size attn_mask = torch.zeros(T, T, device=device, dtype=dtype) attn_mask.masked_fill_(~bool_mask, float('-inf')) mask_cache[cache_key] = attn_mask return F.scaled_dot_product_attention(q, k, v, attn_mask=mask_cache[cache_key], dropout_p=0.0, is_causal=False, enable_gqa=enable_gqa) class BlockCausalAttention(nn.Module): """Block-causal self-attention with FlexAttention and optional GQA.""" FLEX_MASK_BLOCK_SIZE = 128 def __init__(self, config: CausalDiffusionConfig) -> None: super().__init__() if not hasattr(config, 'diff_block_size'): raise ValueError("BlockCausalAttention requires 'diff_block_size' in config.") assert config.n_embd % config.n_head == 0 self.config = config self.n_head = config.n_head self.n_embd = config.n_embd self.head_dim = config.n_embd // config.n_head n_kv = getattr(config, 'n_kv_heads', None) self.n_kv_heads = self.n_head if n_kv is None else int(n_kv) if self.n_kv_heads <= 0: raise ValueError(f'n_kv_heads must be >= 1 (got {self.n_kv_heads})') if self.n_head % self.n_kv_heads != 0: raise ValueError(f'n_head ({self.n_head}) must be divisible by n_kv_heads ({self.n_kv_heads})') self.kv_repeat = self.n_head // self.n_kv_heads use_bias = getattr(config, 'use_bias', False) kv_out = self.n_kv_heads * self.head_dim attn_out = self.n_embd + 2 * kv_out self.c_attn = nn.Linear(config.n_embd, attn_out, bias=use_bias) self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=use_bias) self.c_proj.SCALE_INIT = 1 if getattr(config, 'use_qk_norm', False): if getattr(config, 'use_rms_norm', True): self.q_norm: nn.Module | None = RMSNorm(config, size=self.head_dim) self.k_norm: nn.Module | None = RMSNorm(config, size=self.head_dim) else: self.q_norm = nn.LayerNorm(self.head_dim) self.k_norm = nn.LayerNorm(self.head_dim) else: self.q_norm = None self.k_norm = None if getattr(config, 'use_rope', True): self.rope: RotaryEmbedding | None = RotaryEmbedding(dim=self.head_dim, max_seq_len=config.block_size, base=getattr(config, 'rope_base', 500000.0), rope_full_precision=getattr(config, 'rope_full_precision', True)) else: self.rope = None self._mask_cache: dict = {} self._sdpa_mask_cache: dict[str, torch.Tensor] = {} self._logged_attention_mode = False def _get_block_mask(self, T: int, device: torch.device): cache_key = f'flex_{T}_{device}' if cache_key not in self._mask_cache: diff_block_size = self.config.diff_block_size mask_block_size = self.FLEX_MASK_BLOCK_SIZE if mask_block_size % diff_block_size != 0: mask_block_size = diff_block_size * (mask_block_size // diff_block_size) if mask_block_size == 0: mask_block_size = diff_block_size self._mask_cache[cache_key] = fast_create_block_causal_mask(attn_block_size=diff_block_size, seq_length=T, mask_block_size=mask_block_size, device=device) return self._mask_cache[cache_key] def forward(self, x: torch.Tensor) -> torch.Tensor: B, T, C = x.size() device = x.device use_flex = _FLEX_ATTN_AVAILABLE and x.is_cuda and (flex_attention is not None) if not self._logged_attention_mode: self._logged_attention_mode = True mode = 'flex_attention' if use_flex else 'SDPA fallback' logger.debug(f'[CausalDiffusion] Using {mode} with GQA (n_head={self.n_head}, n_kv_heads={self.n_kv_heads})') qkv = self.c_attn(x) clip_qkv = getattr(self.config, 'clip_qkv', None) if clip_qkv is not None: qkv = qkv.clamp(min=-clip_qkv, max=clip_qkv) kv_dim = self.n_kv_heads * self.head_dim q, k, v = qkv.split([self.n_embd, kv_dim, kv_dim], dim=2) q = q.reshape(B, T, self.n_head, self.head_dim).transpose(1, 2) k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) if self.q_norm is not None and self.k_norm is not None: q = self.q_norm(q) k = self.k_norm(k) if self.rope is not None: q, k = self.rope(q, k) if use_flex: block_mask = self._get_block_mask(T, device) assert flex_attention is not None and compiled_flex_attention is not None if q.is_cuda: y = compiled_flex_attention(q, k, v, block_mask=block_mask, enable_gqa=True) else: y = flex_attention(q, k, v, block_mask=block_mask, enable_gqa=True) else: y = sdpa_with_block_causal_mask(q, k, v, diff_block_size=self.config.diff_block_size, mask_cache=self._sdpa_mask_cache, enable_gqa=True) y = y.transpose(1, 2).reshape(B, T, C) y = self.c_proj(y) return y class CausalDiffusionBlock(nn.Module): """Transformer block for CausalDiffusionLM (block-causal attention + MLP).""" def __init__(self, config: CausalDiffusionConfig) -> None: super().__init__() use_rms_norm = getattr(config, 'use_rms_norm', True) if use_rms_norm: self.ln_1: nn.Module = RMSNorm(config) self.ln_2: nn.Module = RMSNorm(config) else: self.ln_1 = nn.LayerNorm(config.n_embd) self.ln_2 = nn.LayerNorm(config.n_embd) self.norm_order = getattr(config, 'norm_order', 'post') self.attn = BlockCausalAttention(config) self.mlp = MLP(config) def forward(self, x: torch.Tensor) -> torch.Tensor: if self.norm_order == 'pre': x = x + self.attn(self.ln_1(x)) x = x + self.mlp(self.ln_2(x)) else: x = x + self.ln_1(self.attn(x)) x = x + self.ln_2(self.mlp(x)) return x # ====================================================================== # steerling/models/causal_diffusion.py # ====================================================================== class CausalDiffusionLM(nn.Module): """ CausalDiffusionLM transformer backbone with block-causal attention. Pure compute graph — no training code, no loss logic. Args: config: CausalDiffusionConfig with model hyperparameters vocab_size: Vocabulary size (including special tokens) """ def __init__(self, config: CausalDiffusionConfig, vocab_size: int) -> None: super().__init__() self.config = config self.vocab_size = vocab_size self.tok_emb = nn.Embedding(vocab_size, config.n_embd) self.blocks = nn.ModuleList([CausalDiffusionBlock(config) for _ in range(config.n_layers)]) if config.use_rms_norm: self.ln_f: nn.Module = RMSNorm(config) else: self.ln_f = nn.LayerNorm(config.n_embd) self.lm_head = nn.Linear(config.n_embd, vocab_size, bias=False) if config.weight_sharing: self.tok_emb.weight = self.lm_head.weight def forward(self, input_ids: torch.Tensor, *, input_embeds: torch.Tensor | None=None, return_hidden: bool=False) -> torch.Tensor: """ Forward pass. Args: input_ids: Token indices [B, T] (may contain mask tokens) input_embeds: Pre-computed embeddings [B, T, D]. If provided, input_ids is ignored. return_hidden: If True, return hidden states before lm_head. Returns: logits [B, T, vocab_size] or hidden_states [B, T, n_embd] """ if input_embeds is not None: x = input_embeds elif input_ids is not None: x = self.tok_emb(input_ids) else: raise ValueError('Either input_ids or input_embeds must be provided') for block in self.blocks: x = block(x) x = self.ln_f(x) if return_hidden: return x return self.lm_head(x) def get_num_params(self, non_embedding: bool=True) -> int: """Return number of parameters.""" n_params = sum((p.numel() for p in self.parameters())) if non_embedding: n_params -= self.tok_emb.weight.numel() return n_params def _restore_weight_tying(self) -> None: """Re-establish weight tying after to_empty() or device transfer.""" if self.config.weight_sharing: self.tok_emb.weight = self.lm_head.weight def _init_weights(self, module: nn.Module) -> None: """Initialize model weights (used for fresh models, not loaded checkpoints).""" if isinstance(module, nn.Linear): std = 0.02 if hasattr(module, 'SCALE_INIT'): std *= (2 * self.config.n_layers) ** (-0.5) torch.nn.init.normal_(module.weight, mean=0.0, std=std) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) elif isinstance(module, RMSNorm): torch.nn.init.ones_(module.weight) # ====================================================================== # steerling/models/interpretable/outputs.py # ====================================================================== @dataclass class InterpretableOutput: """ Full output from InterpretableCausalDiffusionLM; it contains all decomposition components for attribution and analysis. """ hidden: Tensor known_features: Tensor known_logits: Tensor | None known_gt_features: Tensor | None known_predicted: Tensor known_weights: Tensor | None known_topk_indices: Tensor | None known_topk_logits: Tensor | None unk: Tensor unk_hat: Tensor | None unk_for_lm: Tensor unknown_logits: Tensor | None unknown_weights: Tensor | None unknown_topk_indices: Tensor | None unknown_topk_logits: Tensor | None composed: Tensor epsilon: Tensor | None epsilon_true: Tensor | None # ====================================================================== # steerling/models/interpretable/concept_head.py # ====================================================================== logger = logging.getLogger(__name__) LARGE_CONCEPT_THRESHOLD = 50000 @dataclass class ConceptHeadOutput: """Output from ConceptHead forward pass. Attributes: features: Final concept features after teacher forcing/intervention (B, T, D) gt_features: Ground truth pooled features. None for unknown heads. (B, T, D) or None logits: Full concept logits (B, T, C). Only set if return_logits=True. Usually None. predicted: Predicted features before teacher forcing mixing (B, T, D) weights: Full concept weights (B, T, C). Only set if return_logits=True. Usually None. topk_indices: Top-k concept indices (B, T, k). Set when using streaming top-k. topk_logits: Logits for top-k concepts (B, T, k). Set when using streaming top-k. hidden: Hidden states passed to this head (B, T, D). Stored for attribution. """ features: Tensor gt_features: Tensor | None logits: Tensor | None predicted: Tensor weights: Tensor | None = None topk_indices: Tensor | None = None topk_logits: Tensor | None = None hidden: Tensor | None = None class ConceptHead(nn.Module): """ Concept decomposition head supporting both known and unknown concepts. Memory-efficient implementation that avoids (B, T, C) allocations by default. Modes: - Known (is_unknown=False): Supports GT, teacher forcing, top-k, interventions - Unknown (is_unknown=True): No GT, no teacher forcing Architectures: - use_attention=False: Linear predictor (n_embd -> n_concepts) - use_attention=True: Query projection + sigmoid attention over embeddings Factorization (for large unknown heads): - factorize=False: Dense embeddings (C, D) and predictor (D, C) - factorize=True: Factorized embeddings (C, r) @ (r, D) where r << D Reduces memory by ~10-20x for large C Memory Safety: - Unknown heads with n_concepts > 50k cannot use dense operations - Interventions are only supported for known heads - return_logits=True is forbidden for large unknown heads - All tensor indexing uses F.embedding for DTensor safety Args: n_concepts: Number of concepts (C) concept_dim: Dimension of concept embeddings (should equal n_embd) n_embd: Model hidden dimension is_unknown: If True, skip GT pooling and teacher forcing use_attention: If True, use attention; else use linear predictor topk: Top-k sparsity for concept weights. None = no sparsity. block_size: Block size for memory-efficient operations pad_multiple: Pad n_concepts to a multiple of this for efficiency store_unknown_weights: If True and use_attention & is_unknown, store logits/weights apply_topk_to_unknown: If True, also apply top-k to unknown concepts topk_on_logits: If True, apply top-k on logits (then sigmoid). If False, on weights. teacher_force_alpha: If None, hard TF. If in [0,1], soft mixing. factorize: If True, use low-rank factorized embeddings factorize_rank: Rank for factorization (r). Lower = less memory, less expressivity. """ class ConceptPooling(nn.Module): """Memory-efficient sum pooling using scatter-add.""" def __init__(self, concept_dim: int): super().__init__() self.concept_dim = concept_dim def forward(self, concept_ids: Tensor, concept_mask: Tensor, concept_embeddings: nn.Embedding) -> Tensor: """ Pool concept embeddings based on ground truth IDs. Uses scatter-add to avoid (B, T, K, D) allocation when K is sparse. Args: concept_ids: (B, T, K) concept indices, -1 for invalid concept_mask: (B, T, K) boolean mask for valid concepts concept_embeddings: Embedding layer to look up Returns: Pooled features (B, T, D) """ B, T, K = concept_ids.shape D = concept_embeddings.embedding_dim device = concept_ids.device valid_mask = concept_mask & (concept_ids != -1) pooled = torch.zeros(B, T, D, device=device, dtype=concept_embeddings.weight.dtype) if not valid_mask.any(): return pooled b_idx, t_idx, k_idx = torch.where(valid_mask) c_ids = concept_ids[b_idx, t_idx, k_idx].long() emb = concept_embeddings(c_ids) flat_idx = b_idx * T + t_idx flat_idx = flat_idx.unsqueeze(-1).expand(-1, D) pooled_flat = pooled.view(B * T, D) pooled_flat.scatter_add_(0, flat_idx, emb) return pooled.view(B, T, D) def __init__(self, n_concepts: int, concept_dim: int, n_embd: int, is_unknown: bool=False, use_attention: bool=False, topk: int | None=16, topk_features: int | None=None, block_size: int=8192, *, pad_multiple: int=16, store_unknown_weights: bool=False, apply_topk_to_unknown: bool=False, topk_on_logits: bool=False, factorize: bool=False, factorize_rank: int=256): super().__init__() self.n_concepts = n_concepts self.concept_dim = concept_dim self.n_embd = n_embd self.is_unknown = is_unknown self.use_attention = use_attention self.topk = topk self.topk_features = topk_features if topk_features is not None else topk self.block_size = block_size self.pad_multiple = pad_multiple self.store_unknown_weights = store_unknown_weights self.apply_topk_to_unknown = apply_topk_to_unknown self.topk_on_logits = topk_on_logits self.factorize = factorize self.factorize_rank = factorize_rank self._is_large = n_concepts > LARGE_CONCEPT_THRESHOLD self.n_concepts_padded = (n_concepts + pad_multiple - 1) // pad_multiple * pad_multiple if factorize: self.embedding_coef = nn.Embedding(self.n_concepts_padded, factorize_rank) self.embedding_basis = nn.Linear(factorize_rank, concept_dim, bias=False) self.concept_embedding = None if not use_attention: self.predictor_down = nn.Linear(n_embd, factorize_rank, bias=False) self.predictor_up = nn.Linear(factorize_rank, self.n_concepts_padded, bias=False) self.concept_predictor = None else: self.concept_query_projection = nn.Linear(n_embd, concept_dim, bias=False) self.predictor_down = None self.predictor_up = None self.concept_predictor = None dense_params = n_concepts * concept_dim * 2 factorized_params = n_concepts * factorize_rank + factorize_rank * concept_dim + (n_embd * factorize_rank + factorize_rank * n_concepts if not use_attention else 0) logger.info(f'[ConceptHead] Factorized mode: {n_concepts} concepts, rank={factorize_rank}') logger.info(f'[ConceptHead] Memory: {dense_params * 2 / 1000000000.0:.2f} GB (dense) -> {factorized_params * 2 / 1000000000.0:.2f} GB (factorized) = {(1 - factorized_params / dense_params) * 100:.1f}% reduction') else: self.concept_embedding = nn.Embedding(self.n_concepts_padded, concept_dim) self.embedding_coef = None self.embedding_basis = None if use_attention: self.concept_query_projection = nn.Linear(n_embd, concept_dim, bias=False) self.concept_predictor = None else: self.concept_predictor = nn.Linear(n_embd, self.n_concepts_padded, bias=False) self.predictor_down = None self.predictor_up = None self.concept_pooling = self.ConceptPooling(concept_dim) if self.topk_features != self.topk: logger.info(f"[ConceptHead] {('Unknown' if is_unknown else 'Known')} head: topk={self.topk} (loss), topk_features={self.topk_features} (features)") if is_unknown and apply_topk_to_unknown: logger.info(f'[ConceptHead] Unknown head: apply_topk_to_unknown=True, topk={self.topk}') self._init_weights() def _init_weights(self): """Initialize weights with small values.""" if self.factorize: nn.init.normal_(self.embedding_coef.weight, mean=0.0, std=0.02) nn.init.normal_(self.embedding_basis.weight, mean=0.0, std=0.02) if self.predictor_down is not None: nn.init.normal_(self.predictor_down.weight, mean=0.0, std=0.02) if self.predictor_up is not None: nn.init.normal_(self.predictor_up.weight, mean=0.0, std=0.02) else: if self.concept_embedding is not None: nn.init.normal_(self.concept_embedding.weight, mean=0.0, std=0.02) if self.concept_predictor is not None: nn.init.normal_(self.concept_predictor.weight, mean=0.0, std=0.02) if hasattr(self, 'concept_query_projection') and self.concept_query_projection is not None: nn.init.normal_(self.concept_query_projection.weight, mean=0.0, std=0.02) def _check_dense_allowed(self, operation: str) -> None: """Raise error if dense operations are requested for large unknown heads.""" if self.is_unknown and self._is_large: raise ValueError(f'{operation} requested for unknown head with {self.n_concepts} concepts. This would allocate multi-GB tensors. Use streaming mode instead. (Threshold: {LARGE_CONCEPT_THRESHOLD})') @staticmethod def _safe_index(weight: Tensor, indices: Tensor) -> Tensor: """ DTensor-safe indexing using F.embedding. Replaces weight[indices] which crashes under FSDP2/DTensor. Args: weight: (N, D) weight matrix indices: (...) indices to select Returns: (..., D) selected embeddings """ original_shape = indices.shape flat_indices = indices.reshape(-1) flat_result = F.embedding(flat_indices, weight) return flat_result.reshape(*original_shape, -1) def _get_embedding_weight(self) -> Tensor: """ Get full embedding matrix. For dense: returns concept_embedding.weight For factorized: computes coef @ basis (materializes full matrix) Returns: (C, D) embedding matrix """ if self.concept_embedding is not None: return self.concept_embedding.weight else: return self.embedding_basis(self.embedding_coef.weight) def _get_embedding(self, indices: Tensor) -> Tensor: """ Get embeddings for specific indices (DTensor-safe). For dense: uses F.embedding For factorized: looks up coef, then applies basis Args: indices: (...) concept indices Returns: (..., D) embeddings """ if self.concept_embedding is not None: return self.concept_embedding(indices) else: coef = self.embedding_coef(indices) return self.embedding_basis(coef) def _get_predictor_weight(self) -> Tensor | None: """ Get full predictor weight matrix (for linear path only). Returns: (C, D) predictor weight, or None if using attention """ if self.concept_predictor is not None: return self.concept_predictor.weight elif self.predictor_down is not None and self.predictor_up is not None: return self.predictor_up.weight @ self.predictor_down.weight else: return None @staticmethod def _merge_topk(topv: Tensor, topi: Tensor, v_blk: Tensor, i_blk: Tensor, k: int) -> tuple[Tensor, Tensor]: """Efficient merge of two top-k sets. Memory: O(BT × 2k).""" cand_v = torch.cat([topv, v_blk], dim=1) cand_i = torch.cat([topi, i_blk], dim=1) new_v, sel = torch.topk(cand_v, k, dim=1) new_i = torch.gather(cand_i, 1, sel) return (new_v, new_i) @staticmethod def linear_block_features(hidden: Tensor, predictor_weight: Tensor, embeddings: Tensor, block_size: int=4096) -> Tensor: """ Memory-efficient linear prediction without materializing (B, T, C). Args: hidden: (B, T, D) predictor_weight: (C, D) embeddings: (C, D) block_size: Concepts per block Returns: Features (B, T, D) """ B, T, D = hidden.shape C = predictor_weight.size(0) output = torch.zeros(B, T, D, dtype=hidden.dtype, device=hidden.device) flat_h = hidden.reshape(-1, D) W_t = predictor_weight.t().contiguous() for start in range(0, C, block_size): end = min(start + block_size, C) logits_block = (flat_h @ W_t[:, start:end]).to(torch.float32) logits_block = logits_block.clamp(-15, 15) weights_block = torch.sigmoid(logits_block) E_block = embeddings[start:end].to(weights_block.dtype) output.add_((weights_block @ E_block).reshape(B, T, D)) return output.to(hidden.dtype) @staticmethod def attention_block_features(query: Tensor, embeddings: Tensor, block_size: int=4096) -> Tensor: """Memory-efficient attention features without materializing (B, T, C).""" B, T, D = query.shape C = embeddings.shape[0] scale = 1.0 / math.sqrt(D) flat_q = query.reshape(-1, D) emb_T = embeddings.t().contiguous() output = torch.zeros(B * T, D, dtype=query.dtype, device=query.device) for start in range(0, C, block_size): end = min(start + block_size, C) scores = (flat_q @ emb_T[:, start:end]).to(torch.float32) * scale scores = scores.clamp(-15, 15) weights = torch.sigmoid(scores) output.add_(weights @ embeddings[start:end].to(weights.dtype)) return output.reshape(B, T, D).to(query.dtype) @staticmethod def linear_features_topk_streaming(hidden: Tensor, predictor_weight: Tensor, embeddings: Tensor, k: int, block_size: int=4096, topk_on_logits: bool=False) -> tuple[Tensor, Tensor, Tensor]: """ Memory-efficient linear prediction with streaming top-k. Uses merge-k-with-k to keep memory O(BT × k), not O(BT × block_size). Args: hidden: (B, T, D) predictor_weight: (C, D) embeddings: (C, D) k: Number of top concepts block_size: Concepts per block topk_on_logits: If True, select top-k by logits; else by sigmoid Returns: features: (B, T, D) weighted concept features topk_indices: (B, T, k) indices of top-k concepts topk_logits: (B, T, k) logits for top-k concepts """ B, T, D = hidden.shape C = predictor_weight.size(0) BT = B * T device = hidden.device k = min(k, C) flat_h = hidden.reshape(BT, D) W_t = predictor_weight.t().contiguous() topv = torch.full((BT, k), float('-inf'), device=device, dtype=hidden.dtype) topi = torch.zeros((BT, k), device=device, dtype=torch.long) for start in range(0, C, block_size): end = min(start + block_size, C) logits_blk = (flat_h @ W_t[:, start:end]).to(torch.float32).clamp_(-15, 15) vals_blk = logits_blk if topk_on_logits else torch.sigmoid(logits_blk) blk_k = min(k, end - start) v_blk, idx_blk = torch.topk(vals_blk, blk_k, dim=1) i_blk = idx_blk + start if blk_k < k: pad_v = torch.full((BT, k - blk_k), float('-inf'), device=device, dtype=torch.float32) pad_i = torch.zeros((BT, k - blk_k), device=device, dtype=torch.long) v_blk = torch.cat([v_blk, pad_v], dim=1) i_blk = torch.cat([i_blk, pad_i], dim=1) topv, topi = ConceptHead._merge_topk(topv, topi, v_blk, i_blk, k) W_sel = ConceptHead._safe_index(predictor_weight, topi) logits_sel = torch.einsum('bd,bkd->bk', flat_h.to(torch.float32), W_sel.to(torch.float32)) logits_sel = logits_sel.clamp(-15, 15) del W_sel weights_sel = torch.sigmoid(logits_sel) E_sel = ConceptHead._safe_index(embeddings, topi) features = torch.einsum('bk,bkd->bd', weights_sel.to(E_sel.dtype), E_sel) return (features.reshape(B, T, D).to(hidden.dtype), topi.reshape(B, T, k), logits_sel.reshape(B, T, k)) @staticmethod def attention_features_topk_streaming(query: Tensor, embeddings: Tensor, k: int, block_size: int=4096, topk_on_logits: bool=False) -> tuple[Tensor, Tensor, Tensor]: """Memory-efficient attention with streaming top-k.""" B, T, D = query.shape C = embeddings.shape[0] BT = B * T device = query.device scale = 1.0 / math.sqrt(D) k = min(k, C) flat_q = query.reshape(BT, D) emb_T = embeddings.t().contiguous() topv = torch.full((BT, k), float('-inf'), device=device, dtype=query.dtype) topi = torch.zeros((BT, k), device=device, dtype=torch.long) for start in range(0, C, block_size): end = min(start + block_size, C) logits_blk = (flat_q @ emb_T[:, start:end]).to(torch.float32) * scale logits_blk = logits_blk.clamp(-15, 15) vals_blk = logits_blk if topk_on_logits else torch.sigmoid(logits_blk) blk_k = min(k, end - start) v_blk, idx_blk = torch.topk(vals_blk, blk_k, dim=1) i_blk = idx_blk + start if blk_k < k: pad_v = torch.full((BT, k - blk_k), float('-inf'), device=device, dtype=torch.float32) pad_i = torch.zeros((BT, k - blk_k), device=device, dtype=torch.long) v_blk = torch.cat([v_blk, pad_v], dim=1) i_blk = torch.cat([i_blk, pad_i], dim=1) topv, topi = ConceptHead._merge_topk(topv, topi, v_blk, i_blk, k) E_sel = ConceptHead._safe_index(embeddings, topi) logits_sel = torch.einsum('bd,bkd->bk', flat_q.to(torch.float32), E_sel.to(torch.float32)) * scale logits_sel = logits_sel.clamp(-15, 15) weights_sel = torch.sigmoid(logits_sel) features = torch.einsum('bk,bkd->bd', weights_sel.to(E_sel.dtype), E_sel) return (features.reshape(B, T, D).to(query.dtype), topi.reshape(B, T, k), logits_sel.reshape(B, T, k)) def attention_block_features_factorized(self, query: Tensor, block_size: int=4096) -> Tensor: """ Memory-efficient factorized attention over ALL concepts. Uses factorized scoring and feature computation: - Scoring: (query @ basis.T) @ coef.T instead of query @ E.T - Features: (weights @ coef) @ basis instead of weights @ E FLOPs: O(BT * r * (D + C)) instead of O(BT * D * C) Args: query: (B, T, D) query vectors from concept_query_projection block_size: Concepts per block for chunked processing Returns: (B, T, D) weighted concept features """ assert self.factorize, 'Only valid for factorized head' B, T, D = query.shape BT = B * T C = self.n_concepts _ = self.factorize_rank device = query.device scale = 1.0 / math.sqrt(D) flat_q = query.reshape(BT, D) coef = self.embedding_coef.weight[:C] basis_weight = self.embedding_basis.weight q_compressed = flat_q @ basis_weight output = torch.zeros(BT, D, dtype=query.dtype, device=device) _ = (C + block_size - 1) // block_size for _block_idx, start in enumerate(range(0, C, block_size)): end = min(start + block_size, C) coef_chunk = coef[start:end] scores_chunk = (q_compressed @ coef_chunk.T).float() * scale scores_chunk = scores_chunk.clamp(-15, 15) weights_chunk = torch.sigmoid(scores_chunk) weighted_coef = weights_chunk @ coef_chunk.float() features_chunk = weighted_coef @ basis_weight.T.to(weighted_coef.dtype) output.add_(features_chunk) return output.reshape(B, T, D).to(query.dtype) def attention_features_topk_factorized(self, query: Tensor, k: int, block_size: int=4096) -> tuple[Tensor, Tensor, Tensor]: """ Memory-efficient factorized attention with streaming top-k. Pass 1: Find top-k concepts using factorized scoring Pass 2: Compute features using only top-k embeddings Args: query: (B, T, D) query vectors k: Number of top concepts per token block_size: Concepts per block Returns: features: (B, T, D) weighted concept features topk_indices: (B, T, k) top-k concept indices topk_logits: (B, T, k) logits for top-k concepts """ assert self.factorize, 'Only valid for factorized head' B, T, D = query.shape BT = B * T C = self.n_concepts _ = self.factorize_rank device = query.device scale = 1.0 / math.sqrt(D) k = min(k, C) flat_q = query.reshape(BT, D) coef = self.embedding_coef.weight[:C] basis_weight = self.embedding_basis.weight q_compressed = flat_q @ basis_weight topv = torch.full((BT, k), float('-inf'), device=device, dtype=query.dtype) topi = torch.zeros((BT, k), device=device, dtype=torch.long) for start in range(0, C, block_size): end = min(start + block_size, C) coef_chunk = coef[start:end] scores_chunk = q_compressed.float() @ coef_chunk.T.float() * scale scores_chunk = scores_chunk.clamp(-15, 15) blk_k = min(k, end - start) v_chunk, idx_chunk = torch.topk(scores_chunk, blk_k, dim=1) i_chunk = idx_chunk + start if blk_k < k: pad_v = torch.full((BT, k - blk_k), float('-inf'), device=device, dtype=torch.float32) pad_i = torch.zeros((BT, k - blk_k), device=device, dtype=torch.long) v_chunk = torch.cat([v_chunk, pad_v], dim=1) i_chunk = torch.cat([i_chunk, pad_i], dim=1) topv, topi = self._merge_topk(topv, topi, v_chunk, i_chunk, k) coef_sel = self.embedding_coef(topi) logits_sel = torch.einsum('br,bkr->bk', q_compressed.float(), coef_sel.float()) * scale logits_sel = logits_sel.clamp(-15, 15) weights_sel = torch.sigmoid(logits_sel) weighted_coef = torch.einsum('bk,bkr->br', weights_sel.to(coef_sel.dtype), coef_sel) features = weighted_coef @ basis_weight.T.to(weighted_coef.dtype) return (features.reshape(B, T, D).to(query.dtype), topi.reshape(B, T, k), logits_sel.reshape(B, T, k)) def linear_block_features_factorized(self, hidden: Tensor, block_size: int=4096) -> Tensor: """ Memory-efficient factorized linear prediction over ALL concepts. Uses factorized predictor: logits = hidden @ down @ up.T Uses factorized embeddings: features = weights @ coef @ basis Args: hidden: (B, T, D) hidden states block_size: Concepts per block Returns: (B, T, D) weighted concept features """ assert self.factorize, 'Only valid for factorized head' assert self.predictor_down is not None, 'Linear path requires predictor' B, T, D = hidden.shape BT = B * T C = self.n_concepts _ = self.factorize_rank device = hidden.device flat_h = hidden.reshape(BT, D) coef = self.embedding_coef.weight[:C] basis_weight = self.embedding_basis.weight down_weight = self.predictor_down.weight up_weight = self.predictor_up.weight[:C] h_compressed = flat_h @ down_weight.T output = torch.zeros(BT, D, dtype=hidden.dtype, device=device) for start in range(0, C, block_size): end = min(start + block_size, C) up_chunk = up_weight[start:end] coef_chunk = coef[start:end] logits_chunk = h_compressed.float() @ up_chunk.T.float() logits_chunk = logits_chunk.clamp(-15, 15) weights_chunk = torch.sigmoid(logits_chunk) weighted_coef = weights_chunk @ coef_chunk.float() features_chunk = weighted_coef @ basis_weight.T.to(weighted_coef.dtype) output.add_(features_chunk) return output.reshape(B, T, D).to(hidden.dtype) def linear_features_topk_factorized(self, hidden: Tensor, k: int, block_size: int=4096) -> tuple[Tensor, Tensor, Tensor]: """ Memory-efficient factorized linear with streaming top-k. Args: hidden: (B, T, D) hidden states k: Number of top concepts per token block_size: Concepts per block Returns: features: (B, T, D) weighted concept features topk_indices: (B, T, k) top-k concept indices topk_logits: (B, T, k) logits for top-k concepts """ assert self.factorize, 'Only valid for factorized head' assert self.predictor_down is not None, 'Linear path requires predictor' B, T, D = hidden.shape BT = B * T C = self.n_concepts _ = self.factorize_rank device = hidden.device k = min(k, C) flat_h = hidden.reshape(BT, D) down_weight = self.predictor_down.weight up_weight = self.predictor_up.weight[:C] basis_weight = self.embedding_basis.weight h_compressed = flat_h @ down_weight.T topv = torch.full((BT, k), float('-inf'), device=device, dtype=hidden.dtype) topi = torch.zeros((BT, k), device=device, dtype=torch.long) for start in range(0, C, block_size): end = min(start + block_size, C) up_chunk = up_weight[start:end] logits_chunk = h_compressed.float() @ up_chunk.T.float() logits_chunk = logits_chunk.clamp(-15, 15) blk_k = min(k, end - start) v_chunk, idx_chunk = torch.topk(logits_chunk, blk_k, dim=1) i_chunk = idx_chunk + start if blk_k < k: pad_v = torch.full((BT, k - blk_k), float('-inf'), device=device, dtype=torch.float32) pad_i = torch.zeros((BT, k - blk_k), device=device, dtype=torch.long) v_chunk = torch.cat([v_chunk, pad_v], dim=1) i_chunk = torch.cat([i_chunk, pad_i], dim=1) topv, topi = self._merge_topk(topv, topi, v_chunk, i_chunk, k) coef_sel = self.embedding_coef(topi) up_sel = self._safe_index(self.predictor_up.weight[:C], topi) logits_sel = torch.einsum('br,bkr->bk', h_compressed.float(), up_sel.float()) logits_sel = logits_sel.clamp(-15, 15) weights_sel = torch.sigmoid(logits_sel) weighted_coef = torch.einsum('bk,bkr->br', weights_sel.to(coef_sel.dtype), coef_sel) features = weighted_coef @ basis_weight.T.to(weighted_coef.dtype) return (features.reshape(B, T, D).to(hidden.dtype), topi.reshape(B, T, k), logits_sel.reshape(B, T, k)) def compute_logits_for_indices(self, hidden: Tensor, indices: Tensor) -> Tensor: """ Compute logits for specific concept indices only (sparse). Supports both dense and factorized heads. IMPORTANT: This function materializes (M, K, D) where M is the number of tokens in hidden. Only call this with small M (e.g., masked tokens only). Args: hidden: (M, D) or (B, T, D) hidden states indices: (M, K) or (B, T, K) concept indices Returns: logits: Same shape as indices """ if hidden.dim() == 2: M, D = hidden.shape K = indices.size(-1) flat_h = hidden flat_idx = indices output_shape = indices.shape else: B, T, D = hidden.shape K = indices.size(-1) M = B * T flat_h = hidden.reshape(M, D) flat_idx = indices.reshape(M, K) output_shape = indices.shape estimated_bytes = M * K * D * 2 if estimated_bytes > 1000000000.0: warnings.warn(f'compute_logits_for_indices will allocate ~{estimated_bytes / 1000000000.0:.1f} GB. Consider reducing M={M} (use masked tokens only) or K={K}.') n_valid = self.n_concepts indices_safe = flat_idx.clamp(0, n_valid - 1) if self.use_attention: query = self.concept_query_projection(flat_h.unsqueeze(0)).squeeze(0) scale = 1.0 / math.sqrt(self.concept_dim) E_sel = self._get_embedding(indices_safe) logits = torch.einsum('md,mkd->mk', query.float(), E_sel.float()) * scale else: if self.factorize: W = self._get_predictor_weight()[:n_valid] W_sel = self._safe_index(W, indices_safe) else: W = self.concept_predictor.weight[:n_valid] W_sel = self._safe_index(W, indices_safe) logits = torch.einsum('md,mkd->mk', flat_h.float(), W_sel.float()) return logits.clamp(-15, 15).reshape(output_shape) def get_concept_weights(self, hidden: Tensor, concept_ids: Tensor) -> Tensor: """ Get sigmoid weights for specific concepts (for attribution). Args: hidden: (B, T, D) or (M, D) hidden states concept_ids: (B, T, K) or (M, K) or (K,) concept indices Returns: weights: Same shape as concept_ids, values in [0, 1] """ if concept_ids.dim() == 1: if hidden.dim() == 2: M = hidden.size(0) concept_ids = concept_ids.unsqueeze(0).expand(M, -1) else: B, T, _ = hidden.shape concept_ids = concept_ids.unsqueeze(0).unsqueeze(0).expand(B, T, -1) logits = self.compute_logits_for_indices(hidden, concept_ids) return torch.sigmoid(logits) @staticmethod def blocked_logits(query: Tensor, embeddings: Tensor, block_size: int=8192, out_device: torch.device | None=None, out_dtype: torch.dtype=torch.float32) -> Tensor: """ Compute concept logits in column blocks for memory efficiency. logits = query @ embeddings.T / sqrt(D) """ B, T, D = query.shape C = embeddings.size(0) scale = 1.0 / math.sqrt(D) dev = query.device if out_device is None else out_device logits = torch.empty(B, T, C, device=dev, dtype=out_dtype) q = query.reshape(-1, D).to(torch.float32) Et = embeddings.t().contiguous().to(torch.float32) for s in range(0, C, block_size): e = min(s + block_size, C) scores = q @ Et[:, s:e] * scale scores = scores.clamp(-15, 15) logits[:, :, s:e] = scores.reshape(B, T, e - s).to(out_dtype) return logits @staticmethod def blocked_mix(weights: Tensor, embeddings: Tensor, block_size: int=8192) -> Tensor: """ Compute weighted sum of embeddings in column blocks. output = weights @ embeddings """ B, T, C = weights.shape D = embeddings.size(1) out = torch.zeros(B, T, D, device=weights.device, dtype=weights.dtype) for s in range(0, C, block_size): e = min(s + block_size, C) w_blk = weights[:, :, s:e].to(torch.float32) V_blk = embeddings[s:e].to(w_blk.dtype) out.add_(w_blk @ V_blk) return out.to(weights.dtype) @staticmethod def sigmoid_block_attention(query: Tensor, embeddings: Tensor, block_size: int=8192, return_logits: bool=False) -> Tensor | tuple[Tensor, Tensor]: """Memory-efficient sigmoid attention using block processing.""" B, T, D = query.shape C = embeddings.shape[0] scale = 1.0 / math.sqrt(D) flat_q = query.reshape(-1, D) emb_T = embeddings.t().contiguous() output = torch.zeros(B * T, D, dtype=query.dtype, device=query.device) logits: Tensor | None = None if return_logits: logits = torch.empty(B, T, C, dtype=torch.float32, device=query.device) for start in range(0, C, block_size): end = min(start + block_size, C) scores = (flat_q @ emb_T[:, start:end]).to(torch.float32) * scale scores = scores.clamp(-15, 15) if logits is not None: logits[:, :, start:end] = scores.reshape(B, T, end - start) weights = torch.sigmoid(scores) output.add_(weights @ embeddings[start:end].to(weights.dtype)) output = output.reshape(B, T, D).to(query.dtype) if return_logits: assert logits is not None return (output, logits) return output def _apply_sparse_interventions(self, features: Tensor, hidden: Tensor, intervene_ids: Tensor, intervene_vals: Tensor) -> Tensor: """ Apply sparse interventions matching original dense behavior. Original dense behavior: weights = sigmoid(logits) # (B, T, C) weights[..., c] = new_val # Override features = weights @ embeddings Sparse equivalent: features += (new_val - current_weight) * embedding[c] """ B, T, D = features.shape valid = intervene_ids != -1 if not valid.any(): return features ids_safe = intervene_ids.clamp(0, self.n_concepts - 1) current_logits = self.compute_logits_for_indices(hidden, ids_safe) current_weights = torch.sigmoid(current_logits) emb = self._get_embedding(ids_safe) delta = (intervene_vals - current_weights) * valid.float() correction = (delta.unsqueeze(-1) * emb).sum(dim=2) return features + correction def _apply_dense_interventions(self, concept_weight: Tensor, intervene_ids: Tensor, intervene_vals: Tensor) -> Tensor: """Apply interventions by overriding concept weights (dense path).""" n_valid = min(self.n_concepts, concept_weight.size(-1)) valid_edit = intervene_ids != -1 ids = intervene_ids.clamp(0, n_valid - 1).long() vals = intervene_vals.to(concept_weight.dtype) updates = torch.zeros_like(concept_weight) updates.scatter_add_(2, ids, torch.where(valid_edit, vals, torch.zeros_like(vals))) set_mask = torch.zeros_like(concept_weight, dtype=torch.bool) set_mask.scatter_(2, ids, valid_edit) return torch.where(set_mask, updates, concept_weight) def topk_with_cutoff(self, tensor: Tensor, dim: int=-1) -> Tensor: """ Apply top-k sparsity, zeroing out all but top-k values. Args: tensor: Input tensor, typically (B, T, C) dim: Dimension to apply top-k (default: last) Returns: Sparse tensor with only top-k values preserved """ assert dim == -1 or dim == tensor.dim() - 1 if self.topk is None: return tensor padded = tensor.size(dim) n_valid = min(self.n_concepts, padded) if n_valid <= 0: return torch.zeros_like(tensor) x = tensor.narrow(dim, 0, n_valid) kk = min(self.topk, n_valid) topv, topi = torch.topk(x, kk, dim=dim) out = torch.zeros_like(x) out.scatter_(dim, topi, topv) if n_valid < padded: pad_shape = list(out.shape) pad_shape[dim] = padded - n_valid pad_zeros = out.new_zeros(pad_shape) out = torch.cat([out, pad_zeros], dim=dim) return out def _compute_weights(self, concept_logits: Tensor, E: Tensor) -> Tensor: """Compute concept weights from logits, with optional top-k sparsity.""" apply_topk = self.topk is not None and (not self.is_unknown or self.apply_topk_to_unknown) if apply_topk and self.topk_on_logits: logits_for_weights = self.topk_with_cutoff(concept_logits) weights = torch.sigmoid(logits_for_weights).to(E.dtype) return weights weights = torch.sigmoid(concept_logits).to(E.dtype) if apply_topk and (not self.topk_on_logits): weights = self.topk_with_cutoff(weights) return weights @torch.compiler.disable def forward(self, hidden: Tensor, intervene_ids: Tensor | None=None, intervene_vals: Tensor | None=None, return_logits: bool=False, store_hidden: bool=False) -> ConceptHeadOutput: """ Forward pass for concept decomposition (inference only, no teacher forcing). Args: hidden: Transformer hidden states (B, T, n_embd) intervene_ids: Concept IDs to intervene on (B, T, K_int), -1 = skip intervene_vals: Intervention strength values (B, T, K_int) return_logits: If True, compute full (B, T, C) logits. Forbidden for large heads. store_hidden: If True, store hidden in output for later attribution. Returns: ConceptHeadOutput with features, predicted, topk_indices, topk_logits """ B, T, _ = hidden.shape has_interventions = intervene_ids is not None and intervene_vals is not None if return_logits: self._check_dense_allowed('return_logits=True') n_valid = self.n_concepts concept_logits: Tensor | None = None concept_weight: Tensor | None = None predicted: Tensor topk_indices: Tensor | None = None topk_logits: Tensor | None = None apply_topk = self.topk is not None and (not self.is_unknown or self.apply_topk_to_unknown) k_features = self.topk_features if self.topk_features is not None else self.topk use_dense_intervention = has_interventions and (not self._is_large) if use_dense_intervention: E = self._get_embedding_weight()[:n_valid] if self.use_attention: query = self.concept_query_projection(hidden) concept_logits = self.blocked_logits(query, E, block_size=self.block_size) else: if self.factorize: W = self._get_predictor_weight()[:n_valid] raw_logits = hidden @ W.T else: raw_logits = self.concept_predictor(hidden)[..., :n_valid] concept_logits = raw_logits.float().clamp(-15, 15) concept_weight = self._compute_weights(concept_logits, E) assert intervene_ids is not None and intervene_vals is not None concept_weight = self._apply_dense_interventions(concept_weight, intervene_ids, intervene_vals) predicted = self.blocked_mix(concept_weight, E, block_size=self.block_size) elif self.factorize: if self.use_attention: query = self.concept_query_projection(hidden) if apply_topk: predicted, topk_indices, topk_logits = self.attention_features_topk_factorized(query, k=k_features, block_size=self.block_size) else: predicted = self.attention_block_features_factorized(query, block_size=self.block_size) elif apply_topk: predicted, topk_indices, topk_logits = self.linear_features_topk_factorized(hidden, k=k_features, block_size=self.block_size) else: predicted = self.linear_block_features_factorized(hidden, block_size=self.block_size) elif apply_topk: E = self._get_embedding_weight()[:n_valid] if self.use_attention: query = self.concept_query_projection(hidden) predicted, topk_indices, topk_logits = self.attention_features_topk_streaming(query, E, k=k_features, block_size=self.block_size, topk_on_logits=self.topk_on_logits) else: W = self.concept_predictor.weight[:n_valid] predicted, topk_indices, topk_logits = self.linear_features_topk_streaming(hidden, W, E, k=k_features, block_size=self.block_size, topk_on_logits=self.topk_on_logits) else: E = self._get_embedding_weight()[:n_valid] if self.use_attention: query = self.concept_query_projection(hidden) predicted = self.attention_block_features(query, E, block_size=self.block_size) else: W = self.concept_predictor.weight[:n_valid] predicted = self.linear_block_features(hidden, W, E, block_size=self.block_size) if topk_indices is not None and self.topk is not None and (self.topk_features is not None) and (self.topk_features > self.topk): _, rerank_idx = torch.topk(topk_logits, self.topk, dim=-1) topk_indices = torch.gather(topk_indices, -1, rerank_idx) topk_logits = torch.gather(topk_logits, -1, rerank_idx) if return_logits and (not use_dense_intervention): E = self._get_embedding_weight()[:n_valid] if self.use_attention: query = self.concept_query_projection(hidden) concept_logits = self.blocked_logits(query, E, block_size=self.block_size) else: if self.factorize: W = self._get_predictor_weight()[:n_valid] raw_logits = hidden @ W.T else: raw_logits = self.concept_predictor(hidden)[..., :n_valid] concept_logits = raw_logits.float().clamp(-15, 15) concept_weight = self._compute_weights(concept_logits, E) if not hasattr(self, '_logged_forward_path'): self._logged_forward_path = True path = 'dense_intervention' if use_dense_intervention else 'factorized_topk' if self.factorize and apply_topk else 'factorized_all' if self.factorize else 'streaming_topk' if apply_topk else 'dense_all' logger.info(f"[ConceptHead] {('Unknown' if self.is_unknown else 'Known')} head: path={path}, topk={self.topk}, topk_features={self.topk_features}, n_concepts={self.n_concepts}, factorize={self.factorize}, apply_topk={apply_topk}") if topk_indices is not None and self.topk is not None and (self.topk_features is not None) and (self.topk_features > self.topk): if not hasattr(self, '_logged_topk_slice'): self._logged_topk_slice = True logger.info(f"[ConceptHead] {('Unknown' if self.is_unknown else 'Known')} head: Sliced topk: {self.topk_features} features -> {self.topk} for loss") if has_interventions and (not use_dense_intervention): assert intervene_ids is not None and intervene_vals is not None predicted = self._apply_sparse_interventions(predicted, hidden, intervene_ids, intervene_vals) return ConceptHeadOutput(features=predicted, gt_features=None, logits=concept_logits, predicted=predicted, weights=concept_weight, topk_indices=topk_indices, topk_logits=topk_logits, hidden=hidden.detach() if store_hidden else None) # ====================================================================== # steerling/models/interpretable/interpretable_causal_diffusion.py # ====================================================================== logger = logging.getLogger(__name__) class InterpretableCausalDiffusionLM(nn.Module): """ Interpretable CausalDiffusionLM with concept decomposition heads. Wraps a CausalDiffusionLM and adds: - Known concept head: predicts known concepts from hidden states - Unknown concept head: captures residual features (optional) - Steering via concept interventions Args: config: CausalDiffusionConfig (model architecture) concept_config: ConceptConfig (concept decomposition) vocab_size: Vocabulary size """ def __init__(self, config: CausalDiffusionConfig, concept_config: ConceptConfig, vocab_size: int): super().__init__() self.config = config self.concept_config = concept_config self.vocab_size = vocab_size self.transformer = CausalDiffusionLM(config, vocab_size) self.known_head = ConceptHead(n_concepts=concept_config.n_concepts, concept_dim=concept_config.concept_dim, n_embd=config.n_embd, is_unknown=False, use_attention=concept_config.use_attention_known, topk=concept_config.topk_known, topk_features=concept_config.topk_known_features, block_size=concept_config.block_size, pad_multiple=concept_config.pad_multiple, store_unknown_weights=False, apply_topk_to_unknown=False, topk_on_logits=concept_config.topk_on_logits) if concept_config.use_unknown: if concept_config.n_unknown_concepts is None: raise ValueError('n_unknown_concepts must be set when use_unknown=True') self.unknown_head: ConceptHead | None = ConceptHead(n_concepts=concept_config.n_unknown_concepts, concept_dim=concept_config.concept_dim, n_embd=config.n_embd, is_unknown=True, use_attention=concept_config.use_attention_unknown, topk=concept_config.unknown_topk, block_size=concept_config.block_size, pad_multiple=concept_config.pad_multiple, store_unknown_weights=False, apply_topk_to_unknown=concept_config.apply_topk_to_unknown, topk_on_logits=concept_config.topk_on_logits, factorize=concept_config.factorize_unknown, factorize_rank=concept_config.factorize_rank) else: self.unknown_head = None def forward(self, input_ids: Tensor, *, input_embeds: Tensor | None=None, intervene_known_ids: Tensor | None=None, intervene_known_vals: Tensor | None=None, intervene_unknown_ids: Tensor | None=None, intervene_unknown_vals: Tensor | None=None, minimal_output: bool=False, position_injection: Tensor | None=None, steering_inject_layer: int | None=None, steering_inject_alpha: float=1.0, unknown_topk: int=64) -> tuple[Tensor, InterpretableOutput]: """ Forward pass with concept decomposition. Args: input_ids: Token IDs (B, T). May contain mask tokens. input_embeds: Pre-computed embeddings (B, T, D). Overrides input_ids. intervene_known_ids: Known concept IDs to intervene (B, T, K_int) intervene_known_vals: Intervention values for known (B, T, K_int) intervene_unknown_ids: Unknown concept IDs to intervene (B, T, K_int) intervene_unknown_vals: Intervention values for unknown (B, T, K_int) minimal_output: If True, skip some expensive computations position_injection: Per-position steering injection (B, T, D) steering_inject_layer: Inject at layers >= this steering_inject_alpha: Injection strength unknown_topk: Top-k for unknown head attribution Returns: logits: LM logits (B, T, V) outputs: InterpretableOutput with all decomposition components """ need_dense_logits = not minimal_output if position_injection is not None and steering_inject_layer is not None: hidden = self._forward_with_injection(input_ids, input_embeds, position_injection, steering_inject_layer, steering_inject_alpha) else: hidden = self.transformer(input_ids, input_embeds=input_embeds, return_hidden=True) known_out: ConceptHeadOutput = self.known_head(hidden, intervene_ids=intervene_known_ids, intervene_vals=intervene_known_vals, return_logits=need_dense_logits) known_features = known_out.features.to(hidden.dtype) unk = hidden - known_features.detach() unk_for_lm: Tensor = unk unknown_out: ConceptHeadOutput | None = None unk_hat: Tensor | None = None if self.unknown_head is not None: unknown_out = self.unknown_head(hidden.detach(), intervene_ids=intervene_unknown_ids, intervene_vals=intervene_unknown_vals, return_logits=not minimal_output and (not self.unknown_head._is_large)) assert unknown_out is not None unk_hat = unknown_out.features.to(hidden.dtype) unk_for_lm = unk_hat.detach() epsilon_true = None if self.unknown_head is not None and unk_hat is not None: epsilon_true = hidden.detach() - (known_out.predicted + unk_hat) epsilon = None if self.concept_config.use_epsilon_correction and intervene_known_ids is None: epsilon = hidden - (unk_for_lm + known_features) unk_for_lm = unk_for_lm + epsilon composed = unk_for_lm + known_features logits = self.transformer.lm_head(composed) _unk_topk_indices = unknown_out.topk_indices if unknown_out else None _unk_topk_logits = unknown_out.topk_logits if unknown_out else None if not minimal_output and self.unknown_head is not None and (unknown_out is not None) and (_unk_topk_indices is None) and (unknown_topk > 0): with torch.no_grad(): _unk_topk_indices, _unk_topk_logits = self._compute_unknown_topk(hidden, unknown_topk) outputs = InterpretableOutput(hidden=hidden, known_features=known_features, known_logits=known_out.logits, known_gt_features=known_out.gt_features, known_predicted=known_out.predicted, known_weights=known_out.weights, known_topk_indices=known_out.topk_indices, known_topk_logits=known_out.topk_logits, unk=unk, unk_hat=unk_hat, unk_for_lm=unk_for_lm, unknown_logits=unknown_out.logits if unknown_out else None, unknown_weights=unknown_out.weights if unknown_out else None, unknown_topk_indices=_unk_topk_indices, unknown_topk_logits=_unk_topk_logits, composed=composed, epsilon=epsilon, epsilon_true=epsilon_true) return (logits, outputs) def _compute_unknown_topk(self, hidden: Tensor, unknown_topk: int) -> tuple[Tensor | None, Tensor | None]: """Compute unknown head top-k indices for attribution.""" assert self.unknown_head is not None if self.unknown_head.factorize: if self.unknown_head.use_attention: _query = self.unknown_head.concept_query_projection(hidden.detach()) _, indices, logits = self.unknown_head.attention_features_topk_factorized(_query, k=unknown_topk, block_size=self.unknown_head.block_size) else: _, indices, logits = self.unknown_head.linear_features_topk_factorized(hidden.detach(), k=unknown_topk, block_size=self.unknown_head.block_size) else: _E = self.unknown_head._get_embedding_weight()[:self.unknown_head.n_concepts] if self.unknown_head.use_attention: _query = self.unknown_head.concept_query_projection(hidden.detach()) _, indices, logits = self.unknown_head.attention_features_topk_streaming(_query, _E, k=unknown_topk, block_size=self.unknown_head.block_size) else: _W = self.unknown_head.concept_predictor.weight[:self.unknown_head.n_concepts] _, indices, logits = self.unknown_head.linear_features_topk_streaming(hidden.detach(), _W, _E, k=unknown_topk, block_size=self.unknown_head.block_size) return (indices, logits) def _forward_with_injection(self, input_ids: Tensor, input_embeds: Tensor | None, position_injection: Tensor, inject_layer: int, inject_alpha: float) -> Tensor: """Forward through transformer with steering injection at specified layers.""" x = input_embeds if input_embeds is not None else self.transformer.tok_emb(input_ids) for i, block in enumerate(self.transformer.blocks): x = block(x) if i + 1 >= inject_layer: x = x + inject_alpha * position_injection x = self.transformer.ln_f(x) return x @torch.no_grad() def intervene(self, input_ids: Tensor, known: dict[int, float] | None=None, unknown: dict[int, float] | None=None, positions: Tensor | None=None) -> tuple[Tensor, InterpretableOutput]: """ Run inference with concept interventions. Args: input_ids: Input token IDs (B, T) known: Dict mapping known concept IDs to intervention strengths unknown: Dict mapping unknown concept IDs to intervention strengths positions: Bool mask of positions to intervene (B, T). Default: all. Returns: logits: LM logits (B, T, V) outputs: InterpretableOutput """ B, T = input_ids.shape device = input_ids.device if positions is None: positions = torch.ones(B, T, dtype=torch.bool, device=device) int_known_ids, int_known_vals = (None, None) if known is not None and len(known) > 0: int_known_ids, int_known_vals = self._build_intervention_tensors(known, B, T, positions, device) int_unknown_ids, int_unknown_vals = (None, None) if unknown is not None and len(unknown) > 0: int_unknown_ids, int_unknown_vals = self._build_intervention_tensors(unknown, B, T, positions, device) return self(input_ids, intervene_known_ids=int_known_ids, intervene_known_vals=int_known_vals, intervene_unknown_ids=int_unknown_ids, intervene_unknown_vals=int_unknown_vals, minimal_output=False) @staticmethod def _build_intervention_tensors(interventions: dict[int, float], B: int, T: int, positions: Tensor, device: torch.device) -> tuple[Tensor, Tensor]: """Build intervention tensors for concept steering.""" K = len(interventions) concept_ids = list(interventions.keys()) directions = list(interventions.values()) ids = torch.full((B, T, K), -1, dtype=torch.long, device=device) vals = torch.zeros((B, T, K), dtype=torch.float32, device=device) concept_tensor = torch.tensor(concept_ids, device=device) direction_tensor = torch.tensor(directions, dtype=torch.float32, device=device) n_active = int(positions.sum().item()) ids[positions] = concept_tensor.unsqueeze(0).expand(n_active, -1) vals[positions] = direction_tensor.unsqueeze(0).expand(n_active, -1) return (ids, vals) def get_num_params(self, non_embedding: bool=True) -> int: n_params = sum((p.numel() for p in self.parameters())) if non_embedding and hasattr(self.transformer, 'tok_emb'): n_params -= self.transformer.tok_emb.weight.numel() return n_params from transformers import PreTrainedModel from .configuration_steerling import SteerlingConfig # CausalDiffusionLM is the backbone — alias to HF-friendly name SteerlingBackbone = CausalDiffusionLM class SteerlingForCausalLM(PreTrainedModel): config_class = SteerlingConfig supports_gradient_checkpointing = False _tied_weights_keys = ["transformer.lm_head.weight"] def __init__(self, config: SteerlingConfig): super().__init__(config) # SteerlingConfig has all fields from both arch and concept configs self.concept_config = config self.transformer = SteerlingBackbone(config, config.vocab_size) self.known_head = ConceptHead( n_concepts=config.n_concepts, concept_dim=config.concept_dim, n_embd=config.n_embd, is_unknown=False, use_attention=config.use_attention_known, topk=config.topk_known, topk_features=config.topk_known_features, block_size=config.concept_block_size, pad_multiple=config.pad_multiple, store_unknown_weights=False, apply_topk_to_unknown=False, topk_on_logits=config.topk_on_logits, factorize=False, ) if config.use_unknown: self.unknown_head = ConceptHead( n_concepts=config.n_unknown_concepts, concept_dim=config.concept_dim, n_embd=config.n_embd, is_unknown=True, use_attention=config.use_attention_unknown, topk=config.unknown_topk, block_size=config.concept_block_size, pad_multiple=config.pad_multiple, store_unknown_weights=config.store_unknown_weights, apply_topk_to_unknown=config.apply_topk_to_unknown, topk_on_logits=config.topk_on_logits, factorize=config.factorize_unknown, factorize_rank=config.factorize_rank, ) else: self.unknown_head = None self.post_init() def _init_weights(self, module): pass def _tie_weights(self): if self.config.weight_sharing: self.transformer.lm_head.weight = self.transformer.tok_emb.weight def forward(self, input_ids=None, **kwargs): if self.config.interpretable: return InterpretableCausalDiffusionLM.forward(self, input_ids, **kwargs) else: kwargs.pop('minimal_output', None) return CausalDiffusionLM.forward(self, input_ids, **kwargs)