Prisma / mirrored.py
y3i12's picture
Initial commit
56e82ec
"""
Mirrored Transformer: Weight-sharing between expand and compress phases.
Based on the biconcave lens hypothesis from grafting research:
- Early layers expand from tokens to semantic space
- Late layers compress from semantic space back to tokens
- These phases share structural computation (W₁, W₂)
- Only the gate (semiotic filter) differs by direction
Architecture:
y = W₂ @ (W₁ @ x ⊙ swish(W₃ @ swish(W₄ @ x)))
Both gates fire every pass (additive, OR-logic). W₁ computed once.
W₁, W₂ shared between mirror pairs. W₃, W₄ are dual gates.
~33% FFN parameter savings per mirrored pair vs standard SwiGLU.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass, fields
from .layers import RMSNorm, CausalAttention, SwiGLU
@dataclass
class MirroredConfig:
"""Configuration for Mirrored Transformer."""
vocab_size: int = 50257
hidden_size: int = 768
num_heads: int = 12
num_kv_heads: int | None = None # GQA: None = same as num_heads (MHA)
num_layers: int = 12 # effective depth (expand + middle + compress)
n_middle: int = 2 # unique middle layers (standard SwiGLU)
max_seq_len: int = 512
dropout: float = 0.0
aux_skip_k: int = 0 # skip-ahead prediction distance (0 = disabled)
aux_skip_weight: float = 0.1 # weight for auxiliary skip loss
use_g2lu: bool = True # G²LU nested gates (False = vanilla SwiGLU)
word_rope_dims: int = 0 # head dims for word-position RoPE (0 = disabled)
word_rope_base: float = 10.0 # frequency base for word-position RoPE
embed_dim: int = 0 # factorized embedding dim (0 = use hidden_size)
head_dim: int = 0 # MLP head intermediate dim (0 = linear head)
def __post_init__(self):
assert self.hidden_size % self.num_heads == 0, "hidden_size must be divisible by num_heads"
if self.num_kv_heads is not None:
assert self.num_heads % self.num_kv_heads == 0, \
f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})"
n_mirror_layers = self.num_layers - self.n_middle
assert n_mirror_layers > 0, "num_layers must be greater than n_middle"
assert n_mirror_layers % 2 == 0, "num_layers - n_middle must be even"
self.n_mirror = n_mirror_layers // 2
def to_dict(self) -> dict:
"""Convert to dictionary for serialization."""
return {f.name: getattr(self, f.name) for f in fields(self) if f.name != "n_mirror"}
@classmethod
def from_dict(cls, d: dict) -> "MirroredConfig":
"""Create from dictionary."""
valid = {f.name for f in fields(cls)}
filtered = {k: v for k, v in d.items() if k in valid}
return cls(**filtered)
class MLP(nn.Module):
"""Feed-forward network with SiLU activation."""
def __init__(self, dim, intermediate_size, dropout):
super().__init__()
self.up_proj = nn.Linear(dim, intermediate_size, bias=False)
self.gate_proj = nn.Linear(dim, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, dim, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)))
class MirroredSwiGLU(nn.Module):
"""SwiGLU with shared base weights and dual gates.
Standard SwiGLU: y = W₂(silu(W₁x) ⊙ W₃x) — 3 matrices
Mirrored SwiGLU: y = W₂(W₁x ⊙ (silu(W₃ ⊙ silu(W₄x)))) — 2 shared + 2 gates
W₁ computed once, reused for both branches.
"""
def __init__(self, hidden_size: int, intermediate_size: int | None = None,
gate_mode: str = 'additive', use_g2lu: bool = True):
super().__init__()
self.gate_mode = gate_mode
self.use_g2lu = use_g2lu
self._current_step = 0
intermediate_size = intermediate_size or int(hidden_size * 8 / 3)
intermediate_size = ((intermediate_size + 63) // 64) * 64
# Shared structural transform
self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False)
self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False)
# Gate(s)
self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False)
if use_g2lu:
self.w4 = nn.Linear(hidden_size, intermediate_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
hidden = self.w1(x)
if self.use_g2lu:
g4 = F.silu(self.w4(x))
g3 = F.silu(self.w3(x) * g4)
else:
g3 = F.silu(self.w3(x))
return self.w2(hidden * g3)
class MirroredBlock(nn.Module):
"""Transformer block with shared weights for expand/compress phases.
Each MirroredBlock is used TWICE in the forward pass:
once during expand (building semantics) and once during compress (encoding output).
Shared: attention weights (optional), FFN W₁/W₂
Separate: norms (different residual stream statistics), FFN gate
"""
def __init__(self, hidden_size: int, num_heads: int, num_kv_heads: int | None = None,
max_seq_len: int = 2048,
dropout: float = 0.0,
window_size: int | None = None, gate_mode: str = 'additive',
word_rope_dims: int = 0, word_rope_base: float = 10.0,
use_g2lu: bool = True):
super().__init__()
self.attn = CausalAttention(hidden_size, num_heads, num_kv_heads, max_seq_len, dropout, window_size=window_size,
word_rope_dims=word_rope_dims, word_rope_base=word_rope_base)
# FFN with shared base + direction-specific gates
self.ffn = MirroredSwiGLU(hidden_size, gate_mode=gate_mode, use_g2lu=use_g2lu)
# Separate norms per direction (residual stream statistics differ)
self.expand_attn_norm = RMSNorm(hidden_size)
self.expand_ffn_norm = RMSNorm(hidden_size)
self.compress_attn_norm = RMSNorm(hidden_size)
self.compress_ffn_norm = RMSNorm(hidden_size)
def forward(self, x: torch.Tensor, use_cache: bool = False, past_kv: tuple = None,
word_positions: torch.Tensor | None = None) -> tuple:
attn_norm = self.compress_attn_norm
ffn_norm = self.compress_ffn_norm
attn = self.attn
attn_out, new_kv = attn(attn_norm(x), use_cache, past_kv, word_positions=word_positions)
x = x + attn_out
x = x + self.ffn(ffn_norm(x))
return x, new_kv
class MiddleBlock(nn.Module):
"""Standard transformer block for unique middle layers.
When gate_mode is provided, uses MirroredSwiGLU (dual-gate) instead of
single-gate SwiGLU — giving the middle the same rich gating geometry
as the mirror pairs.
"""
def __init__(self, hidden_size: int, num_heads: int, num_kv_heads: int | None = None,
max_seq_len: int = 2048,
dropout: float = 0.0,
word_rope_dims: int = 0, word_rope_base: float = 10.0,
use_g2lu: bool = True):
super().__init__()
self.attn_norm = RMSNorm(hidden_size)
self.attn = CausalAttention(hidden_size, num_heads, num_kv_heads, max_seq_len, dropout,
word_rope_dims=word_rope_dims, word_rope_base=word_rope_base)
self.ffn_norm = RMSNorm(hidden_size)
self.ffn = MirroredSwiGLU(hidden_size, use_g2lu=use_g2lu)
def forward(self, x: torch.Tensor, use_cache: bool = False, past_kv: tuple = None,
word_positions: torch.Tensor | None = None) -> tuple:
attn_out, new_kv = self.attn(self.attn_norm(x), use_cache, past_kv, word_positions=word_positions)
x = x + attn_out
x = x + self.ffn(self.ffn_norm(x))
return x, new_kv
class MirroredTransformer(nn.Module):
"""Transformer with mirrored expand/compress architecture.
Forward pass:
1. Embed tokens
2. Expand phase: mirror_blocks[0..N] with w3
3. Middle: unique standard blocks
4. Compress phase: mirror_blocks[N..0] (reversed) with w4
5. Norm + LM head
For a 12-layer model with n_middle=2:
- 5 mirror pairs (10 virtual layers) + 2 middle = 12 effective layers
- Expand: blocks[0] → blocks[4]
- Middle: middle[0] → middle[1]
- Compress: blocks[4] → blocks[0]
"""
def __init__(self, config: MirroredConfig):
super().__init__()
self.config = config
# Token embeddings (optionally factorized)
embed_dim = getattr(config, 'embed_dim', 0)
head_dim = getattr(config, 'head_dim', 0)
# Auto-mirror factorization: head uses embed_dim for weight tying
if embed_dim > 0 and head_dim == 0:
head_dim = embed_dim
# G²LU config (needed before projection setup)
use_g2lu = getattr(config, 'use_g2lu', True)
if embed_dim > 0:
self.embed = nn.Embedding(config.vocab_size, embed_dim)
self.embed_proj = nn.Linear(embed_dim, config.hidden_size, bias=False)
# G²LU gates for up-projection (consistent with mirror blocks)
if use_g2lu:
self.embed_g3 = nn.Linear(embed_dim, config.hidden_size, bias=False)
self.embed_g4 = nn.Linear(embed_dim, config.hidden_size, bias=False)
else:
self.embed_g3 = None
self.embed_g4 = None
else:
self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
self.embed_proj = None
self.embed_g3 = None
self.embed_g4 = None
self.embed_scale = math.sqrt(config.hidden_size)
self.window_sizes = [None] * config.n_mirror
# Word-position RoPE config
word_rope_dims = getattr(config, 'word_rope_dims', 0)
word_rope_base = getattr(config, 'word_rope_base', 10.0)
# Mirrored blocks (used in both expand and compress phases)
self.mirror_blocks = nn.ModuleList([
MirroredBlock(
config.hidden_size, config.num_heads, config.num_kv_heads,
config.max_seq_len,
config.dropout,
window_size=self.window_sizes[i],
word_rope_dims=word_rope_dims, word_rope_base=word_rope_base,
use_g2lu=use_g2lu,
)
for i in range(config.n_mirror)
])
# Unique middle blocks (standard transformer, optionally dual-gated)
self.middle_blocks = nn.ModuleList([
MiddleBlock(config.hidden_size, config.num_heads, config.num_kv_heads,
config.max_seq_len, config.dropout,
word_rope_dims=word_rope_dims, word_rope_base=word_rope_base,
use_g2lu=use_g2lu)
for _ in range(config.n_middle)
])
# Output (optionally MLP head)
self.norm = RMSNorm(config.hidden_size)
if head_dim > 0:
self.head_down = nn.Linear(config.hidden_size, head_dim, bias=False)
self.lm_head = nn.Linear(head_dim, config.vocab_size, bias=False)
# G²LU gates for down-projection
if use_g2lu:
self.head_g3 = nn.Linear(config.hidden_size, head_dim, bias=False)
self.head_g4 = nn.Linear(config.hidden_size, head_dim, bias=False)
else:
self.head_g3 = None
self.head_g4 = None
else:
self.head_down = None
self.head_g3 = None
self.head_g4 = None
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Weight tying (when embed and lm_head dimensions match)
_e = embed_dim if embed_dim > 0 else config.hidden_size
_h = head_dim if head_dim > 0 else config.hidden_size
if _e == _h:
self.lm_head.weight = self.embed.weight
# Auxiliary skip-ahead prediction head
self.skip_head = None
self.skip_head_down = None
self.skip_g3 = None
self.skip_g4 = None
if config.aux_skip_k > 0:
if head_dim > 0:
self.skip_head_down = nn.Linear(config.hidden_size, head_dim, bias=False)
self.skip_head = nn.Linear(head_dim, config.vocab_size, bias=False)
if use_g2lu:
self.skip_g3 = nn.Linear(config.hidden_size, head_dim, bias=False)
self.skip_g4 = nn.Linear(config.hidden_size, head_dim, bias=False)
else:
self.skip_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
@property
def total_virtual_layers(self) -> int:
"""Total number of virtual layers in the forward pass."""
return self.config.n_mirror * 2 + self.config.n_middle
def forward(
self,
input_ids: torch.Tensor,
labels: torch.Tensor = None,
use_cache: bool = False,
past_kv: list = None,
word_positions: torch.Tensor | None = None,
) -> dict:
B, L = input_ids.shape
# Embed tokens (optionally factorized, with G²LU gating)
x = self.embed(input_ids)
if self.embed_proj is not None:
if self.embed_g3 is not None:
g4 = F.silu(self.embed_g4(x))
g3 = F.silu(self.embed_g3(x) * g4)
x = self.embed_proj(x) * g3
else:
x = F.silu(self.embed_proj(x))
x = x * self.embed_scale
new_kv = [] if use_cache else None
kv_idx = 0
# === Expand phase ===
for block in self.mirror_blocks:
layer_past = past_kv[kv_idx] if past_kv is not None else None
x, kv = block(x, use_cache=use_cache, past_kv=layer_past, word_positions=word_positions)
if use_cache:
new_kv.append(kv)
kv_idx += 1
# === Dual-path: save pre-middle state for alignment loss ===
for block in self.middle_blocks:
layer_past = past_kv[kv_idx] if past_kv is not None else None
x, kv = block(x, use_cache=use_cache, past_kv=layer_past, word_positions=word_positions)
if use_cache:
new_kv.append(kv)
kv_idx += 1
# === Compress phase (reversed order) ===
for i in reversed(range(len(self.mirror_blocks))):
layer_past = past_kv[kv_idx] if past_kv is not None else None
x, kv = self.mirror_blocks[i](x, use_cache=use_cache, past_kv=layer_past, word_positions=word_positions)
if use_cache:
new_kv.append(kv)
kv_idx += 1
# === Output (optionally MLP head with G²LU gating) ===
x = self.norm(x)
if self.head_down is not None:
if self.head_g3 is not None:
g4 = F.silu(self.head_g4(x))
g3 = F.silu(self.head_g3(x) * g4)
logits = self.lm_head(self.head_down(x) * g3)
else:
logits = self.lm_head(F.silu(self.head_down(x)))
else:
logits = self.lm_head(x)
result = {"logits": logits}
if use_cache:
result["past_kv"] = new_kv
if labels is not None:
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1),
ignore_index=-100
)
if self.skip_head is not None:
skip_k = self.config.aux_skip_k
if self.skip_head_down is not None:
if self.skip_g3 is not None:
g4 = F.silu(self.skip_g4(x))
g3 = F.silu(self.skip_g3(x) * g4)
skip_logits = self.skip_head(self.skip_head_down(x) * g3)[:, :-skip_k, :].contiguous()
else:
skip_logits = self.skip_head(F.silu(self.skip_head_down(x)))[:, :-skip_k, :].contiguous()
else:
skip_logits = self.skip_head(x)[:, :-skip_k, :].contiguous()
skip_labels = labels[:, skip_k:].contiguous()
aux_loss = F.cross_entropy(
skip_logits.view(-1, self.config.vocab_size),
skip_labels.view(-1),
ignore_index=-100
)
result["aux_loss"] = aux_loss
loss = loss + self.config.aux_skip_weight * aux_loss
result["loss"] = loss
return result
@torch.no_grad()
def generate(
self,
prompt_ids: torch.Tensor,
max_new_tokens: int = 50,
temperature: float = 0.8,
top_k: int = 50,
top_p: float = 0.9,
use_cache: bool = True,
word_start_table: torch.Tensor | None = None,
) -> torch.Tensor:
"""Autoregressive generation with KV caching."""
from .layers import compute_word_positions
self.eval()
generated = prompt_ids.clone()
past_kv = None
word_pos_counter = 0
for _ in range(max_new_tokens):
if use_cache and past_kv is not None:
input_ids = generated[:, -1:]
if word_start_table is not None:
last_token = generated[0, -1].item()
if word_start_table[last_token]:
word_pos_counter = 0
else:
word_pos_counter += 1
word_positions = torch.tensor([[float(word_pos_counter)]], device=input_ids.device)
else:
word_positions = None
else:
input_ids = generated
if word_start_table is not None:
word_positions = compute_word_positions(input_ids, word_start_table)
else:
word_positions = None
output = self(input_ids, use_cache=use_cache, past_kv=past_kv, word_positions=word_positions)
logits = output["logits"][:, -1, :]
if use_cache:
past_kv = output["past_kv"]
if temperature > 0:
logits = logits / temperature
if top_k > 0:
top_k_vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
min_top_k = top_k_vals[:, -1].unsqueeze(-1)
logits = torch.where(logits < min_top_k, float("-inf"), logits)
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumsum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumsum_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = False
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, float("-inf"))
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = logits.argmax(dim=-1, keepdim=True)
generated = torch.cat([generated, next_token], dim=1)
if generated.size(1) >= self.config.max_seq_len:
break
return generated
def count_mirrored_parameters(model: MirroredTransformer) -> dict:
"""Count parameters with breakdown by component."""
total = sum(p.numel() for p in model.parameters() if p.requires_grad)
# Unique params (not double-counted from weight tying)
unique = sum(p.numel() for p in set(p for p in model.parameters() if p.requires_grad))
mirror_params = sum(p.numel() for p in model.mirror_blocks.parameters())
middle_params = sum(p.numel() for p in model.middle_blocks.parameters())
embed_params = model.embed.weight.numel()
if model.embed_proj is not None:
embed_params += model.embed_proj.weight.numel()
head_params = 0
if model.head_down is not None:
head_params += model.head_down.weight.numel()
head_params += model.lm_head.weight.numel()
# Break down mirror block into shared vs direction-specific
shared_attn = 0
shared_ffn_base = 0
gate_params = 0
norm_params = 0
for block in model.mirror_blocks:
shared_attn += sum(p.numel() for p in block.attn.parameters())
shared_ffn_base += block.ffn.w1.weight.numel() + block.ffn.w2.weight.numel()
gate_params += block.ffn.w3.weight.numel()
if hasattr(block.ffn, 'w4'):
gate_params += block.ffn.w4.weight.numel()
norm_params += sum(p.numel() for n, p in block.named_parameters() if 'norm' in n)
return {
"total": total,
"unique": unique,
"mirror_blocks": mirror_params,
"middle_blocks": middle_params,
"embedding": embed_params,
"head": head_params,
"shared_attention": shared_attn,
"shared_ffn_base": shared_ffn_base,
"direction_gates": gate_params,
"norms": norm_params,
}