Prisma / model.py
y3i12's picture
Initial commit
56e82ec
"""
Circuit Transformer: Minimal transformer for semantic circuitry experiments.
Follows patterns from shimmer/lira/gpt.py with extension hooks for future work.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from .config import CircuitConfig
from .layers import RMSNorm, RotaryEmbedding, CausalAttention, SwiGLU, WordPositionRoPE
class TransformerBlock(nn.Module):
"""Pre-norm transformer block with causal attention."""
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int | None = None,
max_seq_len: int = 2048,
dropout: float = 0.0,
window_size: int | None = None,
word_rope_dims: int = 0,
word_rope_base: float = 10.0,
):
super().__init__()
self.attn_norm = RMSNorm(hidden_size)
self.attn = CausalAttention(hidden_size, num_heads, num_kv_heads, max_seq_len, dropout, window_size,
word_rope_dims=word_rope_dims, word_rope_base=word_rope_base)
self.ffn_norm = RMSNorm(hidden_size)
self.ffn = SwiGLU(hidden_size)
def forward(
self, x: torch.Tensor, use_cache: bool = False, past_kv: tuple | None = None,
word_positions: torch.Tensor | None = None,
) -> tuple[torch.Tensor, tuple | None]:
# Attention with residual
attn_out, new_kv = self.attn(self.attn_norm(x), use_cache, past_kv, word_positions=word_positions)
x = x + attn_out
# FFN with residual
x = x + self.ffn(self.ffn_norm(x))
return x, new_kv
class CircuitTransformer(nn.Module):
"""
Minimal transformer for semantic circuitry experiments.
Features:
- Standard GPT-style architecture (RMSNorm, RoPE, SwiGLU, causal attention)
- Weight tying (embed = lm_head)
- Extension hooks for future work:
- freeze_layers() / unfreeze_layers() for progressive training
- get_layer_outputs() for interpretability
- window_size param for sliding window attention
"""
def __init__(self, config: CircuitConfig):
super().__init__()
self.config = config
# Token embeddings (optionally factorized)
embed_dim = getattr(config, 'embed_dim', 0)
head_dim = getattr(config, 'head_dim', 0)
# Auto-mirror factorization: head uses embed_dim for weight tying
if embed_dim > 0 and head_dim == 0:
head_dim = embed_dim
if embed_dim > 0:
self.embed = nn.Embedding(config.vocab_size, embed_dim)
self.embed_proj = nn.Linear(embed_dim, config.hidden_size, bias=False)
else:
self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
self.embed_proj = None
self.embed_scale = math.sqrt(config.hidden_size)
# Transformer blocks
self.layers = nn.ModuleList([
TransformerBlock(
config.hidden_size,
config.num_heads,
getattr(config, 'num_kv_heads', None),
config.max_seq_len,
config.dropout,
word_rope_dims=getattr(config, 'word_rope_dims', 0),
word_rope_base=getattr(config, 'word_rope_base', 10.0),
)
for _ in range(config.num_layers)
])
# Output (optionally MLP head)
self.norm = RMSNorm(config.hidden_size)
if head_dim > 0:
self.head_down = nn.Linear(config.hidden_size, head_dim, bias=False)
self.lm_head = nn.Linear(head_dim, config.vocab_size, bias=False)
else:
self.head_down = None
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Weight tying (when embed and lm_head dimensions match)
_e = embed_dim if embed_dim > 0 else config.hidden_size
_h = head_dim if head_dim > 0 else config.hidden_size
if _e == _h:
self.lm_head.weight = self.embed.weight
# Auxiliary skip-ahead prediction head
self.skip_head = None
self.skip_head_down = None
aux_skip_k = getattr(config, 'aux_skip_k', 0)
if aux_skip_k > 0:
if head_dim > 0:
self.skip_head_down = nn.Linear(config.hidden_size, head_dim, bias=False)
self.skip_head = nn.Linear(head_dim, config.vocab_size, bias=False)
else:
self.skip_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Track frozen layers
self._frozen_layers: set[int] = set()
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(
self,
input_ids: torch.Tensor,
labels: torch.Tensor | None = None,
use_cache: bool = False,
past_kv: list | None = None,
word_positions: torch.Tensor | None = None,
) -> dict:
"""
Forward pass.
Args:
input_ids: [B, L] token IDs
labels: [B, L] target token IDs (for loss computation)
use_cache: Whether to return KV cache for generation
past_kv: Previous KV cache
word_positions: [B, L] position within word (from compute_word_positions)
Returns:
dict with 'logits', optionally 'loss' and 'past_kv'
"""
B, L = input_ids.shape
# Embed tokens (optionally factorized)
x = self.embed(input_ids)
if self.embed_proj is not None:
x = F.silu(self.embed_proj(x))
x = x * self.embed_scale
# Process through layers
new_kv = [] if use_cache else None
for i, layer in enumerate(self.layers):
layer_past = past_kv[i] if past_kv is not None else None
x, kv = layer(x, use_cache, layer_past, word_positions=word_positions)
if use_cache:
new_kv.append(kv)
# Output (optionally MLP head)
x = self.norm(x)
if self.head_down is not None:
logits = self.lm_head(F.silu(self.head_down(x)))
else:
logits = self.lm_head(x)
result = {"logits": logits}
if use_cache:
result["past_kv"] = new_kv
# Compute loss if labels provided
if labels is not None:
# Shift for next-token prediction
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1),
ignore_index=-100,
)
# Auxiliary skip-ahead prediction
if self.skip_head is not None:
skip_k = getattr(self.config, 'aux_skip_k', 0)
skip_weight = getattr(self.config, 'aux_skip_weight', 0.1)
if self.skip_head_down is not None:
skip_logits = self.skip_head(F.silu(self.skip_head_down(x)))[:, :-skip_k, :].contiguous()
else:
skip_logits = self.skip_head(x)[:, :-skip_k, :].contiguous()
skip_labels = labels[:, skip_k:].contiguous()
aux_loss = F.cross_entropy(
skip_logits.view(-1, self.config.vocab_size),
skip_labels.view(-1),
ignore_index=-100,
)
result["aux_loss"] = aux_loss
loss = loss + skip_weight * aux_loss
result["loss"] = loss
return result
# === Extension hooks for future experiments ===
def freeze_layers(self, indices: list[int]) -> None:
"""Freeze specific layers (stop gradients)."""
for idx in indices:
if 0 <= idx < len(self.layers):
for param in self.layers[idx].parameters():
param.requires_grad = False
self._frozen_layers.add(idx)
def unfreeze_layers(self, indices: list[int] | None = None) -> None:
"""Unfreeze specific layers (or all if indices=None)."""
if indices is None:
indices = list(self._frozen_layers)
for idx in indices:
if 0 <= idx < len(self.layers):
for param in self.layers[idx].parameters():
param.requires_grad = True
self._frozen_layers.discard(idx)
def get_layer_outputs(self, input_ids: torch.Tensor) -> list[torch.Tensor]:
"""Get intermediate outputs from each layer for interpretability."""
outputs = []
x = self.embed(input_ids)
if self.embed_proj is not None:
x = F.silu(self.embed_proj(x))
x = x * self.embed_scale
for layer in self.layers:
x, _ = layer(x, use_cache=False, past_kv=None)
outputs.append(x.clone())
return outputs
@torch.no_grad()
def generate(
self,
prompt_ids: torch.Tensor,
max_new_tokens: int = 50,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 0.9,
use_cache: bool = True,
word_start_table: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Autoregressive generation with KV caching.
Args:
prompt_ids: [B, L] prompt token IDs
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_k: Top-k filtering
top_p: Nucleus sampling threshold
use_cache: Use KV cache for faster generation
word_start_table: [vocab_size] bool tensor for word-position RoPE
Returns:
[B, L + max_new_tokens] generated token IDs
"""
from .layers import compute_word_positions
self.eval()
generated = prompt_ids.clone()
past_kv = None
word_pos_counter = 0 # Track word position during cached generation
for _ in range(max_new_tokens):
# Get input (full sequence or just last token with cache)
if use_cache and past_kv is not None:
input_ids = generated[:, -1:]
# Compute word position for the single new token
if word_start_table is not None:
last_token = generated[0, -1].item()
if word_start_table[last_token]:
word_pos_counter = 0
else:
word_pos_counter += 1
word_positions = torch.tensor([[float(word_pos_counter)]], device=input_ids.device)
else:
word_positions = None
else:
input_ids = generated
# Compute word positions for full sequence
if word_start_table is not None:
word_positions = compute_word_positions(input_ids, word_start_table)
else:
word_positions = None
# Forward pass
output = self(input_ids, use_cache=use_cache, past_kv=past_kv, word_positions=word_positions)
logits = output["logits"][:, -1, :] # Last position
if use_cache:
past_kv = output["past_kv"]
# Apply temperature
if temperature > 0:
logits = logits / temperature
# Top-k filtering
if top_k > 0:
top_k_vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
min_top_k = top_k_vals[:, -1].unsqueeze(-1)
logits = torch.where(logits < min_top_k, float("-inf"), logits)
# Top-p (nucleus) filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumsum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative prob above threshold
sorted_indices_to_remove = cumsum_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
logits = logits.masked_fill(indices_to_remove, float("-inf"))
# Sample
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
# Greedy
next_token = logits.argmax(dim=-1, keepdim=True)
generated = torch.cat([generated, next_token], dim=1)
# Stop if max length reached
if generated.size(1) >= self.config.max_seq_len:
break
return generated
def count_parameters(model: CircuitTransformer) -> int:
"""Count trainable parameters."""
return sum(p.numel() for p in model.parameters() if p.requires_grad)