""" Baseline models for comparison: 1. FullTransformer — standard softmax attention (upper bound) 2. PureLinearAttention — all linear attention (lower bound) 3. UniformHybrid — every Nth layer is full attention (Jamba-style) 4. DPA — our method (decision point routing) """ from .dpa_model import DPATransformer, LinearAttention, FullAttention import torch import torch.nn as nn def build_model(model_type, **kwargs): """Factory function to build different model variants.""" defaults = dict( vocab_size=32000, hidden_size=512, num_layers=6, num_heads=8, max_seq_len=2048, ) defaults.update(kwargs) if model_type == "full_transformer": return FullTransformerModel(**defaults) elif model_type == "pure_linear": return PureLinearModel(**defaults) elif model_type == "uniform_hybrid": return UniformHybridModel(**defaults) elif model_type == "dpa": return DPATransformer(router_type="learned", **defaults) elif model_type == "dpa_fixed": return DPATransformer(router_type="fixed", **defaults) else: raise ValueError(f"Unknown model type: {model_type}") class FullTransformerModel(nn.Module): """All layers use full softmax attention.""" def __init__(self, vocab_size, hidden_size, num_layers, num_heads, max_seq_len, **kw): super().__init__() self.embedding = nn.Embedding(vocab_size, hidden_size) self.pos_embedding = nn.Embedding(max_seq_len, hidden_size) self.layers = nn.ModuleList([ nn.ModuleDict({ "norm": nn.LayerNorm(hidden_size), "attn": FullAttention(hidden_size, num_heads), }) for _ in range(num_layers) ]) self.norm = nn.LayerNorm(hidden_size) self.output = nn.Linear(hidden_size, vocab_size, bias=False) def forward(self, input_ids, attention_mask=None, labels=None): B, L = input_ids.shape pos = torch.arange(L, device=input_ids.device).unsqueeze(0) x = self.embedding(input_ids) + self.pos_embedding(pos) for layer in self.layers: residual = x x = layer["norm"](x) x = residual + layer["attn"](x, attention_mask) x = self.norm(x) logits = self.output(x) loss = None if labels is not None: loss = nn.functional.cross_entropy( logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100) return {"loss": loss, "logits": logits, "avg_decision_ratio": 1.0} class PureLinearModel(nn.Module): """All layers use linear attention only.""" def __init__(self, vocab_size, hidden_size, num_layers, num_heads, max_seq_len, **kw): super().__init__() self.embedding = nn.Embedding(vocab_size, hidden_size) self.pos_embedding = nn.Embedding(max_seq_len, hidden_size) self.layers = nn.ModuleList([ nn.ModuleDict({ "norm": nn.LayerNorm(hidden_size), "attn": LinearAttention(hidden_size, num_heads), }) for _ in range(num_layers) ]) self.norm = nn.LayerNorm(hidden_size) self.output = nn.Linear(hidden_size, vocab_size, bias=False) def forward(self, input_ids, attention_mask=None, labels=None): B, L = input_ids.shape pos = torch.arange(L, device=input_ids.device).unsqueeze(0) x = self.embedding(input_ids) + self.pos_embedding(pos) for layer in self.layers: residual = x x = layer["norm"](x) x = residual + layer["attn"](x, attention_mask) x = self.norm(x) logits = self.output(x) loss = None if labels is not None: loss = nn.functional.cross_entropy( logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100) return {"loss": loss, "logits": logits, "avg_decision_ratio": 0.0} class UniformHybridModel(nn.Module): """Every Nth layer uses full attention, rest use linear (Jamba-style).""" def __init__(self, vocab_size, hidden_size, num_layers, num_heads, max_seq_len, full_attn_every=4, **kw): super().__init__() self.embedding = nn.Embedding(vocab_size, hidden_size) self.pos_embedding = nn.Embedding(max_seq_len, hidden_size) self.full_attn_every = full_attn_every self.layers = nn.ModuleList() for i in range(num_layers): use_full = (i % full_attn_every == 0) attn = FullAttention(hidden_size, num_heads) if use_full else LinearAttention(hidden_size, num_heads) self.layers.append(nn.ModuleDict({ "norm": nn.LayerNorm(hidden_size), "attn": attn, "is_full": nn.Identity(), # marker })) self.norm = nn.LayerNorm(hidden_size) self.output = nn.Linear(hidden_size, vocab_size, bias=False) self._ratio = 1.0 / full_attn_every def forward(self, input_ids, attention_mask=None, labels=None): B, L = input_ids.shape pos = torch.arange(L, device=input_ids.device).unsqueeze(0) x = self.embedding(input_ids) + self.pos_embedding(pos) for layer in self.layers: residual = x x = layer["norm"](x) x = residual + layer["attn"](x, attention_mask) x = self.norm(x) logits = self.output(x) loss = None if labels is not None: loss = nn.functional.cross_entropy( logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100) return {"loss": loss, "logits": logits, "avg_decision_ratio": self._ratio}