"""AAM Diffusion LLM — Matryoshka Elastic Inference SwiGLU FFN with nested submodel extraction. One training → many deployment sizes. Also replaces the old GELU FFN with SwiGLU (proven better in LLaMA/Mistral). """ from __future__ import annotations import math import copy from dataclasses import dataclass, field from typing import Optional, List, Dict, Any, Tuple import torch import torch.nn as nn import torch.nn.functional as F @dataclass class MatryoshkaConfig: d_model: int = 768 d_ff: int = 3072 granularity_factors: List[float] = field(default_factory=lambda: [0.25, 0.5, 0.75, 1.0]) matryoshka_loss_weight: float = 0.1 use_adaptive: bool = True def __post_init__(self) -> None: if not self.granularity_factors: raise ValueError("granularity_factors cannot be empty") if not all(0 < f <= 1.0 for f in self.granularity_factors): raise ValueError("All granularity_factors must be in (0, 1.0]") class MatryoshkaLayer(nn.Module): """Matryoshka FFN Layer — SwiGLU with nested elastic inference. SwiGLU: output = down_proj(SiLU(gate_proj(x)) * up_proj(x)) Nested structure allows extracting smaller valid submodels. """ def __init__(self, config: MatryoshkaConfig) -> None: super().__init__() self.config = config self.d_model = config.d_model self.d_ff = config.d_ff self.granularity_factors = sorted(config.granularity_factors) self.gate_proj = nn.Linear(config.d_model, config.d_ff, bias=False) self.up_proj = nn.Linear(config.d_model, config.d_ff, bias=False) self.down_proj = nn.Linear(config.d_ff, config.d_model, bias=False) if config.use_adaptive: self.size_selector = nn.Sequential( nn.Linear(config.d_model, config.d_model // 8, bias=False), nn.SiLU(), nn.Linear(config.d_model // 8, 1, bias=False), nn.Sigmoid(), ) def forward( self, x: torch.Tensor, granularity_factor: Optional[float] = None, ) -> Tuple[torch.Tensor, Dict[str, Any]]: factor = granularity_factor or 1.0 factor = min(max(factor, min(self.granularity_factors)), 1.0) if granularity_factor is None and self.config.use_adaptive: score = self.size_selector(x.mean(dim=1, keepdim=False)) factor = self._score_to_factor(score.mean().item()) d_ff_active = max(1, int(self.d_ff * factor)) if factor >= 1.0: gate = F.silu(self.gate_proj(x)) up = self.up_proj(x) output = self.down_proj(gate * up) else: gate_weight = self.gate_proj.weight[:d_ff_active, :] up_weight = self.up_proj.weight[:d_ff_active, :] down_weight = self.down_proj.weight[:, :d_ff_active] gate = F.silu(F.linear(x, gate_weight)) up = F.linear(x, up_weight) output = F.linear(gate * up, down_weight) info = { "granularity_factor": factor, "d_ff_active": d_ff_active, "d_ff_total": self.d_ff, } return output, info def _score_to_factor(self, score: float) -> float: min_dist = float("inf") best_factor = self.granularity_factors[-1] for f in self.granularity_factors: dist = abs(score - f) if dist < min_dist: min_dist = dist best_factor = f return best_factor def compute_matryoshka_loss( self, x: torch.Tensor, target: torch.Tensor, loss_fn: Any = None, ) -> torch.Tensor: if loss_fn is None: loss_fn = nn.MSELoss() total_loss = torch.tensor(0.0, device=x.device) for factor in self.granularity_factors: output, _ = self.forward(x, granularity_factor=factor) sub_loss = loss_fn(output, target) total_loss = total_loss + sub_loss total_loss = total_loss / len(self.granularity_factors) return total_loss * self.config.matryoshka_loss_weight def extract_submodel(self, granularity_factor: float) -> Dict[str, nn.Parameter]: d_ff_sub = max(1, int(self.d_ff * granularity_factor)) return { "gate_proj.weight": self.gate_proj.weight[:d_ff_sub, :].clone(), "up_proj.weight": self.up_proj.weight[:d_ff_sub, :].clone(), "down_proj.weight": self.down_proj.weight[:, :d_ff_sub].clone(), } class ElasticExtractor: """Extract model at various sizes for deployment.""" def __init__(self, model: nn.Module) -> None: self.model = model def extract(self, granularity_factor: float) -> nn.Module: submodel = copy.deepcopy(self.model) for name, module in submodel.named_modules(): if isinstance(module, MatryoshkaLayer): d_ff_sub = max(1, int(module.d_ff * granularity_factor)) with torch.no_grad(): module.gate_proj.weight.data = module.gate_proj.weight.data[:d_ff_sub, :].clone() module.up_proj.weight.data = module.up_proj.weight.data[:d_ff_sub, :].clone() module.down_proj.weight.data = module.down_proj.weight.data[:, :d_ff_sub].clone() module.d_ff = d_ff_sub module.gate_proj.out_features = d_ff_sub module.up_proj.out_features = d_ff_sub module.down_proj.in_features = d_ff_sub return submodel def get_available_sizes(self) -> List[Dict[str, Any]]: factors = set() for name, module in self.model.named_modules(): if isinstance(module, MatryoshkaLayer): factors.update(module.granularity_factors) factors = sorted(factors) total_params = sum(p.numel() for p in self.model.parameters()) sizes = [] for factor in factors: estimated = int(total_params * factor) sizes.append({ "granularity_factor": factor, "estimated_parameters": estimated, "parameter_label": f"~{estimated / 1e6:.0f}M" if estimated < 1e9 else f"~{estimated / 1e9:.1f}B", }) return sizes def mix_and_match(self, layer_factors: Dict[int, float]) -> nn.Module: submodel = copy.deepcopy(self.model) layer_idx = 0 for name, module in submodel.named_modules(): if isinstance(module, MatryoshkaLayer): factor = layer_factors.get(layer_idx, 1.0) d_ff_sub = max(1, int(module.d_ff * factor)) with torch.no_grad(): module.gate_proj.weight.data = module.gate_proj.weight.data[:d_ff_sub, :].clone() module.up_proj.weight.data = module.up_proj.weight.data[:d_ff_sub, :].clone() module.down_proj.weight.data = module.down_proj.weight.data[:, :d_ff_sub].clone() module.d_ff = d_ff_sub module.gate_proj.out_features = d_ff_sub module.up_proj.out_features = d_ff_sub module.down_proj.in_features = d_ff_sub layer_idx += 1 return submodel