steerling-8b / modeling_steerling.py
AyaGL's picture
Update modeling_steerling.py
32dc098 verified
from __future__ import annotations
# Auto-generated by scripts/build_hf_files_v3.py — do not edit manually.
import torch
import torch.nn as nn
import torch.nn.functional as F
import logging
import os
from functools import partial
from typing import TYPE_CHECKING, Any
from dataclasses import dataclass
from torch import Tensor
import math
import warnings
# ======================================================================
# steerling/models/layers/primitives.py
# ======================================================================
class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization.
"""
def __init__(self, config, size: int | None=None):
super().__init__()
self.eps = getattr(config, 'norm_eps', 1e-05)
norm_size = size if size is not None else config.n_embd
self.weight = nn.Parameter(torch.ones(norm_size))
def forward(self, x: torch.Tensor) -> torch.Tensor:
og = x.dtype
x = x.float()
var = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(var + self.eps)
return (self.weight * x).to(og)
class BufferCache:
"""Simple cache for storing tensors (used by RotaryEmbedding)."""
def __init__(self):
self._cache: dict[str, torch.Tensor] = {}
def get(self, key: str) -> torch.Tensor | None:
return self._cache.get(key)
def __setitem__(self, key: str, value: torch.Tensor):
self._cache[key] = value
def __getitem__(self, key: str) -> torch.Tensor:
return self._cache[key]
class RotaryEmbedding(nn.Module):
"""
Rotary Position Embeddings (RoPE).
Applies rotary embeddings to queries and keys for position information.
Args:
dim: Dimension of the rotary embeddings (typically head_dim)
max_seq_len: Maximum sequence length to cache
base: Base for inverse frequency computation (theta)
rope_full_precision: Whether to compute RoPE in full precision
"""
def __init__(self, dim: int, max_seq_len: int=2048, base: float=10000.0, rope_full_precision: bool=True):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.rope_theta = base
self.rope_full_precision = rope_full_precision
self.__cache = BufferCache()
self.get_rotary_embedding(max_seq_len, torch.device('cpu'))
def get_rotary_embedding(self, seq_len: int, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
"""Get or compute rotary embeddings for given sequence length."""
pos_sin = self.__cache.get('rope_pos_sin')
pos_cos = self.__cache.get('rope_pos_cos')
if pos_sin is not None and pos_cos is not None and (pos_sin.shape[-2] >= seq_len) and (pos_cos.shape[-2] >= seq_len):
if pos_sin.device != device:
pos_sin = pos_sin.to(device)
self.__cache['rope_pos_sin'] = pos_sin
if pos_cos.device != device:
pos_cos = pos_cos.to(device)
self.__cache['rope_pos_cos'] = pos_cos
return (pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :])
with torch.autocast(device.type, enabled=False):
inv_freq = 1.0 / self.rope_theta ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float) / self.dim)
seq = torch.arange(seq_len, device=device, dtype=torch.float)
freqs = torch.outer(seq, inv_freq)
positions = torch.cat((freqs, freqs), dim=-1)
pos_sin = positions.sin()[None, None, :, :]
pos_cos = positions.cos()[None, None, :, :]
self.__cache['rope_pos_sin'] = pos_sin
self.__cache['rope_pos_cos'] = pos_cos
return (pos_sin, pos_cos)
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
"""Rotate half the hidden dims of the input."""
B, nh, T, hs = x.size()
x = x.view(B, nh, T, 2, hs // 2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""Apply rotary position embeddings to input tensor."""
return (t * pos_cos + self.rotate_half(t) * pos_sin).to(t.dtype)
def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary embeddings to queries and keys."""
if self.rope_full_precision:
q_, k_ = (q.float(), k.float())
else:
q_, k_ = (q, k)
with torch.autocast(q.device.type, enabled=False):
query_len, key_len = (q_.shape[-2], k_.shape[-2])
pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
pos_sin = pos_sin.type_as(q_)
pos_cos = pos_cos.type_as(q_)
q_ = self.apply_rotary_pos_emb(pos_sin[:, :, key_len - query_len:key_len, :], pos_cos[:, :, key_len - query_len:key_len, :], q_)
k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
return (q_.type_as(q), k_.type_as(k))
class MLP(nn.Module):
"""
Multi-Layer Perceptron with SwiGLU or standard activation.
Args:
config: Model config with n_embd, mlp_ratio, use_bias, mlp_type, activation
"""
def __init__(self, config):
super().__init__()
if hasattr(config, 'intermediate_size') and config.intermediate_size is not None:
intermediate_size = config.intermediate_size
else:
intermediate_size = getattr(config, 'mlp_ratio', 4) * config.n_embd
use_bias = config.use_bias
mlp_type = config.mlp_type
if mlp_type == 'swiglu':
self.c_fc = nn.Linear(config.n_embd, 2 * intermediate_size, bias=use_bias)
self.c_proj = nn.Linear(intermediate_size, config.n_embd, bias=use_bias)
self.activation = None
else:
self.c_fc = nn.Linear(config.n_embd, intermediate_size, bias=use_bias)
self.c_proj = nn.Linear(intermediate_size, config.n_embd, bias=use_bias)
act_map = {'gelu': nn.GELU(approximate='tanh'), 'relu': nn.ReLU(), 'silu': nn.SiLU()}
self.activation = act_map[config.activation]
self.c_proj.SCALE_INIT = 1
self.config = config
def forward(self, x: torch.Tensor) -> torch.Tensor:
mlp_type = getattr(self.config, 'mlp_type', 'swiglu')
if mlp_type == 'swiglu':
gate_up = self.c_fc(x)
up, gate = gate_up.chunk(2, dim=-1)
intermediate = F.silu(gate) * up
else:
intermediate = self.c_fc(x)
intermediate = self.activation(intermediate)
return self.c_proj(intermediate)
# ======================================================================
# steerling/models/layers/causal_diffusion_layers.py
# ======================================================================
logger = logging.getLogger(__name__)
try:
from torch.nn.attention.flex_attention import BlockMask, _dense_to_ordered, flex_attention
_FLEX_ATTN_AVAILABLE = True
except ImportError:
_FLEX_ATTN_AVAILABLE = False
BlockMask: Any = None
flex_attention: Any = None
_dense_to_ordered: Any = None
if os.environ.get('STEERLING_USE_FLEX_ATTN', '0') != '1':
_FLEX_ATTN_AVAILABLE = False
if TYPE_CHECKING:
from torch.nn.attention.flex_attention import BlockMask as BlockMaskType
from steerling.configs.causal_diffusion import CausalDiffusionConfig
if torch.cuda.is_available() and _FLEX_ATTN_AVAILABLE:
compiled_flex_attention = torch.compile(flex_attention, fullgraph=True)
else:
compiled_flex_attention = flex_attention
def block_causal_mask_mod(b: Any, h: Any, q_idx: torch.Tensor, kv_idx: torch.Tensor, *, block_size: int) -> torch.Tensor:
"""Block-causal mask: causal across blocks, bidirectional within blocks."""
return q_idx // block_size >= kv_idx // block_size
def fast_create_block_causal_mask(attn_block_size: int, seq_length: int, mask_block_size: int, device: torch.device) -> BlockMaskType:
"""
Fast block-causal mask creation for flex_attention.
Analytically computes the sparse block structure instead of evaluating
the mask function at every position.
"""
if not _FLEX_ATTN_AVAILABLE or _dense_to_ordered is None or BlockMask is None:
raise RuntimeError('flex_attention not available')
num_mask_blocks = -(-seq_length // mask_block_size)
attn_blocks_per_mask_block, rem = divmod(mask_block_size, attn_block_size)
if rem != 0:
raise ValueError(f'mask_block_size ({mask_block_size}) must be divisible by attn_block_size ({attn_block_size})')
num_attn_blocks = num_mask_blocks * attn_blocks_per_mask_block
lowres_attn_mask = torch.tril(torch.ones(num_attn_blocks, num_attn_blocks, dtype=torch.bool, device=device))
block_attn_count = lowres_attn_mask.reshape(num_mask_blocks, attn_blocks_per_mask_block, num_mask_blocks, attn_blocks_per_mask_block).permute(0, 2, 1, 3).sum(dim=[-2, -1])
max_count = attn_blocks_per_mask_block * attn_blocks_per_mask_block
full_block_mask = block_attn_count == max_count
if seq_length % mask_block_size > 0:
full_block_mask[-1, :] = False
normal_block_mask = (block_attn_count > 0) & ~full_block_mask
kv_num_blocks, kv_indices = _dense_to_ordered(normal_block_mask)
full_kv_num_blocks, full_kv_indices = _dense_to_ordered(full_block_mask)
q_num_blocks, q_indices = _dense_to_ordered(normal_block_mask.transpose(-2, -1))
full_q_num_blocks, full_q_indices = _dense_to_ordered(full_block_mask.transpose(-2, -1))
return BlockMask(seq_lengths=(seq_length, seq_length), kv_num_blocks=kv_num_blocks[None, None, ...], kv_indices=kv_indices[None, None, ...], full_kv_num_blocks=full_kv_num_blocks[None, None, ...], full_kv_indices=full_kv_indices[None, None, ...], q_num_blocks=q_num_blocks[None, None, ...], q_indices=q_indices[None, None, ...], full_q_num_blocks=full_q_num_blocks[None, None, ...], full_q_indices=full_q_indices[None, None, ...], mask_mod=partial(block_causal_mask_mod, block_size=attn_block_size), BLOCK_SIZE=(mask_block_size, mask_block_size))
def sdpa_with_block_causal_mask(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, diff_block_size: int, mask_cache: dict[str, torch.Tensor], enable_gqa: bool=False) -> torch.Tensor:
"""Fallback using SDPA with dense mask when flex_attention unavailable."""
B, H, T, D = q.shape
device = q.device
dtype = q.dtype
cache_key = f'sdpa_{T}_{device}_{dtype}'
if cache_key not in mask_cache:
q_idx = torch.arange(T, device=device).unsqueeze(1)
kv_idx = torch.arange(T, device=device).unsqueeze(0)
bool_mask = q_idx // diff_block_size >= kv_idx // diff_block_size
attn_mask = torch.zeros(T, T, device=device, dtype=dtype)
attn_mask.masked_fill_(~bool_mask, float('-inf'))
mask_cache[cache_key] = attn_mask
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask_cache[cache_key], dropout_p=0.0, is_causal=False, enable_gqa=enable_gqa)
class BlockCausalAttention(nn.Module):
"""Block-causal self-attention with FlexAttention and optional GQA."""
FLEX_MASK_BLOCK_SIZE = 128
def __init__(self, config: CausalDiffusionConfig) -> None:
super().__init__()
if not hasattr(config, 'diff_block_size'):
raise ValueError("BlockCausalAttention requires 'diff_block_size' in config.")
assert config.n_embd % config.n_head == 0
self.config = config
self.n_head = config.n_head
self.n_embd = config.n_embd
self.head_dim = config.n_embd // config.n_head
n_kv = getattr(config, 'n_kv_heads', None)
self.n_kv_heads = self.n_head if n_kv is None else int(n_kv)
if self.n_kv_heads <= 0:
raise ValueError(f'n_kv_heads must be >= 1 (got {self.n_kv_heads})')
if self.n_head % self.n_kv_heads != 0:
raise ValueError(f'n_head ({self.n_head}) must be divisible by n_kv_heads ({self.n_kv_heads})')
self.kv_repeat = self.n_head // self.n_kv_heads
use_bias = getattr(config, 'use_bias', False)
kv_out = self.n_kv_heads * self.head_dim
attn_out = self.n_embd + 2 * kv_out
self.c_attn = nn.Linear(config.n_embd, attn_out, bias=use_bias)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=use_bias)
self.c_proj.SCALE_INIT = 1
if getattr(config, 'use_qk_norm', False):
if getattr(config, 'use_rms_norm', True):
self.q_norm: nn.Module | None = RMSNorm(config, size=self.head_dim)
self.k_norm: nn.Module | None = RMSNorm(config, size=self.head_dim)
else:
self.q_norm = nn.LayerNorm(self.head_dim)
self.k_norm = nn.LayerNorm(self.head_dim)
else:
self.q_norm = None
self.k_norm = None
if getattr(config, 'use_rope', True):
self.rope: RotaryEmbedding | None = RotaryEmbedding(dim=self.head_dim, max_seq_len=config.block_size, base=getattr(config, 'rope_base', 500000.0), rope_full_precision=getattr(config, 'rope_full_precision', True))
else:
self.rope = None
self._mask_cache: dict = {}
self._sdpa_mask_cache: dict[str, torch.Tensor] = {}
self._logged_attention_mode = False
def _get_block_mask(self, T: int, device: torch.device):
cache_key = f'flex_{T}_{device}'
if cache_key not in self._mask_cache:
diff_block_size = self.config.diff_block_size
mask_block_size = self.FLEX_MASK_BLOCK_SIZE
if mask_block_size % diff_block_size != 0:
mask_block_size = diff_block_size * (mask_block_size // diff_block_size)
if mask_block_size == 0:
mask_block_size = diff_block_size
self._mask_cache[cache_key] = fast_create_block_causal_mask(attn_block_size=diff_block_size, seq_length=T, mask_block_size=mask_block_size, device=device)
return self._mask_cache[cache_key]
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.size()
device = x.device
use_flex = _FLEX_ATTN_AVAILABLE and x.is_cuda and (flex_attention is not None)
if not self._logged_attention_mode:
self._logged_attention_mode = True
mode = 'flex_attention' if use_flex else 'SDPA fallback'
logger.debug(f'[CausalDiffusion] Using {mode} with GQA (n_head={self.n_head}, n_kv_heads={self.n_kv_heads})')
qkv = self.c_attn(x)
clip_qkv = getattr(self.config, 'clip_qkv', None)
if clip_qkv is not None:
qkv = qkv.clamp(min=-clip_qkv, max=clip_qkv)
kv_dim = self.n_kv_heads * self.head_dim
q, k, v = qkv.split([self.n_embd, kv_dim, kv_dim], dim=2)
q = q.reshape(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
if self.q_norm is not None and self.k_norm is not None:
q = self.q_norm(q)
k = self.k_norm(k)
if self.rope is not None:
q, k = self.rope(q, k)
if use_flex:
block_mask = self._get_block_mask(T, device)
assert flex_attention is not None and compiled_flex_attention is not None
if q.is_cuda:
y = compiled_flex_attention(q, k, v, block_mask=block_mask, enable_gqa=True)
else:
y = flex_attention(q, k, v, block_mask=block_mask, enable_gqa=True)
else:
y = sdpa_with_block_causal_mask(q, k, v, diff_block_size=self.config.diff_block_size, mask_cache=self._sdpa_mask_cache, enable_gqa=True)
y = y.transpose(1, 2).reshape(B, T, C)
y = self.c_proj(y)
return y
class CausalDiffusionBlock(nn.Module):
"""Transformer block for CausalDiffusionLM (block-causal attention + MLP)."""
def __init__(self, config: CausalDiffusionConfig) -> None:
super().__init__()
use_rms_norm = getattr(config, 'use_rms_norm', True)
if use_rms_norm:
self.ln_1: nn.Module = RMSNorm(config)
self.ln_2: nn.Module = RMSNorm(config)
else:
self.ln_1 = nn.LayerNorm(config.n_embd)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.norm_order = getattr(config, 'norm_order', 'post')
self.attn = BlockCausalAttention(config)
self.mlp = MLP(config)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.norm_order == 'pre':
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
else:
x = x + self.ln_1(self.attn(x))
x = x + self.ln_2(self.mlp(x))
return x
# ======================================================================
# steerling/models/causal_diffusion.py
# ======================================================================
class CausalDiffusionLM(nn.Module):
"""
CausalDiffusionLM transformer backbone with block-causal attention.
Pure compute graph — no training code, no loss logic.
Args:
config: CausalDiffusionConfig with model hyperparameters
vocab_size: Vocabulary size (including special tokens)
"""
def __init__(self, config: CausalDiffusionConfig, vocab_size: int) -> None:
super().__init__()
self.config = config
self.vocab_size = vocab_size
self.tok_emb = nn.Embedding(vocab_size, config.n_embd)
self.blocks = nn.ModuleList([CausalDiffusionBlock(config) for _ in range(config.n_layers)])
if config.use_rms_norm:
self.ln_f: nn.Module = RMSNorm(config)
else:
self.ln_f = nn.LayerNorm(config.n_embd)
self.lm_head = nn.Linear(config.n_embd, vocab_size, bias=False)
if config.weight_sharing:
self.tok_emb.weight = self.lm_head.weight
def forward(self, input_ids: torch.Tensor, *, input_embeds: torch.Tensor | None=None, return_hidden: bool=False) -> torch.Tensor:
"""
Forward pass.
Args:
input_ids: Token indices [B, T] (may contain mask tokens)
input_embeds: Pre-computed embeddings [B, T, D]. If provided, input_ids is ignored.
return_hidden: If True, return hidden states before lm_head.
Returns:
logits [B, T, vocab_size] or hidden_states [B, T, n_embd]
"""
if input_embeds is not None:
x = input_embeds
elif input_ids is not None:
x = self.tok_emb(input_ids)
else:
raise ValueError('Either input_ids or input_embeds must be provided')
for block in self.blocks:
x = block(x)
x = self.ln_f(x)
if return_hidden:
return x
return self.lm_head(x)
def get_num_params(self, non_embedding: bool=True) -> int:
"""Return number of parameters."""
n_params = sum((p.numel() for p in self.parameters()))
if non_embedding:
n_params -= self.tok_emb.weight.numel()
return n_params
def _restore_weight_tying(self) -> None:
"""Re-establish weight tying after to_empty() or device transfer."""
if self.config.weight_sharing:
self.tok_emb.weight = self.lm_head.weight
def _init_weights(self, module: nn.Module) -> None:
"""Initialize model weights (used for fresh models, not loaded checkpoints)."""
if isinstance(module, nn.Linear):
std = 0.02
if hasattr(module, 'SCALE_INIT'):
std *= (2 * self.config.n_layers) ** (-0.5)
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, RMSNorm):
torch.nn.init.ones_(module.weight)
# ======================================================================
# steerling/models/interpretable/outputs.py
# ======================================================================
@dataclass
class InterpretableOutput:
"""
Full output from InterpretableCausalDiffusionLM; it contains all decomposition components for attribution and analysis.
"""
hidden: Tensor
known_features: Tensor
known_logits: Tensor | None
known_gt_features: Tensor | None
known_predicted: Tensor
known_weights: Tensor | None
known_topk_indices: Tensor | None
known_topk_logits: Tensor | None
unk: Tensor
unk_hat: Tensor | None
unk_for_lm: Tensor
unknown_logits: Tensor | None
unknown_weights: Tensor | None
unknown_topk_indices: Tensor | None
unknown_topk_logits: Tensor | None
composed: Tensor
epsilon: Tensor | None
epsilon_true: Tensor | None
# ======================================================================
# steerling/models/interpretable/concept_head.py
# ======================================================================
logger = logging.getLogger(__name__)
LARGE_CONCEPT_THRESHOLD = 50000
@dataclass
class ConceptHeadOutput:
"""Output from ConceptHead forward pass.
Attributes:
features: Final concept features after teacher forcing/intervention (B, T, D)
gt_features: Ground truth pooled features. None for unknown heads. (B, T, D) or None
logits: Full concept logits (B, T, C). Only set if return_logits=True. Usually None.
predicted: Predicted features before teacher forcing mixing (B, T, D)
weights: Full concept weights (B, T, C). Only set if return_logits=True. Usually None.
topk_indices: Top-k concept indices (B, T, k). Set when using streaming top-k.
topk_logits: Logits for top-k concepts (B, T, k). Set when using streaming top-k.
hidden: Hidden states passed to this head (B, T, D). Stored for attribution.
"""
features: Tensor
gt_features: Tensor | None
logits: Tensor | None
predicted: Tensor
weights: Tensor | None = None
topk_indices: Tensor | None = None
topk_logits: Tensor | None = None
hidden: Tensor | None = None
class ConceptHead(nn.Module):
"""
Concept decomposition head supporting both known and unknown concepts.
Memory-efficient implementation that avoids (B, T, C) allocations by default.
Modes:
- Known (is_unknown=False): Supports GT, teacher forcing, top-k, interventions
- Unknown (is_unknown=True): No GT, no teacher forcing
Architectures:
- use_attention=False: Linear predictor (n_embd -> n_concepts)
- use_attention=True: Query projection + sigmoid attention over embeddings
Factorization (for large unknown heads):
- factorize=False: Dense embeddings (C, D) and predictor (D, C)
- factorize=True: Factorized embeddings (C, r) @ (r, D) where r << D
Reduces memory by ~10-20x for large C
Memory Safety:
- Unknown heads with n_concepts > 50k cannot use dense operations
- Interventions are only supported for known heads
- return_logits=True is forbidden for large unknown heads
- All tensor indexing uses F.embedding for DTensor safety
Args:
n_concepts: Number of concepts (C)
concept_dim: Dimension of concept embeddings (should equal n_embd)
n_embd: Model hidden dimension
is_unknown: If True, skip GT pooling and teacher forcing
use_attention: If True, use attention; else use linear predictor
topk: Top-k sparsity for concept weights. None = no sparsity.
block_size: Block size for memory-efficient operations
pad_multiple: Pad n_concepts to a multiple of this for efficiency
store_unknown_weights: If True and use_attention & is_unknown, store logits/weights
apply_topk_to_unknown: If True, also apply top-k to unknown concepts
topk_on_logits: If True, apply top-k on logits (then sigmoid). If False, on weights.
teacher_force_alpha: If None, hard TF. If in [0,1], soft mixing.
factorize: If True, use low-rank factorized embeddings
factorize_rank: Rank for factorization (r). Lower = less memory, less expressivity.
"""
class ConceptPooling(nn.Module):
"""Memory-efficient sum pooling using scatter-add."""
def __init__(self, concept_dim: int):
super().__init__()
self.concept_dim = concept_dim
def forward(self, concept_ids: Tensor, concept_mask: Tensor, concept_embeddings: nn.Embedding) -> Tensor:
"""
Pool concept embeddings based on ground truth IDs.
Uses scatter-add to avoid (B, T, K, D) allocation when K is sparse.
Args:
concept_ids: (B, T, K) concept indices, -1 for invalid
concept_mask: (B, T, K) boolean mask for valid concepts
concept_embeddings: Embedding layer to look up
Returns:
Pooled features (B, T, D)
"""
B, T, K = concept_ids.shape
D = concept_embeddings.embedding_dim
device = concept_ids.device
valid_mask = concept_mask & (concept_ids != -1)
pooled = torch.zeros(B, T, D, device=device, dtype=concept_embeddings.weight.dtype)
if not valid_mask.any():
return pooled
b_idx, t_idx, k_idx = torch.where(valid_mask)
c_ids = concept_ids[b_idx, t_idx, k_idx].long()
emb = concept_embeddings(c_ids)
flat_idx = b_idx * T + t_idx
flat_idx = flat_idx.unsqueeze(-1).expand(-1, D)
pooled_flat = pooled.view(B * T, D)
pooled_flat.scatter_add_(0, flat_idx, emb)
return pooled.view(B, T, D)
def __init__(self, n_concepts: int, concept_dim: int, n_embd: int, is_unknown: bool=False, use_attention: bool=False, topk: int | None=16, topk_features: int | None=None, block_size: int=8192, *, pad_multiple: int=16, store_unknown_weights: bool=False, apply_topk_to_unknown: bool=False, topk_on_logits: bool=False, factorize: bool=False, factorize_rank: int=256):
super().__init__()
self.n_concepts = n_concepts
self.concept_dim = concept_dim
self.n_embd = n_embd
self.is_unknown = is_unknown
self.use_attention = use_attention
self.topk = topk
self.topk_features = topk_features if topk_features is not None else topk
self.block_size = block_size
self.pad_multiple = pad_multiple
self.store_unknown_weights = store_unknown_weights
self.apply_topk_to_unknown = apply_topk_to_unknown
self.topk_on_logits = topk_on_logits
self.factorize = factorize
self.factorize_rank = factorize_rank
self._is_large = n_concepts > LARGE_CONCEPT_THRESHOLD
self.n_concepts_padded = (n_concepts + pad_multiple - 1) // pad_multiple * pad_multiple
if factorize:
self.embedding_coef = nn.Embedding(self.n_concepts_padded, factorize_rank)
self.embedding_basis = nn.Linear(factorize_rank, concept_dim, bias=False)
self.concept_embedding = None
if not use_attention:
self.predictor_down = nn.Linear(n_embd, factorize_rank, bias=False)
self.predictor_up = nn.Linear(factorize_rank, self.n_concepts_padded, bias=False)
self.concept_predictor = None
else:
self.concept_query_projection = nn.Linear(n_embd, concept_dim, bias=False)
self.predictor_down = None
self.predictor_up = None
self.concept_predictor = None
dense_params = n_concepts * concept_dim * 2
factorized_params = n_concepts * factorize_rank + factorize_rank * concept_dim + (n_embd * factorize_rank + factorize_rank * n_concepts if not use_attention else 0)
logger.info(f'[ConceptHead] Factorized mode: {n_concepts} concepts, rank={factorize_rank}')
logger.info(f'[ConceptHead] Memory: {dense_params * 2 / 1000000000.0:.2f} GB (dense) -> {factorized_params * 2 / 1000000000.0:.2f} GB (factorized) = {(1 - factorized_params / dense_params) * 100:.1f}% reduction')
else:
self.concept_embedding = nn.Embedding(self.n_concepts_padded, concept_dim)
self.embedding_coef = None
self.embedding_basis = None
if use_attention:
self.concept_query_projection = nn.Linear(n_embd, concept_dim, bias=False)
self.concept_predictor = None
else:
self.concept_predictor = nn.Linear(n_embd, self.n_concepts_padded, bias=False)
self.predictor_down = None
self.predictor_up = None
self.concept_pooling = self.ConceptPooling(concept_dim)
if self.topk_features != self.topk:
logger.info(f"[ConceptHead] {('Unknown' if is_unknown else 'Known')} head: topk={self.topk} (loss), topk_features={self.topk_features} (features)")
if is_unknown and apply_topk_to_unknown:
logger.info(f'[ConceptHead] Unknown head: apply_topk_to_unknown=True, topk={self.topk}')
self._init_weights()
def _init_weights(self):
"""Initialize weights with small values."""
if self.factorize:
nn.init.normal_(self.embedding_coef.weight, mean=0.0, std=0.02)
nn.init.normal_(self.embedding_basis.weight, mean=0.0, std=0.02)
if self.predictor_down is not None:
nn.init.normal_(self.predictor_down.weight, mean=0.0, std=0.02)
if self.predictor_up is not None:
nn.init.normal_(self.predictor_up.weight, mean=0.0, std=0.02)
else:
if self.concept_embedding is not None:
nn.init.normal_(self.concept_embedding.weight, mean=0.0, std=0.02)
if self.concept_predictor is not None:
nn.init.normal_(self.concept_predictor.weight, mean=0.0, std=0.02)
if hasattr(self, 'concept_query_projection') and self.concept_query_projection is not None:
nn.init.normal_(self.concept_query_projection.weight, mean=0.0, std=0.02)
def _check_dense_allowed(self, operation: str) -> None:
"""Raise error if dense operations are requested for large unknown heads."""
if self.is_unknown and self._is_large:
raise ValueError(f'{operation} requested for unknown head with {self.n_concepts} concepts. This would allocate multi-GB tensors. Use streaming mode instead. (Threshold: {LARGE_CONCEPT_THRESHOLD})')
@staticmethod
def _safe_index(weight: Tensor, indices: Tensor) -> Tensor:
"""
DTensor-safe indexing using F.embedding.
Replaces weight[indices] which crashes under FSDP2/DTensor.
Args:
weight: (N, D) weight matrix
indices: (...) indices to select
Returns:
(..., D) selected embeddings
"""
original_shape = indices.shape
flat_indices = indices.reshape(-1)
flat_result = F.embedding(flat_indices, weight)
return flat_result.reshape(*original_shape, -1)
def _get_embedding_weight(self) -> Tensor:
"""
Get full embedding matrix.
For dense: returns concept_embedding.weight
For factorized: computes coef @ basis (materializes full matrix)
Returns:
(C, D) embedding matrix
"""
if self.concept_embedding is not None:
return self.concept_embedding.weight
else:
return self.embedding_basis(self.embedding_coef.weight)
def _get_embedding(self, indices: Tensor) -> Tensor:
"""
Get embeddings for specific indices (DTensor-safe).
For dense: uses F.embedding
For factorized: looks up coef, then applies basis
Args:
indices: (...) concept indices
Returns:
(..., D) embeddings
"""
if self.concept_embedding is not None:
return self.concept_embedding(indices)
else:
coef = self.embedding_coef(indices)
return self.embedding_basis(coef)
def _get_predictor_weight(self) -> Tensor | None:
"""
Get full predictor weight matrix (for linear path only).
Returns:
(C, D) predictor weight, or None if using attention
"""
if self.concept_predictor is not None:
return self.concept_predictor.weight
elif self.predictor_down is not None and self.predictor_up is not None:
return self.predictor_up.weight @ self.predictor_down.weight
else:
return None
@staticmethod
def _merge_topk(topv: Tensor, topi: Tensor, v_blk: Tensor, i_blk: Tensor, k: int) -> tuple[Tensor, Tensor]:
"""Efficient merge of two top-k sets. Memory: O(BT × 2k)."""
cand_v = torch.cat([topv, v_blk], dim=1)
cand_i = torch.cat([topi, i_blk], dim=1)
new_v, sel = torch.topk(cand_v, k, dim=1)
new_i = torch.gather(cand_i, 1, sel)
return (new_v, new_i)
@staticmethod
def linear_block_features(hidden: Tensor, predictor_weight: Tensor, embeddings: Tensor, block_size: int=4096) -> Tensor:
"""
Memory-efficient linear prediction without materializing (B, T, C).
Args:
hidden: (B, T, D)
predictor_weight: (C, D)
embeddings: (C, D)
block_size: Concepts per block
Returns:
Features (B, T, D)
"""
B, T, D = hidden.shape
C = predictor_weight.size(0)
output = torch.zeros(B, T, D, dtype=hidden.dtype, device=hidden.device)
flat_h = hidden.reshape(-1, D)
W_t = predictor_weight.t().contiguous()
for start in range(0, C, block_size):
end = min(start + block_size, C)
logits_block = (flat_h @ W_t[:, start:end]).to(torch.float32)
logits_block = logits_block.clamp(-15, 15)
weights_block = torch.sigmoid(logits_block)
E_block = embeddings[start:end].to(weights_block.dtype)
output.add_((weights_block @ E_block).reshape(B, T, D))
return output.to(hidden.dtype)
@staticmethod
def attention_block_features(query: Tensor, embeddings: Tensor, block_size: int=4096) -> Tensor:
"""Memory-efficient attention features without materializing (B, T, C)."""
B, T, D = query.shape
C = embeddings.shape[0]
scale = 1.0 / math.sqrt(D)
flat_q = query.reshape(-1, D)
emb_T = embeddings.t().contiguous()
output = torch.zeros(B * T, D, dtype=query.dtype, device=query.device)
for start in range(0, C, block_size):
end = min(start + block_size, C)
scores = (flat_q @ emb_T[:, start:end]).to(torch.float32) * scale
scores = scores.clamp(-15, 15)
weights = torch.sigmoid(scores)
output.add_(weights @ embeddings[start:end].to(weights.dtype))
return output.reshape(B, T, D).to(query.dtype)
@staticmethod
def linear_features_topk_streaming(hidden: Tensor, predictor_weight: Tensor, embeddings: Tensor, k: int, block_size: int=4096, topk_on_logits: bool=False) -> tuple[Tensor, Tensor, Tensor]:
"""
Memory-efficient linear prediction with streaming top-k.
Uses merge-k-with-k to keep memory O(BT × k), not O(BT × block_size).
Args:
hidden: (B, T, D)
predictor_weight: (C, D)
embeddings: (C, D)
k: Number of top concepts
block_size: Concepts per block
topk_on_logits: If True, select top-k by logits; else by sigmoid
Returns:
features: (B, T, D) weighted concept features
topk_indices: (B, T, k) indices of top-k concepts
topk_logits: (B, T, k) logits for top-k concepts
"""
B, T, D = hidden.shape
C = predictor_weight.size(0)
BT = B * T
device = hidden.device
k = min(k, C)
flat_h = hidden.reshape(BT, D)
W_t = predictor_weight.t().contiguous()
topv = torch.full((BT, k), float('-inf'), device=device, dtype=hidden.dtype)
topi = torch.zeros((BT, k), device=device, dtype=torch.long)
for start in range(0, C, block_size):
end = min(start + block_size, C)
logits_blk = (flat_h @ W_t[:, start:end]).to(torch.float32).clamp_(-15, 15)
vals_blk = logits_blk if topk_on_logits else torch.sigmoid(logits_blk)
blk_k = min(k, end - start)
v_blk, idx_blk = torch.topk(vals_blk, blk_k, dim=1)
i_blk = idx_blk + start
if blk_k < k:
pad_v = torch.full((BT, k - blk_k), float('-inf'), device=device, dtype=torch.float32)
pad_i = torch.zeros((BT, k - blk_k), device=device, dtype=torch.long)
v_blk = torch.cat([v_blk, pad_v], dim=1)
i_blk = torch.cat([i_blk, pad_i], dim=1)
topv, topi = ConceptHead._merge_topk(topv, topi, v_blk, i_blk, k)
W_sel = ConceptHead._safe_index(predictor_weight, topi)
logits_sel = torch.einsum('bd,bkd->bk', flat_h.to(torch.float32), W_sel.to(torch.float32))
logits_sel = logits_sel.clamp(-15, 15)
del W_sel
weights_sel = torch.sigmoid(logits_sel)
E_sel = ConceptHead._safe_index(embeddings, topi)
features = torch.einsum('bk,bkd->bd', weights_sel.to(E_sel.dtype), E_sel)
return (features.reshape(B, T, D).to(hidden.dtype), topi.reshape(B, T, k), logits_sel.reshape(B, T, k))
@staticmethod
def attention_features_topk_streaming(query: Tensor, embeddings: Tensor, k: int, block_size: int=4096, topk_on_logits: bool=False) -> tuple[Tensor, Tensor, Tensor]:
"""Memory-efficient attention with streaming top-k."""
B, T, D = query.shape
C = embeddings.shape[0]
BT = B * T
device = query.device
scale = 1.0 / math.sqrt(D)
k = min(k, C)
flat_q = query.reshape(BT, D)
emb_T = embeddings.t().contiguous()
topv = torch.full((BT, k), float('-inf'), device=device, dtype=query.dtype)
topi = torch.zeros((BT, k), device=device, dtype=torch.long)
for start in range(0, C, block_size):
end = min(start + block_size, C)
logits_blk = (flat_q @ emb_T[:, start:end]).to(torch.float32) * scale
logits_blk = logits_blk.clamp(-15, 15)
vals_blk = logits_blk if topk_on_logits else torch.sigmoid(logits_blk)
blk_k = min(k, end - start)
v_blk, idx_blk = torch.topk(vals_blk, blk_k, dim=1)
i_blk = idx_blk + start
if blk_k < k:
pad_v = torch.full((BT, k - blk_k), float('-inf'), device=device, dtype=torch.float32)
pad_i = torch.zeros((BT, k - blk_k), device=device, dtype=torch.long)
v_blk = torch.cat([v_blk, pad_v], dim=1)
i_blk = torch.cat([i_blk, pad_i], dim=1)
topv, topi = ConceptHead._merge_topk(topv, topi, v_blk, i_blk, k)
E_sel = ConceptHead._safe_index(embeddings, topi)
logits_sel = torch.einsum('bd,bkd->bk', flat_q.to(torch.float32), E_sel.to(torch.float32)) * scale
logits_sel = logits_sel.clamp(-15, 15)
weights_sel = torch.sigmoid(logits_sel)
features = torch.einsum('bk,bkd->bd', weights_sel.to(E_sel.dtype), E_sel)
return (features.reshape(B, T, D).to(query.dtype), topi.reshape(B, T, k), logits_sel.reshape(B, T, k))
def attention_block_features_factorized(self, query: Tensor, block_size: int=4096) -> Tensor:
"""
Memory-efficient factorized attention over ALL concepts.
Uses factorized scoring and feature computation:
- Scoring: (query @ basis.T) @ coef.T instead of query @ E.T
- Features: (weights @ coef) @ basis instead of weights @ E
FLOPs: O(BT * r * (D + C)) instead of O(BT * D * C)
Args:
query: (B, T, D) query vectors from concept_query_projection
block_size: Concepts per block for chunked processing
Returns:
(B, T, D) weighted concept features
"""
assert self.factorize, 'Only valid for factorized head'
B, T, D = query.shape
BT = B * T
C = self.n_concepts
_ = self.factorize_rank
device = query.device
scale = 1.0 / math.sqrt(D)
flat_q = query.reshape(BT, D)
coef = self.embedding_coef.weight[:C]
basis_weight = self.embedding_basis.weight
q_compressed = flat_q @ basis_weight
output = torch.zeros(BT, D, dtype=query.dtype, device=device)
_ = (C + block_size - 1) // block_size
for _block_idx, start in enumerate(range(0, C, block_size)):
end = min(start + block_size, C)
coef_chunk = coef[start:end]
scores_chunk = (q_compressed @ coef_chunk.T).float() * scale
scores_chunk = scores_chunk.clamp(-15, 15)
weights_chunk = torch.sigmoid(scores_chunk)
weighted_coef = weights_chunk @ coef_chunk.float()
features_chunk = weighted_coef @ basis_weight.T.to(weighted_coef.dtype)
output.add_(features_chunk)
return output.reshape(B, T, D).to(query.dtype)
def attention_features_topk_factorized(self, query: Tensor, k: int, block_size: int=4096) -> tuple[Tensor, Tensor, Tensor]:
"""
Memory-efficient factorized attention with streaming top-k.
Pass 1: Find top-k concepts using factorized scoring
Pass 2: Compute features using only top-k embeddings
Args:
query: (B, T, D) query vectors
k: Number of top concepts per token
block_size: Concepts per block
Returns:
features: (B, T, D) weighted concept features
topk_indices: (B, T, k) top-k concept indices
topk_logits: (B, T, k) logits for top-k concepts
"""
assert self.factorize, 'Only valid for factorized head'
B, T, D = query.shape
BT = B * T
C = self.n_concepts
_ = self.factorize_rank
device = query.device
scale = 1.0 / math.sqrt(D)
k = min(k, C)
flat_q = query.reshape(BT, D)
coef = self.embedding_coef.weight[:C]
basis_weight = self.embedding_basis.weight
q_compressed = flat_q @ basis_weight
topv = torch.full((BT, k), float('-inf'), device=device, dtype=query.dtype)
topi = torch.zeros((BT, k), device=device, dtype=torch.long)
for start in range(0, C, block_size):
end = min(start + block_size, C)
coef_chunk = coef[start:end]
scores_chunk = q_compressed.float() @ coef_chunk.T.float() * scale
scores_chunk = scores_chunk.clamp(-15, 15)
blk_k = min(k, end - start)
v_chunk, idx_chunk = torch.topk(scores_chunk, blk_k, dim=1)
i_chunk = idx_chunk + start
if blk_k < k:
pad_v = torch.full((BT, k - blk_k), float('-inf'), device=device, dtype=torch.float32)
pad_i = torch.zeros((BT, k - blk_k), device=device, dtype=torch.long)
v_chunk = torch.cat([v_chunk, pad_v], dim=1)
i_chunk = torch.cat([i_chunk, pad_i], dim=1)
topv, topi = self._merge_topk(topv, topi, v_chunk, i_chunk, k)
coef_sel = self.embedding_coef(topi)
logits_sel = torch.einsum('br,bkr->bk', q_compressed.float(), coef_sel.float()) * scale
logits_sel = logits_sel.clamp(-15, 15)
weights_sel = torch.sigmoid(logits_sel)
weighted_coef = torch.einsum('bk,bkr->br', weights_sel.to(coef_sel.dtype), coef_sel)
features = weighted_coef @ basis_weight.T.to(weighted_coef.dtype)
return (features.reshape(B, T, D).to(query.dtype), topi.reshape(B, T, k), logits_sel.reshape(B, T, k))
def linear_block_features_factorized(self, hidden: Tensor, block_size: int=4096) -> Tensor:
"""
Memory-efficient factorized linear prediction over ALL concepts.
Uses factorized predictor: logits = hidden @ down @ up.T
Uses factorized embeddings: features = weights @ coef @ basis
Args:
hidden: (B, T, D) hidden states
block_size: Concepts per block
Returns:
(B, T, D) weighted concept features
"""
assert self.factorize, 'Only valid for factorized head'
assert self.predictor_down is not None, 'Linear path requires predictor'
B, T, D = hidden.shape
BT = B * T
C = self.n_concepts
_ = self.factorize_rank
device = hidden.device
flat_h = hidden.reshape(BT, D)
coef = self.embedding_coef.weight[:C]
basis_weight = self.embedding_basis.weight
down_weight = self.predictor_down.weight
up_weight = self.predictor_up.weight[:C]
h_compressed = flat_h @ down_weight.T
output = torch.zeros(BT, D, dtype=hidden.dtype, device=device)
for start in range(0, C, block_size):
end = min(start + block_size, C)
up_chunk = up_weight[start:end]
coef_chunk = coef[start:end]
logits_chunk = h_compressed.float() @ up_chunk.T.float()
logits_chunk = logits_chunk.clamp(-15, 15)
weights_chunk = torch.sigmoid(logits_chunk)
weighted_coef = weights_chunk @ coef_chunk.float()
features_chunk = weighted_coef @ basis_weight.T.to(weighted_coef.dtype)
output.add_(features_chunk)
return output.reshape(B, T, D).to(hidden.dtype)
def linear_features_topk_factorized(self, hidden: Tensor, k: int, block_size: int=4096) -> tuple[Tensor, Tensor, Tensor]:
"""
Memory-efficient factorized linear with streaming top-k.
Args:
hidden: (B, T, D) hidden states
k: Number of top concepts per token
block_size: Concepts per block
Returns:
features: (B, T, D) weighted concept features
topk_indices: (B, T, k) top-k concept indices
topk_logits: (B, T, k) logits for top-k concepts
"""
assert self.factorize, 'Only valid for factorized head'
assert self.predictor_down is not None, 'Linear path requires predictor'
B, T, D = hidden.shape
BT = B * T
C = self.n_concepts
_ = self.factorize_rank
device = hidden.device
k = min(k, C)
flat_h = hidden.reshape(BT, D)
down_weight = self.predictor_down.weight
up_weight = self.predictor_up.weight[:C]
basis_weight = self.embedding_basis.weight
h_compressed = flat_h @ down_weight.T
topv = torch.full((BT, k), float('-inf'), device=device, dtype=hidden.dtype)
topi = torch.zeros((BT, k), device=device, dtype=torch.long)
for start in range(0, C, block_size):
end = min(start + block_size, C)
up_chunk = up_weight[start:end]
logits_chunk = h_compressed.float() @ up_chunk.T.float()
logits_chunk = logits_chunk.clamp(-15, 15)
blk_k = min(k, end - start)
v_chunk, idx_chunk = torch.topk(logits_chunk, blk_k, dim=1)
i_chunk = idx_chunk + start
if blk_k < k:
pad_v = torch.full((BT, k - blk_k), float('-inf'), device=device, dtype=torch.float32)
pad_i = torch.zeros((BT, k - blk_k), device=device, dtype=torch.long)
v_chunk = torch.cat([v_chunk, pad_v], dim=1)
i_chunk = torch.cat([i_chunk, pad_i], dim=1)
topv, topi = self._merge_topk(topv, topi, v_chunk, i_chunk, k)
coef_sel = self.embedding_coef(topi)
up_sel = self._safe_index(self.predictor_up.weight[:C], topi)
logits_sel = torch.einsum('br,bkr->bk', h_compressed.float(), up_sel.float())
logits_sel = logits_sel.clamp(-15, 15)
weights_sel = torch.sigmoid(logits_sel)
weighted_coef = torch.einsum('bk,bkr->br', weights_sel.to(coef_sel.dtype), coef_sel)
features = weighted_coef @ basis_weight.T.to(weighted_coef.dtype)
return (features.reshape(B, T, D).to(hidden.dtype), topi.reshape(B, T, k), logits_sel.reshape(B, T, k))
def compute_logits_for_indices(self, hidden: Tensor, indices: Tensor) -> Tensor:
"""
Compute logits for specific concept indices only (sparse).
Supports both dense and factorized heads.
IMPORTANT: This function materializes (M, K, D) where M is the number of
tokens in hidden. Only call this with small M (e.g., masked tokens only).
Args:
hidden: (M, D) or (B, T, D) hidden states
indices: (M, K) or (B, T, K) concept indices
Returns:
logits: Same shape as indices
"""
if hidden.dim() == 2:
M, D = hidden.shape
K = indices.size(-1)
flat_h = hidden
flat_idx = indices
output_shape = indices.shape
else:
B, T, D = hidden.shape
K = indices.size(-1)
M = B * T
flat_h = hidden.reshape(M, D)
flat_idx = indices.reshape(M, K)
output_shape = indices.shape
estimated_bytes = M * K * D * 2
if estimated_bytes > 1000000000.0:
warnings.warn(f'compute_logits_for_indices will allocate ~{estimated_bytes / 1000000000.0:.1f} GB. Consider reducing M={M} (use masked tokens only) or K={K}.')
n_valid = self.n_concepts
indices_safe = flat_idx.clamp(0, n_valid - 1)
if self.use_attention:
query = self.concept_query_projection(flat_h.unsqueeze(0)).squeeze(0)
scale = 1.0 / math.sqrt(self.concept_dim)
E_sel = self._get_embedding(indices_safe)
logits = torch.einsum('md,mkd->mk', query.float(), E_sel.float()) * scale
else:
if self.factorize:
W = self._get_predictor_weight()[:n_valid]
W_sel = self._safe_index(W, indices_safe)
else:
W = self.concept_predictor.weight[:n_valid]
W_sel = self._safe_index(W, indices_safe)
logits = torch.einsum('md,mkd->mk', flat_h.float(), W_sel.float())
return logits.clamp(-15, 15).reshape(output_shape)
def get_concept_weights(self, hidden: Tensor, concept_ids: Tensor) -> Tensor:
"""
Get sigmoid weights for specific concepts (for attribution).
Args:
hidden: (B, T, D) or (M, D) hidden states
concept_ids: (B, T, K) or (M, K) or (K,) concept indices
Returns:
weights: Same shape as concept_ids, values in [0, 1]
"""
if concept_ids.dim() == 1:
if hidden.dim() == 2:
M = hidden.size(0)
concept_ids = concept_ids.unsqueeze(0).expand(M, -1)
else:
B, T, _ = hidden.shape
concept_ids = concept_ids.unsqueeze(0).unsqueeze(0).expand(B, T, -1)
logits = self.compute_logits_for_indices(hidden, concept_ids)
return torch.sigmoid(logits)
@staticmethod
def blocked_logits(query: Tensor, embeddings: Tensor, block_size: int=8192, out_device: torch.device | None=None, out_dtype: torch.dtype=torch.float32) -> Tensor:
"""
Compute concept logits in column blocks for memory efficiency.
logits = query @ embeddings.T / sqrt(D)
"""
B, T, D = query.shape
C = embeddings.size(0)
scale = 1.0 / math.sqrt(D)
dev = query.device if out_device is None else out_device
logits = torch.empty(B, T, C, device=dev, dtype=out_dtype)
q = query.reshape(-1, D).to(torch.float32)
Et = embeddings.t().contiguous().to(torch.float32)
for s in range(0, C, block_size):
e = min(s + block_size, C)
scores = q @ Et[:, s:e] * scale
scores = scores.clamp(-15, 15)
logits[:, :, s:e] = scores.reshape(B, T, e - s).to(out_dtype)
return logits
@staticmethod
def blocked_mix(weights: Tensor, embeddings: Tensor, block_size: int=8192) -> Tensor:
"""
Compute weighted sum of embeddings in column blocks.
output = weights @ embeddings
"""
B, T, C = weights.shape
D = embeddings.size(1)
out = torch.zeros(B, T, D, device=weights.device, dtype=weights.dtype)
for s in range(0, C, block_size):
e = min(s + block_size, C)
w_blk = weights[:, :, s:e].to(torch.float32)
V_blk = embeddings[s:e].to(w_blk.dtype)
out.add_(w_blk @ V_blk)
return out.to(weights.dtype)
@staticmethod
def sigmoid_block_attention(query: Tensor, embeddings: Tensor, block_size: int=8192, return_logits: bool=False) -> Tensor | tuple[Tensor, Tensor]:
"""Memory-efficient sigmoid attention using block processing."""
B, T, D = query.shape
C = embeddings.shape[0]
scale = 1.0 / math.sqrt(D)
flat_q = query.reshape(-1, D)
emb_T = embeddings.t().contiguous()
output = torch.zeros(B * T, D, dtype=query.dtype, device=query.device)
logits: Tensor | None = None
if return_logits:
logits = torch.empty(B, T, C, dtype=torch.float32, device=query.device)
for start in range(0, C, block_size):
end = min(start + block_size, C)
scores = (flat_q @ emb_T[:, start:end]).to(torch.float32) * scale
scores = scores.clamp(-15, 15)
if logits is not None:
logits[:, :, start:end] = scores.reshape(B, T, end - start)
weights = torch.sigmoid(scores)
output.add_(weights @ embeddings[start:end].to(weights.dtype))
output = output.reshape(B, T, D).to(query.dtype)
if return_logits:
assert logits is not None
return (output, logits)
return output
def _apply_sparse_interventions(self, features: Tensor, hidden: Tensor, intervene_ids: Tensor, intervene_vals: Tensor) -> Tensor:
"""
Apply sparse interventions matching original dense behavior.
Original dense behavior:
weights = sigmoid(logits) # (B, T, C)
weights[..., c] = new_val # Override
features = weights @ embeddings
Sparse equivalent:
features += (new_val - current_weight) * embedding[c]
"""
B, T, D = features.shape
valid = intervene_ids != -1
if not valid.any():
return features
ids_safe = intervene_ids.clamp(0, self.n_concepts - 1)
current_logits = self.compute_logits_for_indices(hidden, ids_safe)
current_weights = torch.sigmoid(current_logits)
emb = self._get_embedding(ids_safe)
delta = (intervene_vals - current_weights) * valid.float()
correction = (delta.unsqueeze(-1) * emb).sum(dim=2)
return features + correction
def _apply_dense_interventions(self, concept_weight: Tensor, intervene_ids: Tensor, intervene_vals: Tensor) -> Tensor:
"""Apply interventions by overriding concept weights (dense path)."""
n_valid = min(self.n_concepts, concept_weight.size(-1))
valid_edit = intervene_ids != -1
ids = intervene_ids.clamp(0, n_valid - 1).long()
vals = intervene_vals.to(concept_weight.dtype)
updates = torch.zeros_like(concept_weight)
updates.scatter_add_(2, ids, torch.where(valid_edit, vals, torch.zeros_like(vals)))
set_mask = torch.zeros_like(concept_weight, dtype=torch.bool)
set_mask.scatter_(2, ids, valid_edit)
return torch.where(set_mask, updates, concept_weight)
def topk_with_cutoff(self, tensor: Tensor, dim: int=-1) -> Tensor:
"""
Apply top-k sparsity, zeroing out all but top-k values.
Args:
tensor: Input tensor, typically (B, T, C)
dim: Dimension to apply top-k (default: last)
Returns:
Sparse tensor with only top-k values preserved
"""
assert dim == -1 or dim == tensor.dim() - 1
if self.topk is None:
return tensor
padded = tensor.size(dim)
n_valid = min(self.n_concepts, padded)
if n_valid <= 0:
return torch.zeros_like(tensor)
x = tensor.narrow(dim, 0, n_valid)
kk = min(self.topk, n_valid)
topv, topi = torch.topk(x, kk, dim=dim)
out = torch.zeros_like(x)
out.scatter_(dim, topi, topv)
if n_valid < padded:
pad_shape = list(out.shape)
pad_shape[dim] = padded - n_valid
pad_zeros = out.new_zeros(pad_shape)
out = torch.cat([out, pad_zeros], dim=dim)
return out
def _compute_weights(self, concept_logits: Tensor, E: Tensor) -> Tensor:
"""Compute concept weights from logits, with optional top-k sparsity."""
apply_topk = self.topk is not None and (not self.is_unknown or self.apply_topk_to_unknown)
if apply_topk and self.topk_on_logits:
logits_for_weights = self.topk_with_cutoff(concept_logits)
weights = torch.sigmoid(logits_for_weights).to(E.dtype)
return weights
weights = torch.sigmoid(concept_logits).to(E.dtype)
if apply_topk and (not self.topk_on_logits):
weights = self.topk_with_cutoff(weights)
return weights
@torch.compiler.disable
def forward(self, hidden: Tensor, intervene_ids: Tensor | None=None, intervene_vals: Tensor | None=None, return_logits: bool=False, store_hidden: bool=False) -> ConceptHeadOutput:
"""
Forward pass for concept decomposition (inference only, no teacher forcing).
Args:
hidden: Transformer hidden states (B, T, n_embd)
intervene_ids: Concept IDs to intervene on (B, T, K_int), -1 = skip
intervene_vals: Intervention strength values (B, T, K_int)
return_logits: If True, compute full (B, T, C) logits. Forbidden for large heads.
store_hidden: If True, store hidden in output for later attribution.
Returns:
ConceptHeadOutput with features, predicted, topk_indices, topk_logits
"""
B, T, _ = hidden.shape
has_interventions = intervene_ids is not None and intervene_vals is not None
if return_logits:
self._check_dense_allowed('return_logits=True')
n_valid = self.n_concepts
concept_logits: Tensor | None = None
concept_weight: Tensor | None = None
predicted: Tensor
topk_indices: Tensor | None = None
topk_logits: Tensor | None = None
apply_topk = self.topk is not None and (not self.is_unknown or self.apply_topk_to_unknown)
k_features = self.topk_features if self.topk_features is not None else self.topk
use_dense_intervention = has_interventions and (not self._is_large)
if use_dense_intervention:
E = self._get_embedding_weight()[:n_valid]
if self.use_attention:
query = self.concept_query_projection(hidden)
concept_logits = self.blocked_logits(query, E, block_size=self.block_size)
else:
if self.factorize:
W = self._get_predictor_weight()[:n_valid]
raw_logits = hidden @ W.T
else:
raw_logits = self.concept_predictor(hidden)[..., :n_valid]
concept_logits = raw_logits.float().clamp(-15, 15)
concept_weight = self._compute_weights(concept_logits, E)
assert intervene_ids is not None and intervene_vals is not None
concept_weight = self._apply_dense_interventions(concept_weight, intervene_ids, intervene_vals)
predicted = self.blocked_mix(concept_weight, E, block_size=self.block_size)
elif self.factorize:
if self.use_attention:
query = self.concept_query_projection(hidden)
if apply_topk:
predicted, topk_indices, topk_logits = self.attention_features_topk_factorized(query, k=k_features, block_size=self.block_size)
else:
predicted = self.attention_block_features_factorized(query, block_size=self.block_size)
elif apply_topk:
predicted, topk_indices, topk_logits = self.linear_features_topk_factorized(hidden, k=k_features, block_size=self.block_size)
else:
predicted = self.linear_block_features_factorized(hidden, block_size=self.block_size)
elif apply_topk:
E = self._get_embedding_weight()[:n_valid]
if self.use_attention:
query = self.concept_query_projection(hidden)
predicted, topk_indices, topk_logits = self.attention_features_topk_streaming(query, E, k=k_features, block_size=self.block_size, topk_on_logits=self.topk_on_logits)
else:
W = self.concept_predictor.weight[:n_valid]
predicted, topk_indices, topk_logits = self.linear_features_topk_streaming(hidden, W, E, k=k_features, block_size=self.block_size, topk_on_logits=self.topk_on_logits)
else:
E = self._get_embedding_weight()[:n_valid]
if self.use_attention:
query = self.concept_query_projection(hidden)
predicted = self.attention_block_features(query, E, block_size=self.block_size)
else:
W = self.concept_predictor.weight[:n_valid]
predicted = self.linear_block_features(hidden, W, E, block_size=self.block_size)
if topk_indices is not None and self.topk is not None and (self.topk_features is not None) and (self.topk_features > self.topk):
_, rerank_idx = torch.topk(topk_logits, self.topk, dim=-1)
topk_indices = torch.gather(topk_indices, -1, rerank_idx)
topk_logits = torch.gather(topk_logits, -1, rerank_idx)
if return_logits and (not use_dense_intervention):
E = self._get_embedding_weight()[:n_valid]
if self.use_attention:
query = self.concept_query_projection(hidden)
concept_logits = self.blocked_logits(query, E, block_size=self.block_size)
else:
if self.factorize:
W = self._get_predictor_weight()[:n_valid]
raw_logits = hidden @ W.T
else:
raw_logits = self.concept_predictor(hidden)[..., :n_valid]
concept_logits = raw_logits.float().clamp(-15, 15)
concept_weight = self._compute_weights(concept_logits, E)
if not hasattr(self, '_logged_forward_path'):
self._logged_forward_path = True
path = 'dense_intervention' if use_dense_intervention else 'factorized_topk' if self.factorize and apply_topk else 'factorized_all' if self.factorize else 'streaming_topk' if apply_topk else 'dense_all'
logger.info(f"[ConceptHead] {('Unknown' if self.is_unknown else 'Known')} head: path={path}, topk={self.topk}, topk_features={self.topk_features}, n_concepts={self.n_concepts}, factorize={self.factorize}, apply_topk={apply_topk}")
if topk_indices is not None and self.topk is not None and (self.topk_features is not None) and (self.topk_features > self.topk):
if not hasattr(self, '_logged_topk_slice'):
self._logged_topk_slice = True
logger.info(f"[ConceptHead] {('Unknown' if self.is_unknown else 'Known')} head: Sliced topk: {self.topk_features} features -> {self.topk} for loss")
if has_interventions and (not use_dense_intervention):
assert intervene_ids is not None and intervene_vals is not None
predicted = self._apply_sparse_interventions(predicted, hidden, intervene_ids, intervene_vals)
return ConceptHeadOutput(features=predicted, gt_features=None, logits=concept_logits, predicted=predicted, weights=concept_weight, topk_indices=topk_indices, topk_logits=topk_logits, hidden=hidden.detach() if store_hidden else None)
# ======================================================================
# steerling/models/interpretable/interpretable_causal_diffusion.py
# ======================================================================
logger = logging.getLogger(__name__)
class InterpretableCausalDiffusionLM(nn.Module):
"""
Interpretable CausalDiffusionLM with concept decomposition heads.
Wraps a CausalDiffusionLM and adds:
- Known concept head: predicts known concepts from hidden states
- Unknown concept head: captures residual features (optional)
- Steering via concept interventions
Args:
config: CausalDiffusionConfig (model architecture)
concept_config: ConceptConfig (concept decomposition)
vocab_size: Vocabulary size
"""
def __init__(self, config: CausalDiffusionConfig, concept_config: ConceptConfig, vocab_size: int):
super().__init__()
self.config = config
self.concept_config = concept_config
self.vocab_size = vocab_size
self.transformer = CausalDiffusionLM(config, vocab_size)
self.known_head = ConceptHead(n_concepts=concept_config.n_concepts, concept_dim=concept_config.concept_dim, n_embd=config.n_embd, is_unknown=False, use_attention=concept_config.use_attention_known, topk=concept_config.topk_known, topk_features=concept_config.topk_known_features, block_size=concept_config.block_size, pad_multiple=concept_config.pad_multiple, store_unknown_weights=False, apply_topk_to_unknown=False, topk_on_logits=concept_config.topk_on_logits)
if concept_config.use_unknown:
if concept_config.n_unknown_concepts is None:
raise ValueError('n_unknown_concepts must be set when use_unknown=True')
self.unknown_head: ConceptHead | None = ConceptHead(n_concepts=concept_config.n_unknown_concepts, concept_dim=concept_config.concept_dim, n_embd=config.n_embd, is_unknown=True, use_attention=concept_config.use_attention_unknown, topk=concept_config.unknown_topk, block_size=concept_config.block_size, pad_multiple=concept_config.pad_multiple, store_unknown_weights=False, apply_topk_to_unknown=concept_config.apply_topk_to_unknown, topk_on_logits=concept_config.topk_on_logits, factorize=concept_config.factorize_unknown, factorize_rank=concept_config.factorize_rank)
else:
self.unknown_head = None
def forward(self, input_ids: Tensor, *, input_embeds: Tensor | None=None, intervene_known_ids: Tensor | None=None, intervene_known_vals: Tensor | None=None, intervene_unknown_ids: Tensor | None=None, intervene_unknown_vals: Tensor | None=None, minimal_output: bool=False, position_injection: Tensor | None=None, steering_inject_layer: int | None=None, steering_inject_alpha: float=1.0, unknown_topk: int=64) -> tuple[Tensor, InterpretableOutput]:
"""
Forward pass with concept decomposition.
Args:
input_ids: Token IDs (B, T). May contain mask tokens.
input_embeds: Pre-computed embeddings (B, T, D). Overrides input_ids.
intervene_known_ids: Known concept IDs to intervene (B, T, K_int)
intervene_known_vals: Intervention values for known (B, T, K_int)
intervene_unknown_ids: Unknown concept IDs to intervene (B, T, K_int)
intervene_unknown_vals: Intervention values for unknown (B, T, K_int)
minimal_output: If True, skip some expensive computations
position_injection: Per-position steering injection (B, T, D)
steering_inject_layer: Inject at layers >= this
steering_inject_alpha: Injection strength
unknown_topk: Top-k for unknown head attribution
Returns:
logits: LM logits (B, T, V)
outputs: InterpretableOutput with all decomposition components
"""
need_dense_logits = not minimal_output
if position_injection is not None and steering_inject_layer is not None:
hidden = self._forward_with_injection(input_ids, input_embeds, position_injection, steering_inject_layer, steering_inject_alpha)
else:
hidden = self.transformer(input_ids, input_embeds=input_embeds, return_hidden=True)
known_out: ConceptHeadOutput = self.known_head(hidden, intervene_ids=intervene_known_ids, intervene_vals=intervene_known_vals, return_logits=need_dense_logits)
known_features = known_out.features.to(hidden.dtype)
unk = hidden - known_features.detach()
unk_for_lm: Tensor = unk
unknown_out: ConceptHeadOutput | None = None
unk_hat: Tensor | None = None
if self.unknown_head is not None:
unknown_out = self.unknown_head(hidden.detach(), intervene_ids=intervene_unknown_ids, intervene_vals=intervene_unknown_vals, return_logits=not minimal_output and (not self.unknown_head._is_large))
assert unknown_out is not None
unk_hat = unknown_out.features.to(hidden.dtype)
unk_for_lm = unk_hat.detach()
epsilon_true = None
if self.unknown_head is not None and unk_hat is not None:
epsilon_true = hidden.detach() - (known_out.predicted + unk_hat)
epsilon = None
if self.concept_config.use_epsilon_correction and intervene_known_ids is None:
epsilon = hidden - (unk_for_lm + known_features)
unk_for_lm = unk_for_lm + epsilon
composed = unk_for_lm + known_features
logits = self.transformer.lm_head(composed)
_unk_topk_indices = unknown_out.topk_indices if unknown_out else None
_unk_topk_logits = unknown_out.topk_logits if unknown_out else None
if not minimal_output and self.unknown_head is not None and (unknown_out is not None) and (_unk_topk_indices is None) and (unknown_topk > 0):
with torch.no_grad():
_unk_topk_indices, _unk_topk_logits = self._compute_unknown_topk(hidden, unknown_topk)
outputs = InterpretableOutput(hidden=hidden, known_features=known_features, known_logits=known_out.logits, known_gt_features=known_out.gt_features, known_predicted=known_out.predicted, known_weights=known_out.weights, known_topk_indices=known_out.topk_indices, known_topk_logits=known_out.topk_logits, unk=unk, unk_hat=unk_hat, unk_for_lm=unk_for_lm, unknown_logits=unknown_out.logits if unknown_out else None, unknown_weights=unknown_out.weights if unknown_out else None, unknown_topk_indices=_unk_topk_indices, unknown_topk_logits=_unk_topk_logits, composed=composed, epsilon=epsilon, epsilon_true=epsilon_true)
return (logits, outputs)
def _compute_unknown_topk(self, hidden: Tensor, unknown_topk: int) -> tuple[Tensor | None, Tensor | None]:
"""Compute unknown head top-k indices for attribution."""
assert self.unknown_head is not None
if self.unknown_head.factorize:
if self.unknown_head.use_attention:
_query = self.unknown_head.concept_query_projection(hidden.detach())
_, indices, logits = self.unknown_head.attention_features_topk_factorized(_query, k=unknown_topk, block_size=self.unknown_head.block_size)
else:
_, indices, logits = self.unknown_head.linear_features_topk_factorized(hidden.detach(), k=unknown_topk, block_size=self.unknown_head.block_size)
else:
_E = self.unknown_head._get_embedding_weight()[:self.unknown_head.n_concepts]
if self.unknown_head.use_attention:
_query = self.unknown_head.concept_query_projection(hidden.detach())
_, indices, logits = self.unknown_head.attention_features_topk_streaming(_query, _E, k=unknown_topk, block_size=self.unknown_head.block_size)
else:
_W = self.unknown_head.concept_predictor.weight[:self.unknown_head.n_concepts]
_, indices, logits = self.unknown_head.linear_features_topk_streaming(hidden.detach(), _W, _E, k=unknown_topk, block_size=self.unknown_head.block_size)
return (indices, logits)
def _forward_with_injection(self, input_ids: Tensor, input_embeds: Tensor | None, position_injection: Tensor, inject_layer: int, inject_alpha: float) -> Tensor:
"""Forward through transformer with steering injection at specified layers."""
x = input_embeds if input_embeds is not None else self.transformer.tok_emb(input_ids)
for i, block in enumerate(self.transformer.blocks):
x = block(x)
if i + 1 >= inject_layer:
x = x + inject_alpha * position_injection
x = self.transformer.ln_f(x)
return x
@torch.no_grad()
def intervene(self, input_ids: Tensor, known: dict[int, float] | None=None, unknown: dict[int, float] | None=None, positions: Tensor | None=None) -> tuple[Tensor, InterpretableOutput]:
"""
Run inference with concept interventions.
Args:
input_ids: Input token IDs (B, T)
known: Dict mapping known concept IDs to intervention strengths
unknown: Dict mapping unknown concept IDs to intervention strengths
positions: Bool mask of positions to intervene (B, T). Default: all.
Returns:
logits: LM logits (B, T, V)
outputs: InterpretableOutput
"""
B, T = input_ids.shape
device = input_ids.device
if positions is None:
positions = torch.ones(B, T, dtype=torch.bool, device=device)
int_known_ids, int_known_vals = (None, None)
if known is not None and len(known) > 0:
int_known_ids, int_known_vals = self._build_intervention_tensors(known, B, T, positions, device)
int_unknown_ids, int_unknown_vals = (None, None)
if unknown is not None and len(unknown) > 0:
int_unknown_ids, int_unknown_vals = self._build_intervention_tensors(unknown, B, T, positions, device)
return self(input_ids, intervene_known_ids=int_known_ids, intervene_known_vals=int_known_vals, intervene_unknown_ids=int_unknown_ids, intervene_unknown_vals=int_unknown_vals, minimal_output=False)
@staticmethod
def _build_intervention_tensors(interventions: dict[int, float], B: int, T: int, positions: Tensor, device: torch.device) -> tuple[Tensor, Tensor]:
"""Build intervention tensors for concept steering."""
K = len(interventions)
concept_ids = list(interventions.keys())
directions = list(interventions.values())
ids = torch.full((B, T, K), -1, dtype=torch.long, device=device)
vals = torch.zeros((B, T, K), dtype=torch.float32, device=device)
concept_tensor = torch.tensor(concept_ids, device=device)
direction_tensor = torch.tensor(directions, dtype=torch.float32, device=device)
n_active = int(positions.sum().item())
ids[positions] = concept_tensor.unsqueeze(0).expand(n_active, -1)
vals[positions] = direction_tensor.unsqueeze(0).expand(n_active, -1)
return (ids, vals)
def get_num_params(self, non_embedding: bool=True) -> int:
n_params = sum((p.numel() for p in self.parameters()))
if non_embedding and hasattr(self.transformer, 'tok_emb'):
n_params -= self.transformer.tok_emb.weight.numel()
return n_params
from transformers import PreTrainedModel
from .configuration_steerling import SteerlingConfig
# CausalDiffusionLM is the backbone — alias to HF-friendly name
SteerlingBackbone = CausalDiffusionLM
class SteerlingForCausalLM(PreTrainedModel):
config_class = SteerlingConfig
supports_gradient_checkpointing = False
_tied_weights_keys = ["transformer.lm_head.weight"]
def __init__(self, config: SteerlingConfig):
super().__init__(config)
# SteerlingConfig has all fields from both arch and concept configs
self.concept_config = config
self.transformer = SteerlingBackbone(config, config.vocab_size)
self.known_head = ConceptHead(
n_concepts=config.n_concepts,
concept_dim=config.concept_dim,
n_embd=config.n_embd,
is_unknown=False,
use_attention=config.use_attention_known,
topk=config.topk_known,
topk_features=config.topk_known_features,
block_size=config.concept_block_size,
pad_multiple=config.pad_multiple,
store_unknown_weights=False,
apply_topk_to_unknown=False,
topk_on_logits=config.topk_on_logits,
factorize=False,
)
if config.use_unknown:
self.unknown_head = ConceptHead(
n_concepts=config.n_unknown_concepts,
concept_dim=config.concept_dim,
n_embd=config.n_embd,
is_unknown=True,
use_attention=config.use_attention_unknown,
topk=config.unknown_topk,
block_size=config.concept_block_size,
pad_multiple=config.pad_multiple,
store_unknown_weights=config.store_unknown_weights,
apply_topk_to_unknown=config.apply_topk_to_unknown,
topk_on_logits=config.topk_on_logits,
factorize=config.factorize_unknown,
factorize_rank=config.factorize_rank,
)
else:
self.unknown_head = None
self.post_init()
def _init_weights(self, module):
pass
def _tie_weights(self):
if self.config.weight_sharing:
self.transformer.lm_head.weight = self.transformer.tok_emb.weight
def forward(self, input_ids=None, **kwargs):
if self.config.interpretable:
return InterpretableCausalDiffusionLM.forward(self, input_ids, **kwargs)
else:
kwargs.pop('minimal_output', None)
return CausalDiffusionLM.forward(self, input_ids, **kwargs)