| | """
|
| | Circuit Transformer: Minimal transformer for semantic circuitry experiments.
|
| |
|
| | Follows patterns from shimmer/lira/gpt.py with extension hooks for future work.
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | import math
|
| |
|
| | from .config import CircuitConfig
|
| | from .layers import RMSNorm, RotaryEmbedding, CausalAttention, SwiGLU, WordPositionRoPE
|
| |
|
| |
|
| | class TransformerBlock(nn.Module):
|
| | """Pre-norm transformer block with causal attention."""
|
| |
|
| | def __init__(
|
| | self,
|
| | hidden_size: int,
|
| | num_heads: int,
|
| | num_kv_heads: int | None = None,
|
| | max_seq_len: int = 2048,
|
| | dropout: float = 0.0,
|
| | window_size: int | None = None,
|
| | word_rope_dims: int = 0,
|
| | word_rope_base: float = 10.0,
|
| | ):
|
| | super().__init__()
|
| | self.attn_norm = RMSNorm(hidden_size)
|
| | self.attn = CausalAttention(hidden_size, num_heads, num_kv_heads, max_seq_len, dropout, window_size,
|
| | word_rope_dims=word_rope_dims, word_rope_base=word_rope_base)
|
| | self.ffn_norm = RMSNorm(hidden_size)
|
| | self.ffn = SwiGLU(hidden_size)
|
| |
|
| | def forward(
|
| | self, x: torch.Tensor, use_cache: bool = False, past_kv: tuple | None = None,
|
| | word_positions: torch.Tensor | None = None,
|
| | ) -> tuple[torch.Tensor, tuple | None]:
|
| |
|
| | attn_out, new_kv = self.attn(self.attn_norm(x), use_cache, past_kv, word_positions=word_positions)
|
| | x = x + attn_out
|
| |
|
| |
|
| | x = x + self.ffn(self.ffn_norm(x))
|
| |
|
| | return x, new_kv
|
| |
|
| |
|
| | class CircuitTransformer(nn.Module):
|
| | """
|
| | Minimal transformer for semantic circuitry experiments.
|
| |
|
| | Features:
|
| | - Standard GPT-style architecture (RMSNorm, RoPE, SwiGLU, causal attention)
|
| | - Weight tying (embed = lm_head)
|
| | - Extension hooks for future work:
|
| | - freeze_layers() / unfreeze_layers() for progressive training
|
| | - get_layer_outputs() for interpretability
|
| | - window_size param for sliding window attention
|
| | """
|
| |
|
| | def __init__(self, config: CircuitConfig):
|
| | super().__init__()
|
| | self.config = config
|
| |
|
| |
|
| | embed_dim = getattr(config, 'embed_dim', 0)
|
| | head_dim = getattr(config, 'head_dim', 0)
|
| |
|
| | if embed_dim > 0 and head_dim == 0:
|
| | head_dim = embed_dim
|
| |
|
| | if embed_dim > 0:
|
| | self.embed = nn.Embedding(config.vocab_size, embed_dim)
|
| | self.embed_proj = nn.Linear(embed_dim, config.hidden_size, bias=False)
|
| | else:
|
| | self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
|
| | self.embed_proj = None
|
| | self.embed_scale = math.sqrt(config.hidden_size)
|
| |
|
| |
|
| | self.layers = nn.ModuleList([
|
| | TransformerBlock(
|
| | config.hidden_size,
|
| | config.num_heads,
|
| | getattr(config, 'num_kv_heads', None),
|
| | config.max_seq_len,
|
| | config.dropout,
|
| | word_rope_dims=getattr(config, 'word_rope_dims', 0),
|
| | word_rope_base=getattr(config, 'word_rope_base', 10.0),
|
| | )
|
| | for _ in range(config.num_layers)
|
| | ])
|
| |
|
| |
|
| | self.norm = RMSNorm(config.hidden_size)
|
| | if head_dim > 0:
|
| | self.head_down = nn.Linear(config.hidden_size, head_dim, bias=False)
|
| | self.lm_head = nn.Linear(head_dim, config.vocab_size, bias=False)
|
| | else:
|
| | self.head_down = None
|
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| |
|
| |
|
| | _e = embed_dim if embed_dim > 0 else config.hidden_size
|
| | _h = head_dim if head_dim > 0 else config.hidden_size
|
| | if _e == _h:
|
| | self.lm_head.weight = self.embed.weight
|
| |
|
| |
|
| | self.skip_head = None
|
| | self.skip_head_down = None
|
| | aux_skip_k = getattr(config, 'aux_skip_k', 0)
|
| | if aux_skip_k > 0:
|
| | if head_dim > 0:
|
| | self.skip_head_down = nn.Linear(config.hidden_size, head_dim, bias=False)
|
| | self.skip_head = nn.Linear(head_dim, config.vocab_size, bias=False)
|
| | else:
|
| | self.skip_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| |
|
| |
|
| | self._frozen_layers: set[int] = set()
|
| |
|
| |
|
| | self.apply(self._init_weights)
|
| |
|
| | def _init_weights(self, module):
|
| | if isinstance(module, nn.Linear):
|
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| | if module.bias is not None:
|
| | torch.nn.init.zeros_(module.bias)
|
| | elif isinstance(module, nn.Embedding):
|
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| |
|
| | def forward(
|
| | self,
|
| | input_ids: torch.Tensor,
|
| | labels: torch.Tensor | None = None,
|
| | use_cache: bool = False,
|
| | past_kv: list | None = None,
|
| | word_positions: torch.Tensor | None = None,
|
| | ) -> dict:
|
| | """
|
| | Forward pass.
|
| |
|
| | Args:
|
| | input_ids: [B, L] token IDs
|
| | labels: [B, L] target token IDs (for loss computation)
|
| | use_cache: Whether to return KV cache for generation
|
| | past_kv: Previous KV cache
|
| | word_positions: [B, L] position within word (from compute_word_positions)
|
| |
|
| | Returns:
|
| | dict with 'logits', optionally 'loss' and 'past_kv'
|
| | """
|
| | B, L = input_ids.shape
|
| |
|
| |
|
| | x = self.embed(input_ids)
|
| | if self.embed_proj is not None:
|
| | x = F.silu(self.embed_proj(x))
|
| | x = x * self.embed_scale
|
| |
|
| |
|
| | new_kv = [] if use_cache else None
|
| | for i, layer in enumerate(self.layers):
|
| | layer_past = past_kv[i] if past_kv is not None else None
|
| | x, kv = layer(x, use_cache, layer_past, word_positions=word_positions)
|
| | if use_cache:
|
| | new_kv.append(kv)
|
| |
|
| |
|
| | x = self.norm(x)
|
| | if self.head_down is not None:
|
| | logits = self.lm_head(F.silu(self.head_down(x)))
|
| | else:
|
| | logits = self.lm_head(x)
|
| |
|
| | result = {"logits": logits}
|
| |
|
| | if use_cache:
|
| | result["past_kv"] = new_kv
|
| |
|
| |
|
| | if labels is not None:
|
| |
|
| | shift_logits = logits[:, :-1, :].contiguous()
|
| | shift_labels = labels[:, 1:].contiguous()
|
| | loss = F.cross_entropy(
|
| | shift_logits.view(-1, self.config.vocab_size),
|
| | shift_labels.view(-1),
|
| | ignore_index=-100,
|
| | )
|
| |
|
| |
|
| | if self.skip_head is not None:
|
| | skip_k = getattr(self.config, 'aux_skip_k', 0)
|
| | skip_weight = getattr(self.config, 'aux_skip_weight', 0.1)
|
| | if self.skip_head_down is not None:
|
| | skip_logits = self.skip_head(F.silu(self.skip_head_down(x)))[:, :-skip_k, :].contiguous()
|
| | else:
|
| | skip_logits = self.skip_head(x)[:, :-skip_k, :].contiguous()
|
| | skip_labels = labels[:, skip_k:].contiguous()
|
| | aux_loss = F.cross_entropy(
|
| | skip_logits.view(-1, self.config.vocab_size),
|
| | skip_labels.view(-1),
|
| | ignore_index=-100,
|
| | )
|
| | result["aux_loss"] = aux_loss
|
| | loss = loss + skip_weight * aux_loss
|
| |
|
| | result["loss"] = loss
|
| |
|
| | return result
|
| |
|
| |
|
| |
|
| | def freeze_layers(self, indices: list[int]) -> None:
|
| | """Freeze specific layers (stop gradients)."""
|
| | for idx in indices:
|
| | if 0 <= idx < len(self.layers):
|
| | for param in self.layers[idx].parameters():
|
| | param.requires_grad = False
|
| | self._frozen_layers.add(idx)
|
| |
|
| | def unfreeze_layers(self, indices: list[int] | None = None) -> None:
|
| | """Unfreeze specific layers (or all if indices=None)."""
|
| | if indices is None:
|
| | indices = list(self._frozen_layers)
|
| | for idx in indices:
|
| | if 0 <= idx < len(self.layers):
|
| | for param in self.layers[idx].parameters():
|
| | param.requires_grad = True
|
| | self._frozen_layers.discard(idx)
|
| |
|
| | def get_layer_outputs(self, input_ids: torch.Tensor) -> list[torch.Tensor]:
|
| | """Get intermediate outputs from each layer for interpretability."""
|
| | outputs = []
|
| | x = self.embed(input_ids)
|
| | if self.embed_proj is not None:
|
| | x = F.silu(self.embed_proj(x))
|
| | x = x * self.embed_scale
|
| |
|
| | for layer in self.layers:
|
| | x, _ = layer(x, use_cache=False, past_kv=None)
|
| | outputs.append(x.clone())
|
| |
|
| | return outputs
|
| |
|
| | @torch.no_grad()
|
| | def generate(
|
| | self,
|
| | prompt_ids: torch.Tensor,
|
| | max_new_tokens: int = 50,
|
| | temperature: float = 0.8,
|
| | top_k: int = 50,
|
| | top_p: float = 0.9,
|
| | use_cache: bool = True,
|
| | word_start_table: torch.Tensor | None = None,
|
| | ) -> torch.Tensor:
|
| | """
|
| | Autoregressive generation with KV caching.
|
| |
|
| | Args:
|
| | prompt_ids: [B, L] prompt token IDs
|
| | max_new_tokens: Maximum tokens to generate
|
| | temperature: Sampling temperature
|
| | top_k: Top-k filtering
|
| | top_p: Nucleus sampling threshold
|
| | use_cache: Use KV cache for faster generation
|
| | word_start_table: [vocab_size] bool tensor for word-position RoPE
|
| |
|
| | Returns:
|
| | [B, L + max_new_tokens] generated token IDs
|
| | """
|
| | from .layers import compute_word_positions
|
| |
|
| | self.eval()
|
| | generated = prompt_ids.clone()
|
| | past_kv = None
|
| | word_pos_counter = 0
|
| |
|
| | for _ in range(max_new_tokens):
|
| |
|
| | if use_cache and past_kv is not None:
|
| | input_ids = generated[:, -1:]
|
| |
|
| | if word_start_table is not None:
|
| | last_token = generated[0, -1].item()
|
| | if word_start_table[last_token]:
|
| | word_pos_counter = 0
|
| | else:
|
| | word_pos_counter += 1
|
| | word_positions = torch.tensor([[float(word_pos_counter)]], device=input_ids.device)
|
| | else:
|
| | word_positions = None
|
| | else:
|
| | input_ids = generated
|
| |
|
| | if word_start_table is not None:
|
| | word_positions = compute_word_positions(input_ids, word_start_table)
|
| | else:
|
| | word_positions = None
|
| |
|
| |
|
| | output = self(input_ids, use_cache=use_cache, past_kv=past_kv, word_positions=word_positions)
|
| | logits = output["logits"][:, -1, :]
|
| |
|
| | if use_cache:
|
| | past_kv = output["past_kv"]
|
| |
|
| |
|
| | if temperature > 0:
|
| | logits = logits / temperature
|
| |
|
| |
|
| | if top_k > 0:
|
| | top_k_vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| | min_top_k = top_k_vals[:, -1].unsqueeze(-1)
|
| | logits = torch.where(logits < min_top_k, float("-inf"), logits)
|
| |
|
| |
|
| | if top_p < 1.0:
|
| | sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| | cumsum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| |
|
| |
|
| | sorted_indices_to_remove = cumsum_probs > top_p
|
| | sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
| | sorted_indices_to_remove[:, 0] = False
|
| |
|
| | indices_to_remove = sorted_indices_to_remove.scatter(
|
| | 1, sorted_indices, sorted_indices_to_remove
|
| | )
|
| | logits = logits.masked_fill(indices_to_remove, float("-inf"))
|
| |
|
| |
|
| | probs = F.softmax(logits, dim=-1)
|
| | next_token = torch.multinomial(probs, num_samples=1)
|
| | else:
|
| |
|
| | next_token = logits.argmax(dim=-1, keepdim=True)
|
| |
|
| | generated = torch.cat([generated, next_token], dim=1)
|
| |
|
| |
|
| | if generated.size(1) >= self.config.max_seq_len:
|
| | break
|
| |
|
| | return generated
|
| |
|
| |
|
| | def count_parameters(model: CircuitTransformer) -> int:
|
| | """Count trainable parameters."""
|
| | return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| |
|