""" Circuit Transformer: Minimal transformer for semantic circuitry experiments. Follows patterns from shimmer/lira/gpt.py with extension hooks for future work. """ import torch import torch.nn as nn import torch.nn.functional as F import math from .config import CircuitConfig from .layers import RMSNorm, RotaryEmbedding, CausalAttention, SwiGLU, WordPositionRoPE class TransformerBlock(nn.Module): """Pre-norm transformer block with causal attention.""" def __init__( self, hidden_size: int, num_heads: int, num_kv_heads: int | None = None, max_seq_len: int = 2048, dropout: float = 0.0, window_size: int | None = None, word_rope_dims: int = 0, word_rope_base: float = 10.0, ): super().__init__() self.attn_norm = RMSNorm(hidden_size) self.attn = CausalAttention(hidden_size, num_heads, num_kv_heads, max_seq_len, dropout, window_size, word_rope_dims=word_rope_dims, word_rope_base=word_rope_base) self.ffn_norm = RMSNorm(hidden_size) self.ffn = SwiGLU(hidden_size) def forward( self, x: torch.Tensor, use_cache: bool = False, past_kv: tuple | None = None, word_positions: torch.Tensor | None = None, ) -> tuple[torch.Tensor, tuple | None]: # Attention with residual attn_out, new_kv = self.attn(self.attn_norm(x), use_cache, past_kv, word_positions=word_positions) x = x + attn_out # FFN with residual x = x + self.ffn(self.ffn_norm(x)) return x, new_kv class CircuitTransformer(nn.Module): """ Minimal transformer for semantic circuitry experiments. Features: - Standard GPT-style architecture (RMSNorm, RoPE, SwiGLU, causal attention) - Weight tying (embed = lm_head) - Extension hooks for future work: - freeze_layers() / unfreeze_layers() for progressive training - get_layer_outputs() for interpretability - window_size param for sliding window attention """ def __init__(self, config: CircuitConfig): super().__init__() self.config = config # Token embeddings (optionally factorized) embed_dim = getattr(config, 'embed_dim', 0) head_dim = getattr(config, 'head_dim', 0) # Auto-mirror factorization: head uses embed_dim for weight tying if embed_dim > 0 and head_dim == 0: head_dim = embed_dim if embed_dim > 0: self.embed = nn.Embedding(config.vocab_size, embed_dim) self.embed_proj = nn.Linear(embed_dim, config.hidden_size, bias=False) else: self.embed = nn.Embedding(config.vocab_size, config.hidden_size) self.embed_proj = None self.embed_scale = math.sqrt(config.hidden_size) # Transformer blocks self.layers = nn.ModuleList([ TransformerBlock( config.hidden_size, config.num_heads, getattr(config, 'num_kv_heads', None), config.max_seq_len, config.dropout, word_rope_dims=getattr(config, 'word_rope_dims', 0), word_rope_base=getattr(config, 'word_rope_base', 10.0), ) for _ in range(config.num_layers) ]) # Output (optionally MLP head) self.norm = RMSNorm(config.hidden_size) if head_dim > 0: self.head_down = nn.Linear(config.hidden_size, head_dim, bias=False) self.lm_head = nn.Linear(head_dim, config.vocab_size, bias=False) else: self.head_down = None self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Weight tying (when embed and lm_head dimensions match) _e = embed_dim if embed_dim > 0 else config.hidden_size _h = head_dim if head_dim > 0 else config.hidden_size if _e == _h: self.lm_head.weight = self.embed.weight # Auxiliary skip-ahead prediction head self.skip_head = None self.skip_head_down = None aux_skip_k = getattr(config, 'aux_skip_k', 0) if aux_skip_k > 0: if head_dim > 0: self.skip_head_down = nn.Linear(config.hidden_size, head_dim, bias=False) self.skip_head = nn.Linear(head_dim, config.vocab_size, bias=False) else: self.skip_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Track frozen layers self._frozen_layers: set[int] = set() # Initialize weights self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward( self, input_ids: torch.Tensor, labels: torch.Tensor | None = None, use_cache: bool = False, past_kv: list | None = None, word_positions: torch.Tensor | None = None, ) -> dict: """ Forward pass. Args: input_ids: [B, L] token IDs labels: [B, L] target token IDs (for loss computation) use_cache: Whether to return KV cache for generation past_kv: Previous KV cache word_positions: [B, L] position within word (from compute_word_positions) Returns: dict with 'logits', optionally 'loss' and 'past_kv' """ B, L = input_ids.shape # Embed tokens (optionally factorized) x = self.embed(input_ids) if self.embed_proj is not None: x = F.silu(self.embed_proj(x)) x = x * self.embed_scale # Process through layers new_kv = [] if use_cache else None for i, layer in enumerate(self.layers): layer_past = past_kv[i] if past_kv is not None else None x, kv = layer(x, use_cache, layer_past, word_positions=word_positions) if use_cache: new_kv.append(kv) # Output (optionally MLP head) x = self.norm(x) if self.head_down is not None: logits = self.lm_head(F.silu(self.head_down(x))) else: logits = self.lm_head(x) result = {"logits": logits} if use_cache: result["past_kv"] = new_kv # Compute loss if labels provided if labels is not None: # Shift for next-token prediction shift_logits = logits[:, :-1, :].contiguous() shift_labels = labels[:, 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1), ignore_index=-100, ) # Auxiliary skip-ahead prediction if self.skip_head is not None: skip_k = getattr(self.config, 'aux_skip_k', 0) skip_weight = getattr(self.config, 'aux_skip_weight', 0.1) if self.skip_head_down is not None: skip_logits = self.skip_head(F.silu(self.skip_head_down(x)))[:, :-skip_k, :].contiguous() else: skip_logits = self.skip_head(x)[:, :-skip_k, :].contiguous() skip_labels = labels[:, skip_k:].contiguous() aux_loss = F.cross_entropy( skip_logits.view(-1, self.config.vocab_size), skip_labels.view(-1), ignore_index=-100, ) result["aux_loss"] = aux_loss loss = loss + skip_weight * aux_loss result["loss"] = loss return result # === Extension hooks for future experiments === def freeze_layers(self, indices: list[int]) -> None: """Freeze specific layers (stop gradients).""" for idx in indices: if 0 <= idx < len(self.layers): for param in self.layers[idx].parameters(): param.requires_grad = False self._frozen_layers.add(idx) def unfreeze_layers(self, indices: list[int] | None = None) -> None: """Unfreeze specific layers (or all if indices=None).""" if indices is None: indices = list(self._frozen_layers) for idx in indices: if 0 <= idx < len(self.layers): for param in self.layers[idx].parameters(): param.requires_grad = True self._frozen_layers.discard(idx) def get_layer_outputs(self, input_ids: torch.Tensor) -> list[torch.Tensor]: """Get intermediate outputs from each layer for interpretability.""" outputs = [] x = self.embed(input_ids) if self.embed_proj is not None: x = F.silu(self.embed_proj(x)) x = x * self.embed_scale for layer in self.layers: x, _ = layer(x, use_cache=False, past_kv=None) outputs.append(x.clone()) return outputs @torch.no_grad() def generate( self, prompt_ids: torch.Tensor, max_new_tokens: int = 50, temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9, use_cache: bool = True, word_start_table: torch.Tensor | None = None, ) -> torch.Tensor: """ Autoregressive generation with KV caching. Args: prompt_ids: [B, L] prompt token IDs max_new_tokens: Maximum tokens to generate temperature: Sampling temperature top_k: Top-k filtering top_p: Nucleus sampling threshold use_cache: Use KV cache for faster generation word_start_table: [vocab_size] bool tensor for word-position RoPE Returns: [B, L + max_new_tokens] generated token IDs """ from .layers import compute_word_positions self.eval() generated = prompt_ids.clone() past_kv = None word_pos_counter = 0 # Track word position during cached generation for _ in range(max_new_tokens): # Get input (full sequence or just last token with cache) if use_cache and past_kv is not None: input_ids = generated[:, -1:] # Compute word position for the single new token if word_start_table is not None: last_token = generated[0, -1].item() if word_start_table[last_token]: word_pos_counter = 0 else: word_pos_counter += 1 word_positions = torch.tensor([[float(word_pos_counter)]], device=input_ids.device) else: word_positions = None else: input_ids = generated # Compute word positions for full sequence if word_start_table is not None: word_positions = compute_word_positions(input_ids, word_start_table) else: word_positions = None # Forward pass output = self(input_ids, use_cache=use_cache, past_kv=past_kv, word_positions=word_positions) logits = output["logits"][:, -1, :] # Last position if use_cache: past_kv = output["past_kv"] # Apply temperature if temperature > 0: logits = logits / temperature # Top-k filtering if top_k > 0: top_k_vals, _ = torch.topk(logits, min(top_k, logits.size(-1))) min_top_k = top_k_vals[:, -1].unsqueeze(-1) logits = torch.where(logits < min_top_k, float("-inf"), logits) # Top-p (nucleus) filtering if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumsum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # Remove tokens with cumulative prob above threshold sorted_indices_to_remove = cumsum_probs > top_p sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() sorted_indices_to_remove[:, 0] = False indices_to_remove = sorted_indices_to_remove.scatter( 1, sorted_indices, sorted_indices_to_remove ) logits = logits.masked_fill(indices_to_remove, float("-inf")) # Sample probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: # Greedy next_token = logits.argmax(dim=-1, keepdim=True) generated = torch.cat([generated, next_token], dim=1) # Stop if max length reached if generated.size(1) >= self.config.max_seq_len: break return generated def count_parameters(model: CircuitTransformer) -> int: """Count trainable parameters.""" return sum(p.numel() for p in model.parameters() if p.requires_grad)