File size: 7,275 Bytes
7481e86 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | """AAM Diffusion LLM — Matryoshka Elastic Inference
SwiGLU FFN with nested submodel extraction. One training → many
deployment sizes. Also replaces the old GELU FFN with SwiGLU
(proven better in LLaMA/Mistral).
"""
from __future__ import annotations
import math
import copy
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
@dataclass
class MatryoshkaConfig:
d_model: int = 768
d_ff: int = 3072
granularity_factors: List[float] = field(default_factory=lambda: [0.25, 0.5, 0.75, 1.0])
matryoshka_loss_weight: float = 0.1
use_adaptive: bool = True
def __post_init__(self) -> None:
if not self.granularity_factors:
raise ValueError("granularity_factors cannot be empty")
if not all(0 < f <= 1.0 for f in self.granularity_factors):
raise ValueError("All granularity_factors must be in (0, 1.0]")
class MatryoshkaLayer(nn.Module):
"""Matryoshka FFN Layer — SwiGLU with nested elastic inference.
SwiGLU: output = down_proj(SiLU(gate_proj(x)) * up_proj(x))
Nested structure allows extracting smaller valid submodels.
"""
def __init__(self, config: MatryoshkaConfig) -> None:
super().__init__()
self.config = config
self.d_model = config.d_model
self.d_ff = config.d_ff
self.granularity_factors = sorted(config.granularity_factors)
self.gate_proj = nn.Linear(config.d_model, config.d_ff, bias=False)
self.up_proj = nn.Linear(config.d_model, config.d_ff, bias=False)
self.down_proj = nn.Linear(config.d_ff, config.d_model, bias=False)
if config.use_adaptive:
self.size_selector = nn.Sequential(
nn.Linear(config.d_model, config.d_model // 8, bias=False),
nn.SiLU(),
nn.Linear(config.d_model // 8, 1, bias=False),
nn.Sigmoid(),
)
def forward(
self,
x: torch.Tensor,
granularity_factor: Optional[float] = None,
) -> Tuple[torch.Tensor, Dict[str, Any]]:
factor = granularity_factor or 1.0
factor = min(max(factor, min(self.granularity_factors)), 1.0)
if granularity_factor is None and self.config.use_adaptive:
score = self.size_selector(x.mean(dim=1, keepdim=False))
factor = self._score_to_factor(score.mean().item())
d_ff_active = max(1, int(self.d_ff * factor))
if factor >= 1.0:
gate = F.silu(self.gate_proj(x))
up = self.up_proj(x)
output = self.down_proj(gate * up)
else:
gate_weight = self.gate_proj.weight[:d_ff_active, :]
up_weight = self.up_proj.weight[:d_ff_active, :]
down_weight = self.down_proj.weight[:, :d_ff_active]
gate = F.silu(F.linear(x, gate_weight))
up = F.linear(x, up_weight)
output = F.linear(gate * up, down_weight)
info = {
"granularity_factor": factor,
"d_ff_active": d_ff_active,
"d_ff_total": self.d_ff,
}
return output, info
def _score_to_factor(self, score: float) -> float:
min_dist = float("inf")
best_factor = self.granularity_factors[-1]
for f in self.granularity_factors:
dist = abs(score - f)
if dist < min_dist:
min_dist = dist
best_factor = f
return best_factor
def compute_matryoshka_loss(
self,
x: torch.Tensor,
target: torch.Tensor,
loss_fn: Any = None,
) -> torch.Tensor:
if loss_fn is None:
loss_fn = nn.MSELoss()
total_loss = torch.tensor(0.0, device=x.device)
for factor in self.granularity_factors:
output, _ = self.forward(x, granularity_factor=factor)
sub_loss = loss_fn(output, target)
total_loss = total_loss + sub_loss
total_loss = total_loss / len(self.granularity_factors)
return total_loss * self.config.matryoshka_loss_weight
def extract_submodel(self, granularity_factor: float) -> Dict[str, nn.Parameter]:
d_ff_sub = max(1, int(self.d_ff * granularity_factor))
return {
"gate_proj.weight": self.gate_proj.weight[:d_ff_sub, :].clone(),
"up_proj.weight": self.up_proj.weight[:d_ff_sub, :].clone(),
"down_proj.weight": self.down_proj.weight[:, :d_ff_sub].clone(),
}
class ElasticExtractor:
"""Extract model at various sizes for deployment."""
def __init__(self, model: nn.Module) -> None:
self.model = model
def extract(self, granularity_factor: float) -> nn.Module:
submodel = copy.deepcopy(self.model)
for name, module in submodel.named_modules():
if isinstance(module, MatryoshkaLayer):
d_ff_sub = max(1, int(module.d_ff * granularity_factor))
with torch.no_grad():
module.gate_proj.weight.data = module.gate_proj.weight.data[:d_ff_sub, :].clone()
module.up_proj.weight.data = module.up_proj.weight.data[:d_ff_sub, :].clone()
module.down_proj.weight.data = module.down_proj.weight.data[:, :d_ff_sub].clone()
module.d_ff = d_ff_sub
module.gate_proj.out_features = d_ff_sub
module.up_proj.out_features = d_ff_sub
module.down_proj.in_features = d_ff_sub
return submodel
def get_available_sizes(self) -> List[Dict[str, Any]]:
factors = set()
for name, module in self.model.named_modules():
if isinstance(module, MatryoshkaLayer):
factors.update(module.granularity_factors)
factors = sorted(factors)
total_params = sum(p.numel() for p in self.model.parameters())
sizes = []
for factor in factors:
estimated = int(total_params * factor)
sizes.append({
"granularity_factor": factor,
"estimated_parameters": estimated,
"parameter_label": f"~{estimated / 1e6:.0f}M" if estimated < 1e9 else f"~{estimated / 1e9:.1f}B",
})
return sizes
def mix_and_match(self, layer_factors: Dict[int, float]) -> nn.Module:
submodel = copy.deepcopy(self.model)
layer_idx = 0
for name, module in submodel.named_modules():
if isinstance(module, MatryoshkaLayer):
factor = layer_factors.get(layer_idx, 1.0)
d_ff_sub = max(1, int(module.d_ff * factor))
with torch.no_grad():
module.gate_proj.weight.data = module.gate_proj.weight.data[:d_ff_sub, :].clone()
module.up_proj.weight.data = module.up_proj.weight.data[:d_ff_sub, :].clone()
module.down_proj.weight.data = module.down_proj.weight.data[:, :d_ff_sub].clone()
module.d_ff = d_ff_sub
module.gate_proj.out_features = d_ff_sub
module.up_proj.out_features = d_ff_sub
module.down_proj.in_features = d_ff_sub
layer_idx += 1
return submodel
|