jasonfan's picture
Upload folder using huggingface_hub
09dd617 verified
"""
Baseline models for comparison:
1. FullTransformer — standard softmax attention (upper bound)
2. PureLinearAttention — all linear attention (lower bound)
3. UniformHybrid — every Nth layer is full attention (Jamba-style)
4. DPA — our method (decision point routing)
"""
from .dpa_model import DPATransformer, LinearAttention, FullAttention
import torch
import torch.nn as nn
def build_model(model_type, **kwargs):
"""Factory function to build different model variants."""
defaults = dict(
vocab_size=32000, hidden_size=512, num_layers=6,
num_heads=8, max_seq_len=2048,
)
defaults.update(kwargs)
if model_type == "full_transformer":
return FullTransformerModel(**defaults)
elif model_type == "pure_linear":
return PureLinearModel(**defaults)
elif model_type == "uniform_hybrid":
return UniformHybridModel(**defaults)
elif model_type == "dpa":
return DPATransformer(router_type="learned", **defaults)
elif model_type == "dpa_fixed":
return DPATransformer(router_type="fixed", **defaults)
else:
raise ValueError(f"Unknown model type: {model_type}")
class FullTransformerModel(nn.Module):
"""All layers use full softmax attention."""
def __init__(self, vocab_size, hidden_size, num_layers, num_heads, max_seq_len, **kw):
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_size)
self.pos_embedding = nn.Embedding(max_seq_len, hidden_size)
self.layers = nn.ModuleList([
nn.ModuleDict({
"norm": nn.LayerNorm(hidden_size),
"attn": FullAttention(hidden_size, num_heads),
}) for _ in range(num_layers)
])
self.norm = nn.LayerNorm(hidden_size)
self.output = nn.Linear(hidden_size, vocab_size, bias=False)
def forward(self, input_ids, attention_mask=None, labels=None):
B, L = input_ids.shape
pos = torch.arange(L, device=input_ids.device).unsqueeze(0)
x = self.embedding(input_ids) + self.pos_embedding(pos)
for layer in self.layers:
residual = x
x = layer["norm"](x)
x = residual + layer["attn"](x, attention_mask)
x = self.norm(x)
logits = self.output(x)
loss = None
if labels is not None:
loss = nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100)
return {"loss": loss, "logits": logits, "avg_decision_ratio": 1.0}
class PureLinearModel(nn.Module):
"""All layers use linear attention only."""
def __init__(self, vocab_size, hidden_size, num_layers, num_heads, max_seq_len, **kw):
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_size)
self.pos_embedding = nn.Embedding(max_seq_len, hidden_size)
self.layers = nn.ModuleList([
nn.ModuleDict({
"norm": nn.LayerNorm(hidden_size),
"attn": LinearAttention(hidden_size, num_heads),
}) for _ in range(num_layers)
])
self.norm = nn.LayerNorm(hidden_size)
self.output = nn.Linear(hidden_size, vocab_size, bias=False)
def forward(self, input_ids, attention_mask=None, labels=None):
B, L = input_ids.shape
pos = torch.arange(L, device=input_ids.device).unsqueeze(0)
x = self.embedding(input_ids) + self.pos_embedding(pos)
for layer in self.layers:
residual = x
x = layer["norm"](x)
x = residual + layer["attn"](x, attention_mask)
x = self.norm(x)
logits = self.output(x)
loss = None
if labels is not None:
loss = nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100)
return {"loss": loss, "logits": logits, "avg_decision_ratio": 0.0}
class UniformHybridModel(nn.Module):
"""Every Nth layer uses full attention, rest use linear (Jamba-style)."""
def __init__(self, vocab_size, hidden_size, num_layers, num_heads, max_seq_len,
full_attn_every=4, **kw):
super().__init__()
self.embedding = nn.Embedding(vocab_size, hidden_size)
self.pos_embedding = nn.Embedding(max_seq_len, hidden_size)
self.full_attn_every = full_attn_every
self.layers = nn.ModuleList()
for i in range(num_layers):
use_full = (i % full_attn_every == 0)
attn = FullAttention(hidden_size, num_heads) if use_full else LinearAttention(hidden_size, num_heads)
self.layers.append(nn.ModuleDict({
"norm": nn.LayerNorm(hidden_size),
"attn": attn,
"is_full": nn.Identity(), # marker
}))
self.norm = nn.LayerNorm(hidden_size)
self.output = nn.Linear(hidden_size, vocab_size, bias=False)
self._ratio = 1.0 / full_attn_every
def forward(self, input_ids, attention_mask=None, labels=None):
B, L = input_ids.shape
pos = torch.arange(L, device=input_ids.device).unsqueeze(0)
x = self.embedding(input_ids) + self.pos_embedding(pos)
for layer in self.layers:
residual = x
x = layer["norm"](x)
x = residual + layer["attn"](x, attention_mask)
x = self.norm(x)
logits = self.output(x)
loss = None
if labels is not None:
loss = nn.functional.cross_entropy(
logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100)
return {"loss": loss, "logits": logits, "avg_decision_ratio": self._ratio}