| """ |
| Baseline models for comparison: |
| 1. FullTransformer — standard softmax attention (upper bound) |
| 2. PureLinearAttention — all linear attention (lower bound) |
| 3. UniformHybrid — every Nth layer is full attention (Jamba-style) |
| 4. DPA — our method (decision point routing) |
| """ |
|
|
| from .dpa_model import DPATransformer, LinearAttention, FullAttention |
| import torch |
| import torch.nn as nn |
|
|
|
|
| def build_model(model_type, **kwargs): |
| """Factory function to build different model variants.""" |
| defaults = dict( |
| vocab_size=32000, hidden_size=512, num_layers=6, |
| num_heads=8, max_seq_len=2048, |
| ) |
| defaults.update(kwargs) |
|
|
| if model_type == "full_transformer": |
| return FullTransformerModel(**defaults) |
| elif model_type == "pure_linear": |
| return PureLinearModel(**defaults) |
| elif model_type == "uniform_hybrid": |
| return UniformHybridModel(**defaults) |
| elif model_type == "dpa": |
| return DPATransformer(router_type="learned", **defaults) |
| elif model_type == "dpa_fixed": |
| return DPATransformer(router_type="fixed", **defaults) |
| else: |
| raise ValueError(f"Unknown model type: {model_type}") |
|
|
|
|
| class FullTransformerModel(nn.Module): |
| """All layers use full softmax attention.""" |
|
|
| def __init__(self, vocab_size, hidden_size, num_layers, num_heads, max_seq_len, **kw): |
| super().__init__() |
| self.embedding = nn.Embedding(vocab_size, hidden_size) |
| self.pos_embedding = nn.Embedding(max_seq_len, hidden_size) |
| self.layers = nn.ModuleList([ |
| nn.ModuleDict({ |
| "norm": nn.LayerNorm(hidden_size), |
| "attn": FullAttention(hidden_size, num_heads), |
| }) for _ in range(num_layers) |
| ]) |
| self.norm = nn.LayerNorm(hidden_size) |
| self.output = nn.Linear(hidden_size, vocab_size, bias=False) |
|
|
| def forward(self, input_ids, attention_mask=None, labels=None): |
| B, L = input_ids.shape |
| pos = torch.arange(L, device=input_ids.device).unsqueeze(0) |
| x = self.embedding(input_ids) + self.pos_embedding(pos) |
|
|
| for layer in self.layers: |
| residual = x |
| x = layer["norm"](x) |
| x = residual + layer["attn"](x, attention_mask) |
|
|
| x = self.norm(x) |
| logits = self.output(x) |
| loss = None |
| if labels is not None: |
| loss = nn.functional.cross_entropy( |
| logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100) |
| return {"loss": loss, "logits": logits, "avg_decision_ratio": 1.0} |
|
|
|
|
| class PureLinearModel(nn.Module): |
| """All layers use linear attention only.""" |
|
|
| def __init__(self, vocab_size, hidden_size, num_layers, num_heads, max_seq_len, **kw): |
| super().__init__() |
| self.embedding = nn.Embedding(vocab_size, hidden_size) |
| self.pos_embedding = nn.Embedding(max_seq_len, hidden_size) |
| self.layers = nn.ModuleList([ |
| nn.ModuleDict({ |
| "norm": nn.LayerNorm(hidden_size), |
| "attn": LinearAttention(hidden_size, num_heads), |
| }) for _ in range(num_layers) |
| ]) |
| self.norm = nn.LayerNorm(hidden_size) |
| self.output = nn.Linear(hidden_size, vocab_size, bias=False) |
|
|
| def forward(self, input_ids, attention_mask=None, labels=None): |
| B, L = input_ids.shape |
| pos = torch.arange(L, device=input_ids.device).unsqueeze(0) |
| x = self.embedding(input_ids) + self.pos_embedding(pos) |
|
|
| for layer in self.layers: |
| residual = x |
| x = layer["norm"](x) |
| x = residual + layer["attn"](x, attention_mask) |
|
|
| x = self.norm(x) |
| logits = self.output(x) |
| loss = None |
| if labels is not None: |
| loss = nn.functional.cross_entropy( |
| logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100) |
| return {"loss": loss, "logits": logits, "avg_decision_ratio": 0.0} |
|
|
|
|
| class UniformHybridModel(nn.Module): |
| """Every Nth layer uses full attention, rest use linear (Jamba-style).""" |
|
|
| def __init__(self, vocab_size, hidden_size, num_layers, num_heads, max_seq_len, |
| full_attn_every=4, **kw): |
| super().__init__() |
| self.embedding = nn.Embedding(vocab_size, hidden_size) |
| self.pos_embedding = nn.Embedding(max_seq_len, hidden_size) |
| self.full_attn_every = full_attn_every |
|
|
| self.layers = nn.ModuleList() |
| for i in range(num_layers): |
| use_full = (i % full_attn_every == 0) |
| attn = FullAttention(hidden_size, num_heads) if use_full else LinearAttention(hidden_size, num_heads) |
| self.layers.append(nn.ModuleDict({ |
| "norm": nn.LayerNorm(hidden_size), |
| "attn": attn, |
| "is_full": nn.Identity(), |
| })) |
|
|
| self.norm = nn.LayerNorm(hidden_size) |
| self.output = nn.Linear(hidden_size, vocab_size, bias=False) |
| self._ratio = 1.0 / full_attn_every |
|
|
| def forward(self, input_ids, attention_mask=None, labels=None): |
| B, L = input_ids.shape |
| pos = torch.arange(L, device=input_ids.device).unsqueeze(0) |
| x = self.embedding(input_ids) + self.pos_embedding(pos) |
|
|
| for layer in self.layers: |
| residual = x |
| x = layer["norm"](x) |
| x = residual + layer["attn"](x, attention_mask) |
|
|
| x = self.norm(x) |
| logits = self.output(x) |
| loss = None |
| if labels is not None: |
| loss = nn.functional.cross_entropy( |
| logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100) |
| return {"loss": loss, "logits": logits, "avg_decision_ratio": self._ratio} |
|
|