Spaces:
Running
Running
| """MicroMixer-2 V4: MLP-Mixer architecture optimized for language models. | |
| V4 innovations based on research: | |
| - DropPath (stochastic depth): Regularization via random residual skipping | |
| - FourierMixing: Parameter-free FFT token mixing (FNet-inspired) | |
| - Padding-aware loss: Ignore padding tokens in cross-entropy | |
| - Label smoothing: Regularize overconfident predictions | |
| - Increased depth: 6-12 layers for larger models | |
| - HyperMixing (ACL 2023): O(S) token mixing via hypernetwork | |
| - RoPE: Rotary position embedding for length generalization | |
| - Standard MLP: Better knowledge capacity than GatedMLP | |
| """ | |
| import math | |
| from dataclasses import dataclass, field | |
| from enum import Enum, auto | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class TokenMixerType(Enum): | |
| HYPER = auto() # HyperMixing: O(S) via hypernetwork | |
| FOURIER = auto() # FourierMixing: O(S log S) via FFT, zero params | |
| class DropPath(nn.Module): | |
| """Stochastic Depth (DropPath) per sample. | |
| Randomly drops entire residual branches during training. | |
| Linear schedule: drop probability increases with layer depth. | |
| """ | |
| def __init__(self, drop_prob: float = 0.0): | |
| super().__init__() | |
| self.drop_prob = drop_prob | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if not self.training or self.drop_prob == 0.0: | |
| return x | |
| keep_prob = 1 - self.drop_prob | |
| shape = (x.shape[0],) + (1,) * (x.ndim - 1) | |
| random_tensor = torch.rand(shape, dtype=x.dtype, device=x.device) | |
| random_tensor = torch.floor(random_tensor + keep_prob) | |
| output = x / keep_prob * random_tensor | |
| return output | |
| class MlpBlock(nn.Module): | |
| """Standard 2-layer MLP with GELU activation.""" | |
| def __init__(self, in_dim: int, hidden_dim: int, dropout: float = 0.1): | |
| super().__init__() | |
| self.fc1 = nn.Linear(in_dim, hidden_dim) | |
| self.fc2 = nn.Linear(hidden_dim, in_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x = self.fc1(x) | |
| x = F.gelu(x) | |
| x = self.dropout(x) | |
| x = self.fc2(x) | |
| x = self.dropout(x) | |
| return x | |
| class RotaryPositionEmbedding(nn.Module): | |
| """Rotary Position Embedding (RoPE) for length generalization.""" | |
| def __init__(self, dim: int, max_seq_len: int = 512): | |
| super().__init__() | |
| self.dim = dim | |
| inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) | |
| self.register_buffer("inv_freq", inv_freq) | |
| def forward(self, x: torch.Tensor, seq_len: int) -> torch.Tensor: | |
| t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) | |
| freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
| emb = torch.cat((freqs, freqs), dim=-1) | |
| cos, sin = emb.cos(), emb.sin() | |
| cos = cos.unsqueeze(0) | |
| sin = sin.unsqueeze(0) | |
| x_rot = x[..., : self.dim] | |
| x_rest = x[..., self.dim :] if x.shape[-1] > self.dim else None | |
| x1, x2 = x_rot[..., ::2], x_rot[..., 1::2] | |
| rotated = torch.stack([-x2, x1], dim=-1).flatten(-2) | |
| x_rotated = x_rot * cos + rotated * sin | |
| if x_rest is not None: | |
| return torch.cat([x_rotated, x_rest], dim=-1) | |
| return x_rotated | |
| class HyperMixing(nn.Module): | |
| """HyperMixing: O(S) token mixing via cumulative-mean hypernetwork. | |
| Based on HyperMixer (ACL 2023). Uses running statistics to generate | |
| mixing weights dynamically. | |
| """ | |
| def __init__(self, hidden_dim: int, hyper_hidden_dim: int, dropout: float = 0.1): | |
| super().__init__() | |
| self.hidden_dim = hidden_dim | |
| self.hyper = nn.Sequential( | |
| nn.Linear(hidden_dim, hyper_hidden_dim), | |
| nn.GELU(), | |
| nn.Linear(hyper_hidden_dim, hidden_dim * 2), | |
| ) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| B, S, H = x.shape | |
| # Cumulative mean for causal context | |
| cumsum = torch.cumsum(x, dim=1) | |
| counts = torch.arange(1, S + 1, device=x.device).view(1, S, 1).float() | |
| pooled = cumsum / counts | |
| # Hypernetwork generates affine transform weights | |
| weights = self.hyper(pooled) | |
| w1, w2 = weights.chunk(2, dim=-1) | |
| # Affine mixing: scale + shift | |
| x = x * w1 + w2 | |
| return self.dropout(x) | |
| class FourierMixing(nn.Module): | |
| """FourierMixing: Parameter-free token mixing via FFT. | |
| Based on FNet (NAACL 2022). Replaces attention with 2D FFT. | |
| - Zero learnable parameters for token mixing | |
| - O(S log S) complexity | |
| - 80% faster than attention on GPUs | |
| Causal property: FFT mixes all positions, so we apply | |
| cumulative masking to maintain autoregressive property. | |
| """ | |
| def __init__(self, hidden_dim: int, dropout: float = 0.1): | |
| super().__init__() | |
| self.dropout = nn.Dropout(dropout) | |
| # Learnable scaling for output (optional, helps stability) | |
| self.scale = nn.Parameter(torch.ones(1) * 0.1) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| B, S, H = x.shape | |
| # Apply FFT along sequence dimension (dim=1) | |
| # Real-valued FFT preserves real output | |
| x_fft = torch.fft.fft(x, dim=1).real | |
| # Apply causal masking: each position only sees itself and prior | |
| # Use cumulative sum to enforce causality | |
| mask = torch.triu(torch.ones(S, S, device=x.device)).bool() | |
| # More efficient: use cumulative mean like HyperMixing | |
| cumsum = torch.cumsum(x_fft, dim=1) | |
| counts = torch.arange(1, S + 1, device=x.device).view(1, S, 1).float() | |
| x_causal = cumsum / counts | |
| # Blend original FFT output with causal version | |
| x = x_fft * (1 - self.scale) + x_causal * self.scale | |
| return self.dropout(x) | |
| class MicroMixerLayer(nn.Module): | |
| """Single MicroMixer layer with DropPath regularization. | |
| Architecture: | |
| 1. LayerNorm -> Token Mixing -> DropPath -> Residual | |
| 2. LayerNorm -> Channel Mixing (MLP) -> DropPath -> Residual | |
| """ | |
| def __init__( | |
| self, | |
| hidden_dim: int, | |
| hyper_hidden_dim: int, | |
| channel_mlp_dim: int, | |
| dropout: float = 0.1, | |
| drop_path: float = 0.0, | |
| mixer_type: TokenMixerType = TokenMixerType.HYPER, | |
| ): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(hidden_dim) | |
| self.norm2 = nn.LayerNorm(hidden_dim) | |
| if mixer_type == TokenMixerType.HYPER: | |
| self.token_mixer = HyperMixing(hidden_dim, hyper_hidden_dim, dropout) | |
| elif mixer_type == TokenMixerType.FOURIER: | |
| self.token_mixer = FourierMixing(hidden_dim, dropout) | |
| else: | |
| raise ValueError(f"Unknown mixer type: {mixer_type}") | |
| self.channel_mlp = MlpBlock(hidden_dim, channel_mlp_dim, dropout) | |
| self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity() | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # Token mixing with stochastic depth | |
| residual = x | |
| x = self.norm1(x) | |
| x = self.token_mixer(x) | |
| x = residual + self.drop_path(x) | |
| # Channel mixing with stochastic depth | |
| residual = x | |
| x = self.norm2(x) | |
| x = self.channel_mlp(x) | |
| x = residual + self.drop_path(x) | |
| return x | |
| class MicroMixerConfig: | |
| """Configuration for MicroMixer-2 V4. | |
| Attributes: | |
| vocab_size: Vocabulary size (256 for byte-level). | |
| max_seq_len: Maximum sequence length. | |
| hidden_dim: Hidden dimension for embeddings and mixer layers. | |
| hyper_hidden_dim: Hidden dimension for HyperMixing hypernetwork. | |
| channel_mlp_dim: Inner dimension of channel-mixing MLP. | |
| num_layers: Number of mixer layers. | |
| dropout: Dropout probability. | |
| drop_path: DropPath probability (0 = disabled). | |
| label_smoothing: Label smoothing for cross-entropy (0 = disabled). | |
| tie_weights: Tie input/output embeddings. | |
| mixer_type: Token mixing strategy (HYPER or FOURIER). | |
| pad_token_id: Padding token ID for masked loss. | |
| """ | |
| vocab_size: int = 256 | |
| max_seq_len: int = 128 | |
| hidden_dim: int = 128 | |
| hyper_hidden_dim: int = 64 | |
| channel_mlp_dim: int = 256 | |
| num_layers: int = 2 | |
| dropout: float = 0.1 | |
| drop_path: float = 0.0 | |
| label_smoothing: float = 0.0 | |
| tie_weights: bool = True | |
| mixer_type: TokenMixerType = TokenMixerType.HYPER | |
| pad_token_id: int = 0 | |
| class MicroMixer(nn.Module): | |
| """MicroMixer-2 V4: MLP-Mixer language model with research-backed innovations. | |
| V4 improvements: | |
| - DropPath: Stochastic depth regularization | |
| - FourierMixing: Optional parameter-free FFT mixing | |
| - Padding-aware loss: Ignores padding tokens | |
| - Label smoothing: Regularizes overconfident predictions | |
| - Increased depth: Up to 12 layers for larger models | |
| """ | |
| def __init__(self, config: MicroMixerConfig): | |
| super().__init__() | |
| self.config = config | |
| self.vocab_size = config.vocab_size | |
| self.max_seq_len = config.max_seq_len | |
| self.hidden_dim = config.hidden_dim | |
| self.pad_token_id = config.pad_token_id | |
| self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_dim) | |
| self.rope = RotaryPositionEmbedding(config.hidden_dim, config.max_seq_len) | |
| self.dropout = nn.Dropout(config.dropout) | |
| # DropPath: linear schedule (increases with depth) | |
| dpr = [x.item() for x in torch.linspace(0, config.drop_path, config.num_layers)] | |
| self.mixer_layers = nn.ModuleList([ | |
| MicroMixerLayer( | |
| config.hidden_dim, | |
| config.hyper_hidden_dim, | |
| config.channel_mlp_dim, | |
| config.dropout, | |
| drop_path=dpr[i], | |
| mixer_type=config.mixer_type, | |
| ) | |
| for i in range(config.num_layers) | |
| ]) | |
| self.layer_norm = nn.LayerNorm(config.hidden_dim) | |
| self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size, bias=False) | |
| if config.tie_weights: | |
| self.lm_head.weight = self.token_embedding.weight | |
| self.apply(self._init_weights) | |
| def _init_weights(self, module: nn.Module): | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| if getattr(module, "bias", None) is not None: | |
| torch.nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| elif isinstance(module, nn.LayerNorm): | |
| torch.nn.init.ones_(module.weight) | |
| torch.nn.init.zeros_(module.bias) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| targets: torch.Tensor | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: | |
| B, S = input_ids.shape | |
| if S > self.max_seq_len: | |
| input_ids = input_ids[:, -self.max_seq_len :] | |
| S = self.max_seq_len | |
| if targets is not None: | |
| targets = targets[:, -self.max_seq_len :] | |
| if attention_mask is not None: | |
| attention_mask = attention_mask[:, -self.max_seq_len :] | |
| token_emb = self.token_embedding(input_ids) | |
| x = self.rope(token_emb, S) | |
| x = self.dropout(x) | |
| for layer in self.mixer_layers: | |
| x = layer(x) | |
| x = self.layer_norm(x) | |
| logits = self.lm_head(x) | |
| if targets is not None: | |
| # Flatten for cross-entropy | |
| logits_flat = logits.view(-1, self.vocab_size) | |
| targets_flat = targets.view(-1) | |
| # Build ignore_mask: padding tokens AND positions after padding | |
| if attention_mask is not None: | |
| # Shift mask: predict token AFTER seeing context | |
| # mask[i] = 1 means token i is real, so predicting i+1 is valid | |
| shifted_mask = torch.ones_like(attention_mask) | |
| shifted_mask[:, 1:] = attention_mask[:, :-1] | |
| ignore_mask = (shifted_mask.view(-1) == 0) | |
| pad_indices = ignore_mask.nonzero(as_tuple=True)[0] | |
| else: | |
| # No mask provided: only ignore explicit pad tokens in targets | |
| pad_indices = (targets_flat == self.pad_token_id).nonzero(as_tuple=True)[0] | |
| # Compute loss with label smoothing and padding ignore | |
| loss = F.cross_entropy( | |
| logits_flat, | |
| targets_flat, | |
| ignore_index=self.pad_token_id if len(pad_indices) > 0 else -100, | |
| label_smoothing=self.config.label_smoothing, | |
| ) | |
| return logits, loss | |
| return logits | |
| def generate( | |
| self, | |
| input_ids: torch.Tensor, | |
| max_new_tokens: int, | |
| temperature: float = 1.0, | |
| top_k: int | None = None, | |
| ) -> torch.Tensor: | |
| """Autoregressive text generation.""" | |
| self.eval() | |
| device = next(self.parameters()).device | |
| input_ids = input_ids.to(device) | |
| for _ in range(max_new_tokens): | |
| logits = self(input_ids) | |
| logits = logits[:, -1, :] | |
| if temperature == 0.0: | |
| next_token = torch.argmax(logits, dim=-1, keepdim=True) | |
| else: | |
| logits = logits / temperature | |
| if top_k is not None: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits[logits < v[:, [-1]]] = -float("Inf") | |
| probs = F.softmax(logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| input_ids = torch.cat([input_ids, next_token], dim=1) | |
| return input_ids | |
| def count_parameters(model: nn.Module) -> int: | |
| """Count total trainable parameters.""" | |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| def micromixer_100k() -> MicroMixerConfig: | |
| """~100K parameter model for testing/experimentation.""" | |
| return MicroMixerConfig( | |
| max_seq_len=64, | |
| hidden_dim=84, | |
| hyper_hidden_dim=48, | |
| channel_mlp_dim=128, | |
| num_layers=3, | |
| dropout=0.1, | |
| drop_path=0.0, | |
| label_smoothing=0.0, | |
| ) | |
| def micromixer_300k() -> MicroMixerConfig: | |
| """~300K parameter model for small-scale experiments.""" | |
| return MicroMixerConfig( | |
| max_seq_len=128, | |
| hidden_dim=128, | |
| hyper_hidden_dim=64, | |
| channel_mlp_dim=288, | |
| num_layers=4, | |
| dropout=0.1, | |
| drop_path=0.05, | |
| label_smoothing=0.05, | |
| ) | |
| def micromixer_500k() -> MicroMixerConfig: | |
| """~500K parameter model for medium-scale experiments.""" | |
| return MicroMixerConfig( | |
| max_seq_len=128, | |
| hidden_dim=176, | |
| hyper_hidden_dim=88, | |
| channel_mlp_dim=384, | |
| num_layers=4, | |
| dropout=0.1, | |
| drop_path=0.1, | |
| label_smoothing=0.05, | |
| ) | |
| def micromixer_1m() -> MicroMixerConfig: | |
| """~1M parameter model for standard experiments.""" | |
| return MicroMixerConfig( | |
| max_seq_len=256, | |
| hidden_dim=168, | |
| hyper_hidden_dim=84, | |
| channel_mlp_dim=448, | |
| num_layers=5, | |
| dropout=0.1, | |
| drop_path=0.1, | |
| label_smoothing=0.1, | |
| ) | |
| def micromixer_1m_long(max_seq_len: int = 4096) -> MicroMixerConfig: | |
| """~1M parameter model with extended context length.""" | |
| return MicroMixerConfig( | |
| max_seq_len=max_seq_len, | |
| hidden_dim=168, | |
| hyper_hidden_dim=84, | |
| channel_mlp_dim=448, | |
| num_layers=5, | |
| dropout=0.1, | |
| drop_path=0.1, | |
| label_smoothing=0.1, | |
| ) | |
| def micromixer_1m_fourier() -> MicroMixerConfig: | |
| """~1M parameter model with FourierMixing (parameter-free token mixing).""" | |
| return MicroMixerConfig( | |
| max_seq_len=256, | |
| hidden_dim=168, | |
| hyper_hidden_dim=84, | |
| channel_mlp_dim=448, | |
| num_layers=5, | |
| dropout=0.1, | |
| drop_path=0.1, | |
| label_smoothing=0.1, | |
| mixer_type=TokenMixerType.FOURIER, | |
| ) | |
| if __name__ == "__main__": | |
| print("Testing MicroMixer-2 V4...") | |
| for name, config_fn in [ | |
| ("100k", micromixer_100k), | |
| ("300k", micromixer_300k), | |
| ("500k", micromixer_500k), | |
| ("1M", micromixer_1m), | |
| ("1M-fourier", micromixer_1m_fourier), | |
| ]: | |
| config = config_fn() | |
| model = MicroMixer(config) | |
| params = count_parameters(model) | |
| print(f" {name}: {params:,} parameters, {config.num_layers} layers") | |
| batch_size = 2 | |
| seq_len = min(32, config.max_seq_len) | |
| input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len)) | |
| logits = model(input_ids) | |
| assert logits.shape == (batch_size, seq_len, config.vocab_size) | |
| targets = torch.randint(0, config.vocab_size, (batch_size, seq_len)) | |
| logits, loss = model(input_ids, targets) | |
| assert logits.shape == (batch_size, seq_len, config.vocab_size) | |
| assert loss.dim() == 0 | |
| prompt = input_ids[:, :4] | |
| gen_ids = model.generate(prompt, max_new_tokens=8, temperature=0.8, top_k=10) | |
| assert gen_ids.shape == (batch_size, 12) | |
| prefix = torch.randint(0, config.vocab_size, (1, 5)) | |
| extra = torch.randint(0, config.vocab_size, (1, 3)) | |
| logits_prefix = model(prefix)[:, -1, :] | |
| logits_extended = model(torch.cat([prefix, extra], dim=1))[:, 4, :] | |
| assert torch.allclose(logits_prefix, logits_extended, atol=1e-5) | |
| print("All V4 tests passed!") | |