""" Mirrored Transformer: Weight-sharing between expand and compress phases. Based on the biconcave lens hypothesis from grafting research: - Early layers expand from tokens to semantic space - Late layers compress from semantic space back to tokens - These phases share structural computation (W₁, W₂) - Only the gate (semiotic filter) differs by direction Architecture: y = W₂ @ (W₁ @ x ⊙ swish(W₃ @ swish(W₄ @ x))) Both gates fire every pass (additive, OR-logic). W₁ computed once. W₁, W₂ shared between mirror pairs. W₃, W₄ are dual gates. ~33% FFN parameter savings per mirrored pair vs standard SwiGLU. """ import torch import torch.nn as nn import torch.nn.functional as F import math from dataclasses import dataclass, fields from .layers import RMSNorm, CausalAttention, SwiGLU @dataclass class MirroredConfig: """Configuration for Mirrored Transformer.""" vocab_size: int = 50257 hidden_size: int = 768 num_heads: int = 12 num_kv_heads: int | None = None # GQA: None = same as num_heads (MHA) num_layers: int = 12 # effective depth (expand + middle + compress) n_middle: int = 2 # unique middle layers (standard SwiGLU) max_seq_len: int = 512 dropout: float = 0.0 aux_skip_k: int = 0 # skip-ahead prediction distance (0 = disabled) aux_skip_weight: float = 0.1 # weight for auxiliary skip loss use_g2lu: bool = True # G²LU nested gates (False = vanilla SwiGLU) word_rope_dims: int = 0 # head dims for word-position RoPE (0 = disabled) word_rope_base: float = 10.0 # frequency base for word-position RoPE embed_dim: int = 0 # factorized embedding dim (0 = use hidden_size) head_dim: int = 0 # MLP head intermediate dim (0 = linear head) def __post_init__(self): assert self.hidden_size % self.num_heads == 0, "hidden_size must be divisible by num_heads" if self.num_kv_heads is not None: assert self.num_heads % self.num_kv_heads == 0, \ f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})" n_mirror_layers = self.num_layers - self.n_middle assert n_mirror_layers > 0, "num_layers must be greater than n_middle" assert n_mirror_layers % 2 == 0, "num_layers - n_middle must be even" self.n_mirror = n_mirror_layers // 2 def to_dict(self) -> dict: """Convert to dictionary for serialization.""" return {f.name: getattr(self, f.name) for f in fields(self) if f.name != "n_mirror"} @classmethod def from_dict(cls, d: dict) -> "MirroredConfig": """Create from dictionary.""" valid = {f.name for f in fields(cls)} filtered = {k: v for k, v in d.items() if k in valid} return cls(**filtered) class MLP(nn.Module): """Feed-forward network with SiLU activation.""" def __init__(self, dim, intermediate_size, dropout): super().__init__() self.up_proj = nn.Linear(dim, intermediate_size, bias=False) self.gate_proj = nn.Linear(dim, intermediate_size, bias=False) self.down_proj = nn.Linear(intermediate_size, dim, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))) class MirroredSwiGLU(nn.Module): """SwiGLU with shared base weights and dual gates. Standard SwiGLU: y = W₂(silu(W₁x) ⊙ W₃x) — 3 matrices Mirrored SwiGLU: y = W₂(W₁x ⊙ (silu(W₃ ⊙ silu(W₄x)))) — 2 shared + 2 gates W₁ computed once, reused for both branches. """ def __init__(self, hidden_size: int, intermediate_size: int | None = None, gate_mode: str = 'additive', use_g2lu: bool = True): super().__init__() self.gate_mode = gate_mode self.use_g2lu = use_g2lu self._current_step = 0 intermediate_size = intermediate_size or int(hidden_size * 8 / 3) intermediate_size = ((intermediate_size + 63) // 64) * 64 # Shared structural transform self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False) self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False) # Gate(s) self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False) if use_g2lu: self.w4 = nn.Linear(hidden_size, intermediate_size, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: hidden = self.w1(x) if self.use_g2lu: g4 = F.silu(self.w4(x)) g3 = F.silu(self.w3(x) * g4) else: g3 = F.silu(self.w3(x)) return self.w2(hidden * g3) class MirroredBlock(nn.Module): """Transformer block with shared weights for expand/compress phases. Each MirroredBlock is used TWICE in the forward pass: once during expand (building semantics) and once during compress (encoding output). Shared: attention weights (optional), FFN W₁/W₂ Separate: norms (different residual stream statistics), FFN gate """ def __init__(self, hidden_size: int, num_heads: int, num_kv_heads: int | None = None, max_seq_len: int = 2048, dropout: float = 0.0, window_size: int | None = None, gate_mode: str = 'additive', word_rope_dims: int = 0, word_rope_base: float = 10.0, use_g2lu: bool = True): super().__init__() self.attn = CausalAttention(hidden_size, num_heads, num_kv_heads, max_seq_len, dropout, window_size=window_size, word_rope_dims=word_rope_dims, word_rope_base=word_rope_base) # FFN with shared base + direction-specific gates self.ffn = MirroredSwiGLU(hidden_size, gate_mode=gate_mode, use_g2lu=use_g2lu) # Separate norms per direction (residual stream statistics differ) self.expand_attn_norm = RMSNorm(hidden_size) self.expand_ffn_norm = RMSNorm(hidden_size) self.compress_attn_norm = RMSNorm(hidden_size) self.compress_ffn_norm = RMSNorm(hidden_size) def forward(self, x: torch.Tensor, use_cache: bool = False, past_kv: tuple = None, word_positions: torch.Tensor | None = None) -> tuple: attn_norm = self.compress_attn_norm ffn_norm = self.compress_ffn_norm attn = self.attn attn_out, new_kv = attn(attn_norm(x), use_cache, past_kv, word_positions=word_positions) x = x + attn_out x = x + self.ffn(ffn_norm(x)) return x, new_kv class MiddleBlock(nn.Module): """Standard transformer block for unique middle layers. When gate_mode is provided, uses MirroredSwiGLU (dual-gate) instead of single-gate SwiGLU — giving the middle the same rich gating geometry as the mirror pairs. """ def __init__(self, hidden_size: int, num_heads: int, num_kv_heads: int | None = None, max_seq_len: int = 2048, dropout: float = 0.0, word_rope_dims: int = 0, word_rope_base: float = 10.0, use_g2lu: bool = True): super().__init__() self.attn_norm = RMSNorm(hidden_size) self.attn = CausalAttention(hidden_size, num_heads, num_kv_heads, max_seq_len, dropout, word_rope_dims=word_rope_dims, word_rope_base=word_rope_base) self.ffn_norm = RMSNorm(hidden_size) self.ffn = MirroredSwiGLU(hidden_size, use_g2lu=use_g2lu) def forward(self, x: torch.Tensor, use_cache: bool = False, past_kv: tuple = None, word_positions: torch.Tensor | None = None) -> tuple: attn_out, new_kv = self.attn(self.attn_norm(x), use_cache, past_kv, word_positions=word_positions) x = x + attn_out x = x + self.ffn(self.ffn_norm(x)) return x, new_kv class MirroredTransformer(nn.Module): """Transformer with mirrored expand/compress architecture. Forward pass: 1. Embed tokens 2. Expand phase: mirror_blocks[0..N] with w3 3. Middle: unique standard blocks 4. Compress phase: mirror_blocks[N..0] (reversed) with w4 5. Norm + LM head For a 12-layer model with n_middle=2: - 5 mirror pairs (10 virtual layers) + 2 middle = 12 effective layers - Expand: blocks[0] → blocks[4] - Middle: middle[0] → middle[1] - Compress: blocks[4] → blocks[0] """ def __init__(self, config: MirroredConfig): super().__init__() self.config = config # Token embeddings (optionally factorized) embed_dim = getattr(config, 'embed_dim', 0) head_dim = getattr(config, 'head_dim', 0) # Auto-mirror factorization: head uses embed_dim for weight tying if embed_dim > 0 and head_dim == 0: head_dim = embed_dim # G²LU config (needed before projection setup) use_g2lu = getattr(config, 'use_g2lu', True) if embed_dim > 0: self.embed = nn.Embedding(config.vocab_size, embed_dim) self.embed_proj = nn.Linear(embed_dim, config.hidden_size, bias=False) # G²LU gates for up-projection (consistent with mirror blocks) if use_g2lu: self.embed_g3 = nn.Linear(embed_dim, config.hidden_size, bias=False) self.embed_g4 = nn.Linear(embed_dim, config.hidden_size, bias=False) else: self.embed_g3 = None self.embed_g4 = None else: self.embed = nn.Embedding(config.vocab_size, config.hidden_size) self.embed_proj = None self.embed_g3 = None self.embed_g4 = None self.embed_scale = math.sqrt(config.hidden_size) self.window_sizes = [None] * config.n_mirror # Word-position RoPE config word_rope_dims = getattr(config, 'word_rope_dims', 0) word_rope_base = getattr(config, 'word_rope_base', 10.0) # Mirrored blocks (used in both expand and compress phases) self.mirror_blocks = nn.ModuleList([ MirroredBlock( config.hidden_size, config.num_heads, config.num_kv_heads, config.max_seq_len, config.dropout, window_size=self.window_sizes[i], word_rope_dims=word_rope_dims, word_rope_base=word_rope_base, use_g2lu=use_g2lu, ) for i in range(config.n_mirror) ]) # Unique middle blocks (standard transformer, optionally dual-gated) self.middle_blocks = nn.ModuleList([ MiddleBlock(config.hidden_size, config.num_heads, config.num_kv_heads, config.max_seq_len, config.dropout, word_rope_dims=word_rope_dims, word_rope_base=word_rope_base, use_g2lu=use_g2lu) for _ in range(config.n_middle) ]) # Output (optionally MLP head) self.norm = RMSNorm(config.hidden_size) if head_dim > 0: self.head_down = nn.Linear(config.hidden_size, head_dim, bias=False) self.lm_head = nn.Linear(head_dim, config.vocab_size, bias=False) # G²LU gates for down-projection if use_g2lu: self.head_g3 = nn.Linear(config.hidden_size, head_dim, bias=False) self.head_g4 = nn.Linear(config.hidden_size, head_dim, bias=False) else: self.head_g3 = None self.head_g4 = None else: self.head_down = None self.head_g3 = None self.head_g4 = None self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Weight tying (when embed and lm_head dimensions match) _e = embed_dim if embed_dim > 0 else config.hidden_size _h = head_dim if head_dim > 0 else config.hidden_size if _e == _h: self.lm_head.weight = self.embed.weight # Auxiliary skip-ahead prediction head self.skip_head = None self.skip_head_down = None self.skip_g3 = None self.skip_g4 = None if config.aux_skip_k > 0: if head_dim > 0: self.skip_head_down = nn.Linear(config.hidden_size, head_dim, bias=False) self.skip_head = nn.Linear(head_dim, config.vocab_size, bias=False) if use_g2lu: self.skip_g3 = nn.Linear(config.hidden_size, head_dim, bias=False) self.skip_g4 = nn.Linear(config.hidden_size, head_dim, bias=False) else: self.skip_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) @property def total_virtual_layers(self) -> int: """Total number of virtual layers in the forward pass.""" return self.config.n_mirror * 2 + self.config.n_middle def forward( self, input_ids: torch.Tensor, labels: torch.Tensor = None, use_cache: bool = False, past_kv: list = None, word_positions: torch.Tensor | None = None, ) -> dict: B, L = input_ids.shape # Embed tokens (optionally factorized, with G²LU gating) x = self.embed(input_ids) if self.embed_proj is not None: if self.embed_g3 is not None: g4 = F.silu(self.embed_g4(x)) g3 = F.silu(self.embed_g3(x) * g4) x = self.embed_proj(x) * g3 else: x = F.silu(self.embed_proj(x)) x = x * self.embed_scale new_kv = [] if use_cache else None kv_idx = 0 # === Expand phase === for block in self.mirror_blocks: layer_past = past_kv[kv_idx] if past_kv is not None else None x, kv = block(x, use_cache=use_cache, past_kv=layer_past, word_positions=word_positions) if use_cache: new_kv.append(kv) kv_idx += 1 # === Dual-path: save pre-middle state for alignment loss === for block in self.middle_blocks: layer_past = past_kv[kv_idx] if past_kv is not None else None x, kv = block(x, use_cache=use_cache, past_kv=layer_past, word_positions=word_positions) if use_cache: new_kv.append(kv) kv_idx += 1 # === Compress phase (reversed order) === for i in reversed(range(len(self.mirror_blocks))): layer_past = past_kv[kv_idx] if past_kv is not None else None x, kv = self.mirror_blocks[i](x, use_cache=use_cache, past_kv=layer_past, word_positions=word_positions) if use_cache: new_kv.append(kv) kv_idx += 1 # === Output (optionally MLP head with G²LU gating) === x = self.norm(x) if self.head_down is not None: if self.head_g3 is not None: g4 = F.silu(self.head_g4(x)) g3 = F.silu(self.head_g3(x) * g4) logits = self.lm_head(self.head_down(x) * g3) else: logits = self.lm_head(F.silu(self.head_down(x))) else: logits = self.lm_head(x) result = {"logits": logits} if use_cache: result["past_kv"] = new_kv if labels is not None: shift_logits = logits[:, :-1, :].contiguous() shift_labels = labels[:, 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1), ignore_index=-100 ) if self.skip_head is not None: skip_k = self.config.aux_skip_k if self.skip_head_down is not None: if self.skip_g3 is not None: g4 = F.silu(self.skip_g4(x)) g3 = F.silu(self.skip_g3(x) * g4) skip_logits = self.skip_head(self.skip_head_down(x) * g3)[:, :-skip_k, :].contiguous() else: skip_logits = self.skip_head(F.silu(self.skip_head_down(x)))[:, :-skip_k, :].contiguous() else: skip_logits = self.skip_head(x)[:, :-skip_k, :].contiguous() skip_labels = labels[:, skip_k:].contiguous() aux_loss = F.cross_entropy( skip_logits.view(-1, self.config.vocab_size), skip_labels.view(-1), ignore_index=-100 ) result["aux_loss"] = aux_loss loss = loss + self.config.aux_skip_weight * aux_loss result["loss"] = loss return result @torch.no_grad() def generate( self, prompt_ids: torch.Tensor, max_new_tokens: int = 50, temperature: float = 0.8, top_k: int = 50, top_p: float = 0.9, use_cache: bool = True, word_start_table: torch.Tensor | None = None, ) -> torch.Tensor: """Autoregressive generation with KV caching.""" from .layers import compute_word_positions self.eval() generated = prompt_ids.clone() past_kv = None word_pos_counter = 0 for _ in range(max_new_tokens): if use_cache and past_kv is not None: input_ids = generated[:, -1:] if word_start_table is not None: last_token = generated[0, -1].item() if word_start_table[last_token]: word_pos_counter = 0 else: word_pos_counter += 1 word_positions = torch.tensor([[float(word_pos_counter)]], device=input_ids.device) else: word_positions = None else: input_ids = generated if word_start_table is not None: word_positions = compute_word_positions(input_ids, word_start_table) else: word_positions = None output = self(input_ids, use_cache=use_cache, past_kv=past_kv, word_positions=word_positions) logits = output["logits"][:, -1, :] if use_cache: past_kv = output["past_kv"] if temperature > 0: logits = logits / temperature if top_k > 0: top_k_vals, _ = torch.topk(logits, min(top_k, logits.size(-1))) min_top_k = top_k_vals[:, -1].unsqueeze(-1) logits = torch.where(logits < min_top_k, float("-inf"), logits) if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumsum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumsum_probs > top_p sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() sorted_indices_to_remove[:, 0] = False indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits = logits.masked_fill(indices_to_remove, float("-inf")) probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = logits.argmax(dim=-1, keepdim=True) generated = torch.cat([generated, next_token], dim=1) if generated.size(1) >= self.config.max_seq_len: break return generated def count_mirrored_parameters(model: MirroredTransformer) -> dict: """Count parameters with breakdown by component.""" total = sum(p.numel() for p in model.parameters() if p.requires_grad) # Unique params (not double-counted from weight tying) unique = sum(p.numel() for p in set(p for p in model.parameters() if p.requires_grad)) mirror_params = sum(p.numel() for p in model.mirror_blocks.parameters()) middle_params = sum(p.numel() for p in model.middle_blocks.parameters()) embed_params = model.embed.weight.numel() if model.embed_proj is not None: embed_params += model.embed_proj.weight.numel() head_params = 0 if model.head_down is not None: head_params += model.head_down.weight.numel() head_params += model.lm_head.weight.numel() # Break down mirror block into shared vs direction-specific shared_attn = 0 shared_ffn_base = 0 gate_params = 0 norm_params = 0 for block in model.mirror_blocks: shared_attn += sum(p.numel() for p in block.attn.parameters()) shared_ffn_base += block.ffn.w1.weight.numel() + block.ffn.w2.weight.numel() gate_params += block.ffn.w3.weight.numel() if hasattr(block.ffn, 'w4'): gate_params += block.ffn.w4.weight.numel() norm_params += sum(p.numel() for n, p in block.named_parameters() if 'norm' in n) return { "total": total, "unique": unique, "mirror_blocks": mirror_params, "middle_blocks": middle_params, "embedding": embed_params, "head": head_params, "shared_attention": shared_attn, "shared_ffn_base": shared_ffn_base, "direction_gates": gate_params, "norms": norm_params, }