| """
|
| Mirrored Transformer: Weight-sharing between expand and compress phases.
|
|
|
| Based on the biconcave lens hypothesis from grafting research:
|
| - Early layers expand from tokens to semantic space
|
| - Late layers compress from semantic space back to tokens
|
| - These phases share structural computation (W₁, W₂)
|
| - Only the gate (semiotic filter) differs by direction
|
|
|
| Architecture:
|
| y = W₂ @ (W₁ @ x ⊙ swish(W₃ @ swish(W₄ @ x)))
|
|
|
| Both gates fire every pass (additive, OR-logic). W₁ computed once.
|
| W₁, W₂ shared between mirror pairs. W₃, W₄ are dual gates.
|
| ~33% FFN parameter savings per mirrored pair vs standard SwiGLU.
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import math
|
| from dataclasses import dataclass, fields
|
|
|
| from .layers import RMSNorm, CausalAttention, SwiGLU
|
|
|
|
|
| @dataclass
|
| class MirroredConfig:
|
| """Configuration for Mirrored Transformer."""
|
| vocab_size: int = 50257
|
| hidden_size: int = 768
|
| num_heads: int = 12
|
| num_kv_heads: int | None = None
|
| num_layers: int = 12
|
| n_middle: int = 2
|
| max_seq_len: int = 512
|
| dropout: float = 0.0
|
| aux_skip_k: int = 0
|
| aux_skip_weight: float = 0.1
|
| use_g2lu: bool = True
|
| word_rope_dims: int = 0
|
| word_rope_base: float = 10.0
|
| embed_dim: int = 0
|
| head_dim: int = 0
|
|
|
| def __post_init__(self):
|
| assert self.hidden_size % self.num_heads == 0, "hidden_size must be divisible by num_heads"
|
| if self.num_kv_heads is not None:
|
| assert self.num_heads % self.num_kv_heads == 0, \
|
| f"num_heads ({self.num_heads}) must be divisible by num_kv_heads ({self.num_kv_heads})"
|
| n_mirror_layers = self.num_layers - self.n_middle
|
| assert n_mirror_layers > 0, "num_layers must be greater than n_middle"
|
| assert n_mirror_layers % 2 == 0, "num_layers - n_middle must be even"
|
| self.n_mirror = n_mirror_layers // 2
|
|
|
| def to_dict(self) -> dict:
|
| """Convert to dictionary for serialization."""
|
| return {f.name: getattr(self, f.name) for f in fields(self) if f.name != "n_mirror"}
|
|
|
| @classmethod
|
| def from_dict(cls, d: dict) -> "MirroredConfig":
|
| """Create from dictionary."""
|
| valid = {f.name for f in fields(cls)}
|
| filtered = {k: v for k, v in d.items() if k in valid}
|
| return cls(**filtered)
|
|
|
| class MLP(nn.Module):
|
| """Feed-forward network with SiLU activation."""
|
|
|
| def __init__(self, dim, intermediate_size, dropout):
|
| super().__init__()
|
| self.up_proj = nn.Linear(dim, intermediate_size, bias=False)
|
| self.gate_proj = nn.Linear(dim, intermediate_size, bias=False)
|
| self.down_proj = nn.Linear(intermediate_size, dim, bias=False)
|
| self.dropout = nn.Dropout(dropout)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)))
|
|
|
| class MirroredSwiGLU(nn.Module):
|
| """SwiGLU with shared base weights and dual gates.
|
|
|
| Standard SwiGLU: y = W₂(silu(W₁x) ⊙ W₃x) — 3 matrices
|
| Mirrored SwiGLU: y = W₂(W₁x ⊙ (silu(W₃ ⊙ silu(W₄x)))) — 2 shared + 2 gates
|
|
|
| W₁ computed once, reused for both branches.
|
| """
|
|
|
| def __init__(self, hidden_size: int, intermediate_size: int | None = None,
|
| gate_mode: str = 'additive', use_g2lu: bool = True):
|
| super().__init__()
|
| self.gate_mode = gate_mode
|
| self.use_g2lu = use_g2lu
|
| self._current_step = 0
|
| intermediate_size = intermediate_size or int(hidden_size * 8 / 3)
|
| intermediate_size = ((intermediate_size + 63) // 64) * 64
|
|
|
|
|
| self.w1 = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| self.w2 = nn.Linear(intermediate_size, hidden_size, bias=False)
|
|
|
|
|
| self.w3 = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| if use_g2lu:
|
| self.w4 = nn.Linear(hidden_size, intermediate_size, bias=False)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| hidden = self.w1(x)
|
| if self.use_g2lu:
|
| g4 = F.silu(self.w4(x))
|
| g3 = F.silu(self.w3(x) * g4)
|
| else:
|
| g3 = F.silu(self.w3(x))
|
| return self.w2(hidden * g3)
|
|
|
|
|
| class MirroredBlock(nn.Module):
|
| """Transformer block with shared weights for expand/compress phases.
|
|
|
| Each MirroredBlock is used TWICE in the forward pass:
|
| once during expand (building semantics) and once during compress (encoding output).
|
|
|
| Shared: attention weights (optional), FFN W₁/W₂
|
| Separate: norms (different residual stream statistics), FFN gate
|
| """
|
|
|
| def __init__(self, hidden_size: int, num_heads: int, num_kv_heads: int | None = None,
|
| max_seq_len: int = 2048,
|
| dropout: float = 0.0,
|
| window_size: int | None = None, gate_mode: str = 'additive',
|
| word_rope_dims: int = 0, word_rope_base: float = 10.0,
|
| use_g2lu: bool = True):
|
| super().__init__()
|
|
|
| self.attn = CausalAttention(hidden_size, num_heads, num_kv_heads, max_seq_len, dropout, window_size=window_size,
|
| word_rope_dims=word_rope_dims, word_rope_base=word_rope_base)
|
|
|
|
|
| self.ffn = MirroredSwiGLU(hidden_size, gate_mode=gate_mode, use_g2lu=use_g2lu)
|
|
|
|
|
| self.expand_attn_norm = RMSNorm(hidden_size)
|
| self.expand_ffn_norm = RMSNorm(hidden_size)
|
| self.compress_attn_norm = RMSNorm(hidden_size)
|
| self.compress_ffn_norm = RMSNorm(hidden_size)
|
|
|
| def forward(self, x: torch.Tensor, use_cache: bool = False, past_kv: tuple = None,
|
| word_positions: torch.Tensor | None = None) -> tuple:
|
| attn_norm = self.compress_attn_norm
|
| ffn_norm = self.compress_ffn_norm
|
| attn = self.attn
|
|
|
| attn_out, new_kv = attn(attn_norm(x), use_cache, past_kv, word_positions=word_positions)
|
| x = x + attn_out
|
| x = x + self.ffn(ffn_norm(x))
|
|
|
| return x, new_kv
|
|
|
|
|
| class MiddleBlock(nn.Module):
|
| """Standard transformer block for unique middle layers.
|
|
|
| When gate_mode is provided, uses MirroredSwiGLU (dual-gate) instead of
|
| single-gate SwiGLU — giving the middle the same rich gating geometry
|
| as the mirror pairs.
|
| """
|
|
|
| def __init__(self, hidden_size: int, num_heads: int, num_kv_heads: int | None = None,
|
| max_seq_len: int = 2048,
|
| dropout: float = 0.0,
|
| word_rope_dims: int = 0, word_rope_base: float = 10.0,
|
| use_g2lu: bool = True):
|
| super().__init__()
|
| self.attn_norm = RMSNorm(hidden_size)
|
| self.attn = CausalAttention(hidden_size, num_heads, num_kv_heads, max_seq_len, dropout,
|
| word_rope_dims=word_rope_dims, word_rope_base=word_rope_base)
|
| self.ffn_norm = RMSNorm(hidden_size)
|
| self.ffn = MirroredSwiGLU(hidden_size, use_g2lu=use_g2lu)
|
|
|
| def forward(self, x: torch.Tensor, use_cache: bool = False, past_kv: tuple = None,
|
| word_positions: torch.Tensor | None = None) -> tuple:
|
| attn_out, new_kv = self.attn(self.attn_norm(x), use_cache, past_kv, word_positions=word_positions)
|
| x = x + attn_out
|
| x = x + self.ffn(self.ffn_norm(x))
|
| return x, new_kv
|
|
|
|
|
| class MirroredTransformer(nn.Module):
|
| """Transformer with mirrored expand/compress architecture.
|
|
|
| Forward pass:
|
| 1. Embed tokens
|
| 2. Expand phase: mirror_blocks[0..N] with w3
|
| 3. Middle: unique standard blocks
|
| 4. Compress phase: mirror_blocks[N..0] (reversed) with w4
|
| 5. Norm + LM head
|
|
|
| For a 12-layer model with n_middle=2:
|
| - 5 mirror pairs (10 virtual layers) + 2 middle = 12 effective layers
|
| - Expand: blocks[0] → blocks[4]
|
| - Middle: middle[0] → middle[1]
|
| - Compress: blocks[4] → blocks[0]
|
| """
|
|
|
| def __init__(self, config: MirroredConfig):
|
| super().__init__()
|
| self.config = config
|
|
|
|
|
| embed_dim = getattr(config, 'embed_dim', 0)
|
| head_dim = getattr(config, 'head_dim', 0)
|
|
|
| if embed_dim > 0 and head_dim == 0:
|
| head_dim = embed_dim
|
|
|
|
|
| use_g2lu = getattr(config, 'use_g2lu', True)
|
|
|
| if embed_dim > 0:
|
| self.embed = nn.Embedding(config.vocab_size, embed_dim)
|
| self.embed_proj = nn.Linear(embed_dim, config.hidden_size, bias=False)
|
|
|
| if use_g2lu:
|
| self.embed_g3 = nn.Linear(embed_dim, config.hidden_size, bias=False)
|
| self.embed_g4 = nn.Linear(embed_dim, config.hidden_size, bias=False)
|
| else:
|
| self.embed_g3 = None
|
| self.embed_g4 = None
|
| else:
|
| self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
|
| self.embed_proj = None
|
| self.embed_g3 = None
|
| self.embed_g4 = None
|
| self.embed_scale = math.sqrt(config.hidden_size)
|
| self.window_sizes = [None] * config.n_mirror
|
|
|
|
|
| word_rope_dims = getattr(config, 'word_rope_dims', 0)
|
| word_rope_base = getattr(config, 'word_rope_base', 10.0)
|
|
|
|
|
| self.mirror_blocks = nn.ModuleList([
|
| MirroredBlock(
|
| config.hidden_size, config.num_heads, config.num_kv_heads,
|
| config.max_seq_len,
|
| config.dropout,
|
| window_size=self.window_sizes[i],
|
| word_rope_dims=word_rope_dims, word_rope_base=word_rope_base,
|
| use_g2lu=use_g2lu,
|
| )
|
| for i in range(config.n_mirror)
|
| ])
|
|
|
|
|
| self.middle_blocks = nn.ModuleList([
|
| MiddleBlock(config.hidden_size, config.num_heads, config.num_kv_heads,
|
| config.max_seq_len, config.dropout,
|
| word_rope_dims=word_rope_dims, word_rope_base=word_rope_base,
|
| use_g2lu=use_g2lu)
|
| for _ in range(config.n_middle)
|
| ])
|
|
|
|
|
| self.norm = RMSNorm(config.hidden_size)
|
| if head_dim > 0:
|
| self.head_down = nn.Linear(config.hidden_size, head_dim, bias=False)
|
| self.lm_head = nn.Linear(head_dim, config.vocab_size, bias=False)
|
|
|
| if use_g2lu:
|
| self.head_g3 = nn.Linear(config.hidden_size, head_dim, bias=False)
|
| self.head_g4 = nn.Linear(config.hidden_size, head_dim, bias=False)
|
| else:
|
| self.head_g3 = None
|
| self.head_g4 = None
|
| else:
|
| self.head_down = None
|
| self.head_g3 = None
|
| self.head_g4 = None
|
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
|
| _e = embed_dim if embed_dim > 0 else config.hidden_size
|
| _h = head_dim if head_dim > 0 else config.hidden_size
|
| if _e == _h:
|
| self.lm_head.weight = self.embed.weight
|
|
|
|
|
| self.skip_head = None
|
| self.skip_head_down = None
|
| self.skip_g3 = None
|
| self.skip_g4 = None
|
| if config.aux_skip_k > 0:
|
| if head_dim > 0:
|
| self.skip_head_down = nn.Linear(config.hidden_size, head_dim, bias=False)
|
| self.skip_head = nn.Linear(head_dim, config.vocab_size, bias=False)
|
| if use_g2lu:
|
| self.skip_g3 = nn.Linear(config.hidden_size, head_dim, bias=False)
|
| self.skip_g4 = nn.Linear(config.hidden_size, head_dim, bias=False)
|
| else:
|
| self.skip_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
|
| self.apply(self._init_weights)
|
|
|
| def _init_weights(self, module):
|
| if isinstance(module, nn.Linear):
|
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| if module.bias is not None:
|
| torch.nn.init.zeros_(module.bias)
|
| elif isinstance(module, nn.Embedding):
|
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
|
| @property
|
| def total_virtual_layers(self) -> int:
|
| """Total number of virtual layers in the forward pass."""
|
| return self.config.n_mirror * 2 + self.config.n_middle
|
|
|
| def forward(
|
| self,
|
| input_ids: torch.Tensor,
|
| labels: torch.Tensor = None,
|
| use_cache: bool = False,
|
| past_kv: list = None,
|
| word_positions: torch.Tensor | None = None,
|
| ) -> dict:
|
| B, L = input_ids.shape
|
|
|
|
|
| x = self.embed(input_ids)
|
| if self.embed_proj is not None:
|
| if self.embed_g3 is not None:
|
| g4 = F.silu(self.embed_g4(x))
|
| g3 = F.silu(self.embed_g3(x) * g4)
|
| x = self.embed_proj(x) * g3
|
| else:
|
| x = F.silu(self.embed_proj(x))
|
| x = x * self.embed_scale
|
|
|
| new_kv = [] if use_cache else None
|
| kv_idx = 0
|
|
|
|
|
| for block in self.mirror_blocks:
|
| layer_past = past_kv[kv_idx] if past_kv is not None else None
|
| x, kv = block(x, use_cache=use_cache, past_kv=layer_past, word_positions=word_positions)
|
| if use_cache:
|
| new_kv.append(kv)
|
| kv_idx += 1
|
|
|
|
|
| for block in self.middle_blocks:
|
| layer_past = past_kv[kv_idx] if past_kv is not None else None
|
| x, kv = block(x, use_cache=use_cache, past_kv=layer_past, word_positions=word_positions)
|
| if use_cache:
|
| new_kv.append(kv)
|
| kv_idx += 1
|
|
|
|
|
| for i in reversed(range(len(self.mirror_blocks))):
|
| layer_past = past_kv[kv_idx] if past_kv is not None else None
|
| x, kv = self.mirror_blocks[i](x, use_cache=use_cache, past_kv=layer_past, word_positions=word_positions)
|
| if use_cache:
|
| new_kv.append(kv)
|
| kv_idx += 1
|
|
|
|
|
| x = self.norm(x)
|
| if self.head_down is not None:
|
| if self.head_g3 is not None:
|
| g4 = F.silu(self.head_g4(x))
|
| g3 = F.silu(self.head_g3(x) * g4)
|
| logits = self.lm_head(self.head_down(x) * g3)
|
| else:
|
| logits = self.lm_head(F.silu(self.head_down(x)))
|
| else:
|
| logits = self.lm_head(x)
|
|
|
| result = {"logits": logits}
|
|
|
| if use_cache:
|
| result["past_kv"] = new_kv
|
|
|
| if labels is not None:
|
| shift_logits = logits[:, :-1, :].contiguous()
|
| shift_labels = labels[:, 1:].contiguous()
|
| loss = F.cross_entropy(
|
| shift_logits.view(-1, self.config.vocab_size),
|
| shift_labels.view(-1),
|
| ignore_index=-100
|
| )
|
|
|
| if self.skip_head is not None:
|
| skip_k = self.config.aux_skip_k
|
| if self.skip_head_down is not None:
|
| if self.skip_g3 is not None:
|
| g4 = F.silu(self.skip_g4(x))
|
| g3 = F.silu(self.skip_g3(x) * g4)
|
| skip_logits = self.skip_head(self.skip_head_down(x) * g3)[:, :-skip_k, :].contiguous()
|
| else:
|
| skip_logits = self.skip_head(F.silu(self.skip_head_down(x)))[:, :-skip_k, :].contiguous()
|
| else:
|
| skip_logits = self.skip_head(x)[:, :-skip_k, :].contiguous()
|
| skip_labels = labels[:, skip_k:].contiguous()
|
| aux_loss = F.cross_entropy(
|
| skip_logits.view(-1, self.config.vocab_size),
|
| skip_labels.view(-1),
|
| ignore_index=-100
|
| )
|
| result["aux_loss"] = aux_loss
|
| loss = loss + self.config.aux_skip_weight * aux_loss
|
|
|
| result["loss"] = loss
|
|
|
| return result
|
|
|
| @torch.no_grad()
|
| def generate(
|
| self,
|
| prompt_ids: torch.Tensor,
|
| max_new_tokens: int = 50,
|
| temperature: float = 0.8,
|
| top_k: int = 50,
|
| top_p: float = 0.9,
|
| use_cache: bool = True,
|
| word_start_table: torch.Tensor | None = None,
|
| ) -> torch.Tensor:
|
| """Autoregressive generation with KV caching."""
|
| from .layers import compute_word_positions
|
|
|
| self.eval()
|
| generated = prompt_ids.clone()
|
| past_kv = None
|
| word_pos_counter = 0
|
|
|
| for _ in range(max_new_tokens):
|
| if use_cache and past_kv is not None:
|
| input_ids = generated[:, -1:]
|
| if word_start_table is not None:
|
| last_token = generated[0, -1].item()
|
| if word_start_table[last_token]:
|
| word_pos_counter = 0
|
| else:
|
| word_pos_counter += 1
|
| word_positions = torch.tensor([[float(word_pos_counter)]], device=input_ids.device)
|
| else:
|
| word_positions = None
|
| else:
|
| input_ids = generated
|
| if word_start_table is not None:
|
| word_positions = compute_word_positions(input_ids, word_start_table)
|
| else:
|
| word_positions = None
|
|
|
| output = self(input_ids, use_cache=use_cache, past_kv=past_kv, word_positions=word_positions)
|
| logits = output["logits"][:, -1, :]
|
|
|
| if use_cache:
|
| past_kv = output["past_kv"]
|
|
|
| if temperature > 0:
|
| logits = logits / temperature
|
|
|
| if top_k > 0:
|
| top_k_vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| min_top_k = top_k_vals[:, -1].unsqueeze(-1)
|
| logits = torch.where(logits < min_top_k, float("-inf"), logits)
|
|
|
| if top_p < 1.0:
|
| sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| cumsum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| sorted_indices_to_remove = cumsum_probs > top_p
|
| sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
|
| sorted_indices_to_remove[:, 0] = False
|
| indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| logits = logits.masked_fill(indices_to_remove, float("-inf"))
|
|
|
| probs = F.softmax(logits, dim=-1)
|
| next_token = torch.multinomial(probs, num_samples=1)
|
| else:
|
| next_token = logits.argmax(dim=-1, keepdim=True)
|
|
|
| generated = torch.cat([generated, next_token], dim=1)
|
|
|
| if generated.size(1) >= self.config.max_seq_len:
|
| break
|
|
|
| return generated
|
|
|
|
|
| def count_mirrored_parameters(model: MirroredTransformer) -> dict:
|
| """Count parameters with breakdown by component."""
|
| total = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|
|
| unique = sum(p.numel() for p in set(p for p in model.parameters() if p.requires_grad))
|
|
|
| mirror_params = sum(p.numel() for p in model.mirror_blocks.parameters())
|
| middle_params = sum(p.numel() for p in model.middle_blocks.parameters())
|
| embed_params = model.embed.weight.numel()
|
| if model.embed_proj is not None:
|
| embed_params += model.embed_proj.weight.numel()
|
| head_params = 0
|
| if model.head_down is not None:
|
| head_params += model.head_down.weight.numel()
|
| head_params += model.lm_head.weight.numel()
|
|
|
|
|
| shared_attn = 0
|
| shared_ffn_base = 0
|
| gate_params = 0
|
| norm_params = 0
|
|
|
| for block in model.mirror_blocks:
|
| shared_attn += sum(p.numel() for p in block.attn.parameters())
|
| shared_ffn_base += block.ffn.w1.weight.numel() + block.ffn.w2.weight.numel()
|
| gate_params += block.ffn.w3.weight.numel()
|
| if hasattr(block.ffn, 'w4'):
|
| gate_params += block.ffn.w4.weight.numel()
|
| norm_params += sum(p.numel() for n, p in block.named_parameters() if 'norm' in n)
|
|
|
| return {
|
| "total": total,
|
| "unique": unique,
|
| "mirror_blocks": mirror_params,
|
| "middle_blocks": middle_params,
|
| "embedding": embed_params,
|
| "head": head_params,
|
| "shared_attention": shared_attn,
|
| "shared_ffn_base": shared_ffn_base,
|
| "direction_gates": gate_params,
|
| "norms": norm_params,
|
| }
|
|
|