MicroMixer-2 / src /model.py
llaa33219's picture
Upload 6 files
c047c1e verified
Raw
History Blame Contribute Delete
17.8 kB
"""MicroMixer-2 V4: MLP-Mixer architecture optimized for language models.
V4 innovations based on research:
- DropPath (stochastic depth): Regularization via random residual skipping
- FourierMixing: Parameter-free FFT token mixing (FNet-inspired)
- Padding-aware loss: Ignore padding tokens in cross-entropy
- Label smoothing: Regularize overconfident predictions
- Increased depth: 6-12 layers for larger models
- HyperMixing (ACL 2023): O(S) token mixing via hypernetwork
- RoPE: Rotary position embedding for length generalization
- Standard MLP: Better knowledge capacity than GatedMLP
"""
import math
from dataclasses import dataclass, field
from enum import Enum, auto
import torch
import torch.nn as nn
import torch.nn.functional as F
class TokenMixerType(Enum):
HYPER = auto() # HyperMixing: O(S) via hypernetwork
FOURIER = auto() # FourierMixing: O(S log S) via FFT, zero params
class DropPath(nn.Module):
"""Stochastic Depth (DropPath) per sample.
Randomly drops entire residual branches during training.
Linear schedule: drop probability increases with layer depth.
"""
def __init__(self, drop_prob: float = 0.0):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not self.training or self.drop_prob == 0.0:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor = torch.floor(random_tensor + keep_prob)
output = x / keep_prob * random_tensor
return output
class MlpBlock(nn.Module):
"""Standard 2-layer MLP with GELU activation."""
def __init__(self, in_dim: int, hidden_dim: int, dropout: float = 0.1):
super().__init__()
self.fc1 = nn.Linear(in_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, in_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc1(x)
x = F.gelu(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.dropout(x)
return x
class RotaryPositionEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE) for length generalization."""
def __init__(self, dim: int, max_seq_len: int = 512):
super().__init__()
self.dim = dim
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, x: torch.Tensor, seq_len: int) -> torch.Tensor:
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos, sin = emb.cos(), emb.sin()
cos = cos.unsqueeze(0)
sin = sin.unsqueeze(0)
x_rot = x[..., : self.dim]
x_rest = x[..., self.dim :] if x.shape[-1] > self.dim else None
x1, x2 = x_rot[..., ::2], x_rot[..., 1::2]
rotated = torch.stack([-x2, x1], dim=-1).flatten(-2)
x_rotated = x_rot * cos + rotated * sin
if x_rest is not None:
return torch.cat([x_rotated, x_rest], dim=-1)
return x_rotated
class HyperMixing(nn.Module):
"""HyperMixing: O(S) token mixing via cumulative-mean hypernetwork.
Based on HyperMixer (ACL 2023). Uses running statistics to generate
mixing weights dynamically.
"""
def __init__(self, hidden_dim: int, hyper_hidden_dim: int, dropout: float = 0.1):
super().__init__()
self.hidden_dim = hidden_dim
self.hyper = nn.Sequential(
nn.Linear(hidden_dim, hyper_hidden_dim),
nn.GELU(),
nn.Linear(hyper_hidden_dim, hidden_dim * 2),
)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, S, H = x.shape
# Cumulative mean for causal context
cumsum = torch.cumsum(x, dim=1)
counts = torch.arange(1, S + 1, device=x.device).view(1, S, 1).float()
pooled = cumsum / counts
# Hypernetwork generates affine transform weights
weights = self.hyper(pooled)
w1, w2 = weights.chunk(2, dim=-1)
# Affine mixing: scale + shift
x = x * w1 + w2
return self.dropout(x)
class FourierMixing(nn.Module):
"""FourierMixing: Parameter-free token mixing via FFT.
Based on FNet (NAACL 2022). Replaces attention with 2D FFT.
- Zero learnable parameters for token mixing
- O(S log S) complexity
- 80% faster than attention on GPUs
Causal property: FFT mixes all positions, so we apply
cumulative masking to maintain autoregressive property.
"""
def __init__(self, hidden_dim: int, dropout: float = 0.1):
super().__init__()
self.dropout = nn.Dropout(dropout)
# Learnable scaling for output (optional, helps stability)
self.scale = nn.Parameter(torch.ones(1) * 0.1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, S, H = x.shape
# Apply FFT along sequence dimension (dim=1)
# Real-valued FFT preserves real output
x_fft = torch.fft.fft(x, dim=1).real
# Apply causal masking: each position only sees itself and prior
# Use cumulative sum to enforce causality
mask = torch.triu(torch.ones(S, S, device=x.device)).bool()
# More efficient: use cumulative mean like HyperMixing
cumsum = torch.cumsum(x_fft, dim=1)
counts = torch.arange(1, S + 1, device=x.device).view(1, S, 1).float()
x_causal = cumsum / counts
# Blend original FFT output with causal version
x = x_fft * (1 - self.scale) + x_causal * self.scale
return self.dropout(x)
class MicroMixerLayer(nn.Module):
"""Single MicroMixer layer with DropPath regularization.
Architecture:
1. LayerNorm -> Token Mixing -> DropPath -> Residual
2. LayerNorm -> Channel Mixing (MLP) -> DropPath -> Residual
"""
def __init__(
self,
hidden_dim: int,
hyper_hidden_dim: int,
channel_mlp_dim: int,
dropout: float = 0.1,
drop_path: float = 0.0,
mixer_type: TokenMixerType = TokenMixerType.HYPER,
):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_dim)
self.norm2 = nn.LayerNorm(hidden_dim)
if mixer_type == TokenMixerType.HYPER:
self.token_mixer = HyperMixing(hidden_dim, hyper_hidden_dim, dropout)
elif mixer_type == TokenMixerType.FOURIER:
self.token_mixer = FourierMixing(hidden_dim, dropout)
else:
raise ValueError(f"Unknown mixer type: {mixer_type}")
self.channel_mlp = MlpBlock(hidden_dim, channel_mlp_dim, dropout)
self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Token mixing with stochastic depth
residual = x
x = self.norm1(x)
x = self.token_mixer(x)
x = residual + self.drop_path(x)
# Channel mixing with stochastic depth
residual = x
x = self.norm2(x)
x = self.channel_mlp(x)
x = residual + self.drop_path(x)
return x
@dataclass
class MicroMixerConfig:
"""Configuration for MicroMixer-2 V4.
Attributes:
vocab_size: Vocabulary size (256 for byte-level).
max_seq_len: Maximum sequence length.
hidden_dim: Hidden dimension for embeddings and mixer layers.
hyper_hidden_dim: Hidden dimension for HyperMixing hypernetwork.
channel_mlp_dim: Inner dimension of channel-mixing MLP.
num_layers: Number of mixer layers.
dropout: Dropout probability.
drop_path: DropPath probability (0 = disabled).
label_smoothing: Label smoothing for cross-entropy (0 = disabled).
tie_weights: Tie input/output embeddings.
mixer_type: Token mixing strategy (HYPER or FOURIER).
pad_token_id: Padding token ID for masked loss.
"""
vocab_size: int = 256
max_seq_len: int = 128
hidden_dim: int = 128
hyper_hidden_dim: int = 64
channel_mlp_dim: int = 256
num_layers: int = 2
dropout: float = 0.1
drop_path: float = 0.0
label_smoothing: float = 0.0
tie_weights: bool = True
mixer_type: TokenMixerType = TokenMixerType.HYPER
pad_token_id: int = 0
class MicroMixer(nn.Module):
"""MicroMixer-2 V4: MLP-Mixer language model with research-backed innovations.
V4 improvements:
- DropPath: Stochastic depth regularization
- FourierMixing: Optional parameter-free FFT mixing
- Padding-aware loss: Ignores padding tokens
- Label smoothing: Regularizes overconfident predictions
- Increased depth: Up to 12 layers for larger models
"""
def __init__(self, config: MicroMixerConfig):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.max_seq_len = config.max_seq_len
self.hidden_dim = config.hidden_dim
self.pad_token_id = config.pad_token_id
self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_dim)
self.rope = RotaryPositionEmbedding(config.hidden_dim, config.max_seq_len)
self.dropout = nn.Dropout(config.dropout)
# DropPath: linear schedule (increases with depth)
dpr = [x.item() for x in torch.linspace(0, config.drop_path, config.num_layers)]
self.mixer_layers = nn.ModuleList([
MicroMixerLayer(
config.hidden_dim,
config.hyper_hidden_dim,
config.channel_mlp_dim,
config.dropout,
drop_path=dpr[i],
mixer_type=config.mixer_type,
)
for i in range(config.num_layers)
])
self.layer_norm = nn.LayerNorm(config.hidden_dim)
self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
if config.tie_weights:
self.lm_head.weight = self.token_embedding.weight
self.apply(self._init_weights)
def _init_weights(self, module: nn.Module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if getattr(module, "bias", None) is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.ones_(module.weight)
torch.nn.init.zeros_(module.bias)
def forward(
self,
input_ids: torch.Tensor,
targets: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
B, S = input_ids.shape
if S > self.max_seq_len:
input_ids = input_ids[:, -self.max_seq_len :]
S = self.max_seq_len
if targets is not None:
targets = targets[:, -self.max_seq_len :]
if attention_mask is not None:
attention_mask = attention_mask[:, -self.max_seq_len :]
token_emb = self.token_embedding(input_ids)
x = self.rope(token_emb, S)
x = self.dropout(x)
for layer in self.mixer_layers:
x = layer(x)
x = self.layer_norm(x)
logits = self.lm_head(x)
if targets is not None:
# Flatten for cross-entropy
logits_flat = logits.view(-1, self.vocab_size)
targets_flat = targets.view(-1)
# Build ignore_mask: padding tokens AND positions after padding
if attention_mask is not None:
# Shift mask: predict token AFTER seeing context
# mask[i] = 1 means token i is real, so predicting i+1 is valid
shifted_mask = torch.ones_like(attention_mask)
shifted_mask[:, 1:] = attention_mask[:, :-1]
ignore_mask = (shifted_mask.view(-1) == 0)
pad_indices = ignore_mask.nonzero(as_tuple=True)[0]
else:
# No mask provided: only ignore explicit pad tokens in targets
pad_indices = (targets_flat == self.pad_token_id).nonzero(as_tuple=True)[0]
# Compute loss with label smoothing and padding ignore
loss = F.cross_entropy(
logits_flat,
targets_flat,
ignore_index=self.pad_token_id if len(pad_indices) > 0 else -100,
label_smoothing=self.config.label_smoothing,
)
return logits, loss
return logits
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int,
temperature: float = 1.0,
top_k: int | None = None,
) -> torch.Tensor:
"""Autoregressive text generation."""
self.eval()
device = next(self.parameters()).device
input_ids = input_ids.to(device)
for _ in range(max_new_tokens):
logits = self(input_ids)
logits = logits[:, -1, :]
if temperature == 0.0:
next_token = torch.argmax(logits, dim=-1, keepdim=True)
else:
logits = logits / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("Inf")
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=1)
return input_ids
def count_parameters(model: nn.Module) -> int:
"""Count total trainable parameters."""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def micromixer_100k() -> MicroMixerConfig:
"""~100K parameter model for testing/experimentation."""
return MicroMixerConfig(
max_seq_len=64,
hidden_dim=84,
hyper_hidden_dim=48,
channel_mlp_dim=128,
num_layers=3,
dropout=0.1,
drop_path=0.0,
label_smoothing=0.0,
)
def micromixer_300k() -> MicroMixerConfig:
"""~300K parameter model for small-scale experiments."""
return MicroMixerConfig(
max_seq_len=128,
hidden_dim=128,
hyper_hidden_dim=64,
channel_mlp_dim=288,
num_layers=4,
dropout=0.1,
drop_path=0.05,
label_smoothing=0.05,
)
def micromixer_500k() -> MicroMixerConfig:
"""~500K parameter model for medium-scale experiments."""
return MicroMixerConfig(
max_seq_len=128,
hidden_dim=176,
hyper_hidden_dim=88,
channel_mlp_dim=384,
num_layers=4,
dropout=0.1,
drop_path=0.1,
label_smoothing=0.05,
)
def micromixer_1m() -> MicroMixerConfig:
"""~1M parameter model for standard experiments."""
return MicroMixerConfig(
max_seq_len=256,
hidden_dim=168,
hyper_hidden_dim=84,
channel_mlp_dim=448,
num_layers=5,
dropout=0.1,
drop_path=0.1,
label_smoothing=0.1,
)
def micromixer_1m_long(max_seq_len: int = 4096) -> MicroMixerConfig:
"""~1M parameter model with extended context length."""
return MicroMixerConfig(
max_seq_len=max_seq_len,
hidden_dim=168,
hyper_hidden_dim=84,
channel_mlp_dim=448,
num_layers=5,
dropout=0.1,
drop_path=0.1,
label_smoothing=0.1,
)
def micromixer_1m_fourier() -> MicroMixerConfig:
"""~1M parameter model with FourierMixing (parameter-free token mixing)."""
return MicroMixerConfig(
max_seq_len=256,
hidden_dim=168,
hyper_hidden_dim=84,
channel_mlp_dim=448,
num_layers=5,
dropout=0.1,
drop_path=0.1,
label_smoothing=0.1,
mixer_type=TokenMixerType.FOURIER,
)
if __name__ == "__main__":
print("Testing MicroMixer-2 V4...")
for name, config_fn in [
("100k", micromixer_100k),
("300k", micromixer_300k),
("500k", micromixer_500k),
("1M", micromixer_1m),
("1M-fourier", micromixer_1m_fourier),
]:
config = config_fn()
model = MicroMixer(config)
params = count_parameters(model)
print(f" {name}: {params:,} parameters, {config.num_layers} layers")
batch_size = 2
seq_len = min(32, config.max_seq_len)
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len))
logits = model(input_ids)
assert logits.shape == (batch_size, seq_len, config.vocab_size)
targets = torch.randint(0, config.vocab_size, (batch_size, seq_len))
logits, loss = model(input_ids, targets)
assert logits.shape == (batch_size, seq_len, config.vocab_size)
assert loss.dim() == 0
prompt = input_ids[:, :4]
gen_ids = model.generate(prompt, max_new_tokens=8, temperature=0.8, top_k=10)
assert gen_ids.shape == (batch_size, 12)
prefix = torch.randint(0, config.vocab_size, (1, 5))
extra = torch.randint(0, config.vocab_size, (1, 3))
logits_prefix = model(prefix)[:, -1, :]
logits_extended = model(torch.cat([prefix, extra], dim=1))[:, 4, :]
assert torch.allclose(logits_prefix, logits_extended, atol=1e-5)
print("All V4 tests passed!")