""" Baseline implementations for fair comparison. Baselines: 1. Standard Transformer: Dense MLP FFN, no TT, no quantum. 2. Distilled: Smaller transformer trained with KD. 3. Pruned: Magnitude-based structured pruning. 4. TT-Only: Tensor network FFN without quantum or adaptive rank. """ import torch import torch.nn as nn import torch.nn.functional as F import math from typing import Optional class StandardTransformer(nn.Module): """ Basic transformer decoder (GPT-style) with dense MLP FFN. Reference baseline — matches Q-TensorFormer architecture exactly except for TT decomposition and quantum layers. """ def __init__(self, vocab_size: int = 10000, d_model: int = 128, n_heads: int = 4, n_layers: int = 2, ff_mult: int = 4, max_seq_len: int = 128, dropout: float = 0.1): super().__init__() self.d_model = d_model self.config = type("config", (), { "d_model": d_model, "n_heads": n_heads, "n_layers": n_layers, "ff_multiplier": ff_mult, "max_seq_len": max_seq_len, "vocab_size": vocab_size, "dropout": dropout, })() self.embedding = nn.Embedding(vocab_size, d_model) self.pos_encoding = _PositionalEncoding(d_model, max_seq_len, dropout) self.blocks = nn.ModuleList([ _StandardBlock(d_model, n_heads, ff_mult, dropout, max_seq_len) for _ in range(n_layers) ]) self.ln_f = nn.LayerNorm(d_model) self.lm_head = nn.Linear(d_model, vocab_size, bias=False) self.lm_head.weight = self.embedding.weight def forward(self, input_ids, attention_mask=None, return_stats=False): x = self.embedding(input_ids) x = self.pos_encoding(x) for block in self.blocks: x = block(x, mask=attention_mask) x = self.ln_f(x) logits = self.lm_head(x) if return_stats: return logits, [] return logits @property def total_params(self) -> int: return sum(p.numel() for p in self.parameters()) class DistilledTransformer(nn.Module): """ Smaller transformer trained via knowledge distillation. Designed to match Q-TensorFormer parameter counts. """ def __init__(self, vocab_size: int = 10000, d_model: int = 96, n_heads: int = 4, n_layers: int = 2, ff_mult: int = 3, max_seq_len: int = 128, dropout: float = 0.1): super().__init__() self.d_model = d_model self.config = type("config", (), { "d_model": d_model, "n_heads": n_heads, "n_layers": n_layers, "ff_multiplier": ff_mult, "max_seq_len": max_seq_len, "vocab_size": vocab_size, "dropout": dropout, })() self.embedding = nn.Embedding(vocab_size, d_model) self.pos_encoding = _PositionalEncoding(d_model, max_seq_len, dropout) self.blocks = nn.ModuleList([ _StandardBlock(d_model, n_heads, ff_mult, dropout, max_seq_len) for _ in range(n_layers) ]) self.ln_f = nn.LayerNorm(d_model) self.lm_head = nn.Linear(d_model, vocab_size, bias=False) self.lm_head.weight = self.embedding.weight def forward(self, input_ids, attention_mask=None, return_stats=False): x = self.embedding(input_ids) x = self.pos_encoding(x) for block in self.blocks: x = block(x, mask=attention_mask) x = self.ln_f(x) logits = self.lm_head(x) if return_stats: return logits, [] return logits @property def total_params(self) -> int: return sum(p.numel() for p in self.parameters()) class PrunedTransformer(nn.Module): """ Magnitude-pruned standard transformer. Prunes FFN weights globally to match Q-TensorFormer parameter count. Applies structured pruning (zeroing channels) for efficiency. """ def __init__(self, base_model: StandardTransformer, prune_ratio: float = 0.5): super().__init__() self.base = base_model self.prune_ratio = prune_ratio self.config = base_model.config self._prune() def _prune(self): """Apply structured magnitude pruning to FFN layers.""" all_weights = [] for block in self.base.blocks: for weight in [block.ffn[0].weight, block.ffn[2].weight]: all_weights.append(weight.flatten()) # Compute global threshold flat = torch.cat(all_weights) k = int(len(flat) * self.prune_ratio) threshold = torch.topk(flat.abs(), k, largest=False).values[-1] # Apply structured pruning (zero rows/cols) for block in self.base.blocks: for layer in [block.ffn[0], block.ffn[2]]: mask = (layer.weight.abs() > threshold).float() # Zero small rows entirely row_norms = mask.sum(dim=1) dead_rows = row_norms < layer.weight.size(1) * 0.1 mask[dead_rows] = 0 layer.weight.data *= mask def forward(self, *args, **kwargs): return self.base(*args, **kwargs) @property def total_params(self) -> int: return sum(p.numel() for p in self.parameters()) class _StandardBlock(nn.Module): """Standard transformer decoder block.""" def __init__(self, d_model, n_heads, ff_mult, dropout, max_seq_len): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.attn = _CausalAttention(d_model, n_heads, dropout, max_seq_len) self.ln2 = nn.LayerNorm(d_model) self.ffn = nn.Sequential( nn.Linear(d_model, d_model * ff_mult), nn.GELU(), nn.Linear(d_model * ff_mult, d_model), nn.Dropout(dropout), ) self.dropout = nn.Dropout(dropout) def forward(self, x, mask=None): x = x + self.dropout(self.attn(self.ln1(x), mask=mask)) x = x + self.ffn(self.ln2(x)) return x class _CausalAttention(nn.Module): """Causal multi-head attention.""" def __init__(self, d_model, n_heads, dropout, max_seq_len): super().__init__() assert d_model % n_heads == 0 self.n_heads = n_heads self.head_dim = d_model // n_heads self.scale = math.sqrt(self.head_dim) self.qkv = nn.Linear(d_model, 3 * d_model, bias=False) self.out_proj = nn.Linear(d_model, d_model, bias=False) self.dropout = nn.Dropout(dropout) self.max_seq_len = max_seq_len def forward(self, x, mask=None): B, T, C = x.shape qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim) q, k, v = qkv.unbind(dim=2) q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) attn = (q @ k.transpose(-2, -1)) / self.scale causal = torch.triu(torch.ones(T, T, device=x.device) * float("-inf"), diagonal=1) attn = attn + causal if mask is not None: attn = attn + mask.unsqueeze(1).unsqueeze(2) * float("-inf") attn = F.softmax(attn, dim=-1) attn = self.dropout(attn) out = (attn @ v).transpose(1, 2).reshape(B, T, C) return self.out_proj(out) class _PositionalEncoding(nn.Module): def __init__(self, d_model, max_len, dropout): super().__init__() self.dropout = nn.Dropout(dropout) pe = torch.zeros(max_len, d_model) pos = torch.arange(max_len).unsqueeze(1).float() div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(pos * div) pe[:, 1::2] = torch.cos(pos * div) self.register_buffer("pe", pe.unsqueeze(0)) def forward(self, x): return self.dropout(x + self.pe[:, :x.size(1)])