""" ASTRAI Pluto — native architecture for the Pluto family. A standalone decoder-only Transformer with: * RMSNorm + RoPE (no learned positional embeddings) * Causal SDPA attention (multi-head, optional GQA) * Top-K Mixture-of-Experts (SwiGLU experts), no required shared expert * Multi-Token Prediction heads (training-only) * Tied input/output embedding * Router auxiliary loss (load balance) + z-loss Not derived from any HuggingFace base model — fresh implementation in plain PyTorch. Save/load uses a `pluto_config.json` + a safetensors weights file. Naming: `PlutoModel` / `PlutoForCausalLM`. The `_meta` dict on the config holds size hyper-params; routing / aux-loss config is on its own dataclass. """ from __future__ import annotations import json import math import os from dataclasses import asdict, dataclass, field from pathlib import Path from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F # ─── Config ───────────────────────────────────────────────────────────── @dataclass class PlutoConfig: # Architecture (multilingual Nano — d=384, layers=16, GQA, 32k vocab) vocab_size: int = 32768 hidden_size: int = 384 intermediate_size_expert: int = 1536 intermediate_size_shared: int = 0 # 0 = no shared expert n_layers: int = 16 n_heads: int = 6 n_kv_heads: int = 2 # GQA: 6→2 → ~50 % attn-param saving n_experts: int = 35 # 5 langs × 7 experts each top_k: int = 1 # max sparsity → ~50 M active inference n_languages: int = 5 # en, pt, es, zh, hi max_position_embeddings: int = 4096 rope_theta: float = 1_000_000.0 rms_norm_eps: float = 1e-6 tie_word_embeddings: bool = True # MTP — training-only aux heads mtp_depth: int = 2 mtp_loss_weight: float = 0.15 # Routing aux losses router_aux_loss_coef: float = 0.01 router_z_loss_coef: float = 0.001 # Bookkeeping model_type: str = "astrai_pluto" pad_token_id: int | None = None bos_token_id: int | None = None eos_token_id: int | None = None # Tokenizer config (saved for convenience) tokenizer_name: str | None = None def to_dict(self) -> dict: return asdict(self) @classmethod def from_dict(cls, d: dict) -> "PlutoConfig": # ignore extra keys silently for forward-compat known = {f.name for f in cls.__dataclass_fields__.values()} return cls(**{k: v for k, v in d.items() if k in known}) def save(self, output_dir: str | Path) -> None: os.makedirs(output_dir, exist_ok=True) with open(Path(output_dir) / "pluto_config.json", "w") as f: json.dump(self.to_dict(), f, indent=2) @classmethod def load(cls, model_dir: str | Path) -> "PlutoConfig": with open(Path(model_dir) / "pluto_config.json") as f: return cls.from_dict(json.load(f)) # ─── Layers ───────────────────────────────────────────────────────────── class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(dim)) self.eps = eps def forward(self, x: torch.Tensor) -> torch.Tensor: # Compute in fp32 for numerical stability, return in input dtype out = x.float() norm = out.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() return (out * norm).to(x.dtype) * self.weight def _rope_freqs(dim: int, base: float, device, dtype=torch.float32) -> torch.Tensor: inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=dtype) / dim)) return inv_freq def _rope_cache(seq_len: int, dim: int, base: float, device) -> tuple[torch.Tensor, torch.Tensor]: inv_freq = _rope_freqs(dim, base, device) t = torch.arange(seq_len, device=device, dtype=torch.float32) freqs = torch.outer(t, inv_freq) cos = freqs.cos() sin = freqs.sin() return cos, sin def _apply_rope(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): # q, k: [B, H, T, Dh]; cos, sin: [T, Dh/2] def rotate(x: torch.Tensor) -> torch.Tensor: x1, x2 = x[..., ::2], x[..., 1::2] rot = torch.stack((-x2 * sin + x1 * cos, x1 * sin + x2 * cos), dim=-1) return rot.flatten(-2) return rotate(q), rotate(k) class PlutoAttention(nn.Module): """Causal SDPA attention with optional GQA + RoPE.""" def __init__(self, cfg: PlutoConfig): super().__init__() assert cfg.hidden_size % cfg.n_heads == 0 self.cfg = cfg self.head_dim = cfg.hidden_size // cfg.n_heads self.q_proj = nn.Linear(cfg.hidden_size, cfg.n_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(cfg.hidden_size, cfg.n_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(cfg.hidden_size, cfg.n_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False) def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: B, T, D = x.shape H = self.cfg.n_heads Hk = self.cfg.n_kv_heads Dh = self.head_dim q = self.q_proj(x).view(B, T, H, Dh).transpose(1, 2) # [B, H, T, Dh] k = self.k_proj(x).view(B, T, Hk, Dh).transpose(1, 2) # [B, Hk, T, Dh] v = self.v_proj(x).view(B, T, Hk, Dh).transpose(1, 2) q, k = _apply_rope(q, k, cos[:T].to(q.dtype), sin[:T].to(q.dtype)) # GQA: expand kv if Hk < H if Hk != H: repeats = H // Hk k = k.repeat_interleave(repeats, dim=1) v = v.repeat_interleave(repeats, dim=1) y = F.scaled_dot_product_attention(q, k, v, is_causal=True) y = y.transpose(1, 2).contiguous().view(B, T, D) return self.o_proj(y) class SwiGLU(nn.Module): def __init__(self, dim: int, hidden: int): super().__init__() self.w_gate = nn.Linear(dim, hidden, bias=False) self.w_up = nn.Linear(dim, hidden, bias=False) self.w_down = nn.Linear(hidden, dim, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)) class PlutoMoE(nn.Module): """Top-K MoE using grouped matmul (torch._grouped_mm). Expert weights are kept as 3 stacked tensors of shape [E, D, H] (gate, up) and [E, H, D] (down) so the whole layer is 3 grouped GEMMs per forward. Currently specialised for top_k == 1 (sort once, no aggregation). Top-K>1 falls back to the per-expert loop. Optional shared expert (always active) if intermediate_size_shared > 0. """ def __init__(self, cfg: PlutoConfig): super().__init__() self.cfg = cfg E, D, H = cfg.n_experts, cfg.hidden_size, cfg.intermediate_size_expert self.router = nn.Linear(D, E, bias=False) # SwiGLU expert weights stacked along the expert dim. # `_grouped_mm(A, B, offs)` expects B in [E, K, N] for A in [M, K] # → output [M, N]. So we store: # W_gate: [E, D, H] → x @ W_gate → [M, H] # W_up: [E, D, H] # W_down: [E, H, D] self.W_gate = nn.Parameter(torch.empty(E, D, H)) self.W_up = nn.Parameter(torch.empty(E, D, H)) self.W_down = nn.Parameter(torch.empty(E, H, D)) # Init: Kaiming-like, scaled down so initial residual is well-behaved. std_in = 1.0 / math.sqrt(D) std_h = 1.0 / math.sqrt(H) nn.init.normal_(self.W_gate, std=std_in) nn.init.normal_(self.W_up, std=std_in) nn.init.normal_(self.W_down, std=std_h) self.shared = (SwiGLU(D, cfg.intermediate_size_shared) if cfg.intermediate_size_shared > 0 else None) @staticmethod def _offsets_from_counts(counts: torch.Tensor) -> torch.Tensor: # Convert [E] counts → end-offset tensor [E] of int32. # `torch._grouped_mm` consumes end-offsets (exclusive cumsum). return counts.cumsum(0).to(torch.int32) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict]: B, T, D = x.shape E = self.cfg.n_experts x_flat = x.reshape(B * T, D) logits = self.router(x_flat) # [B*T, E] if self.cfg.top_k == 1: # Sort tokens by expert id → contiguous expert ranges → grouped GEMM top_idx = logits.argmax(dim=-1) # [B*T] sort_idx = top_idx.argsort(stable=True) x_sorted = x_flat[sort_idx] # [B*T, D] counts = torch.bincount(top_idx, minlength=E) # [E] offsets = self._offsets_from_counts(counts) # [E] end-offsets # Grouped SwiGLU: each token uses ONE expert. gate = torch._grouped_mm(x_sorted, self.W_gate, offsets) # [B*T, H] up = torch._grouped_mm(x_sorted, self.W_up, offsets) # [B*T, H] hidden = F.silu(gate) * up out_sorted = torch._grouped_mm(hidden, self.W_down, offsets) # [B*T, D] # Un-sort inverse = torch.empty_like(sort_idx) inverse[sort_idx] = torch.arange(sort_idx.size(0), device=x.device) out = out_sorted[inverse] else: # Top-K>1 fallback: slower loop. Kept for completeness. topk_vals, topk_idx = logits.topk(self.cfg.top_k, dim=-1) topk_w = F.softmax(topk_vals, dim=-1) out = torch.zeros_like(x_flat) for k in range(self.cfg.top_k): ids = topk_idx[..., k] w = topk_w[..., k].unsqueeze(-1) # Per-K grouped GEMM sort_idx = ids.argsort(stable=True) x_sorted = x_flat[sort_idx] counts = torch.bincount(ids, minlength=E) offsets = self._offsets_from_counts(counts) gate = torch._grouped_mm(x_sorted, self.W_gate, offsets) up = torch._grouped_mm(x_sorted, self.W_up, offsets) hidden = F.silu(gate) * up out_sorted = torch._grouped_mm(hidden, self.W_down, offsets) inverse = torch.empty_like(sort_idx) inverse[sort_idx] = torch.arange(sort_idx.size(0), device=x.device) out = out + out_sorted[inverse] * w top_idx = topk_idx[..., 0] # for aux-loss bookkeeping below if self.shared is not None: out = out + self.shared(x_flat) out = out.reshape(B, T, D) # Auxiliary losses (Switch Transformer load-balance + ST-MoE z-loss) aux: dict = {} if self.training: probs = F.softmax(logits.float(), dim=-1) expert_freq = probs.mean(dim=0) # [E] counts_norm = (counts.float() / counts.float().sum().clamp_min(1.0)) aux["aux_load"] = (expert_freq * counts_norm).sum() * self.cfg.n_experts aux["aux_z"] = (logits.float().logsumexp(-1) ** 2).mean() return out, aux class PlutoBlock(nn.Module): def __init__(self, cfg: PlutoConfig): super().__init__() self.ln1 = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps) self.attn = PlutoAttention(cfg) self.ln2 = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps) self.moe = PlutoMoE(cfg) def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> tuple[torch.Tensor, dict]: x = x + self.attn(self.ln1(x), cos, sin) y, aux = self.moe(self.ln2(x)) x = x + y return x, aux # ─── Models ───────────────────────────────────────────────────────────── class PlutoModel(nn.Module): """Decoder backbone: token embed → N blocks → final RMSNorm.""" def __init__(self, cfg: PlutoConfig): super().__init__() self.cfg = cfg self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.hidden_size) self.blocks = nn.ModuleList([PlutoBlock(cfg) for _ in range(cfg.n_layers)]) self.final_norm = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps) self.register_buffer("_rope_initialised", torch.tensor(False), persistent=False) self._rope_cos = None self._rope_sin = None def _ensure_rope(self, seq_len: int, device, dtype): head_dim = self.cfg.hidden_size // self.cfg.n_heads if (self._rope_cos is None or self._rope_cos.size(0) < seq_len or self._rope_cos.device != device): cos, sin = _rope_cache(self.cfg.max_position_embeddings, head_dim, self.cfg.rope_theta, device) self._rope_cos = cos.to(dtype) self._rope_sin = sin.to(dtype) def forward(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, list[dict]]: B, T = input_ids.shape h = self.embed_tokens(input_ids) self._ensure_rope(T, h.device, h.dtype) aux_list = [] for blk in self.blocks: h, aux = blk(h, self._rope_cos, self._rope_sin) aux_list.append(aux) h = self.final_norm(h) return h, aux_list class PlutoForCausalLM(nn.Module): """LM head + optional MTP heads. Returns full loss in `forward`.""" def __init__(self, cfg: PlutoConfig): super().__init__() self.cfg = cfg self.model = PlutoModel(cfg) self.lm_head = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False) if cfg.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight # MTP — training-only auxiliary heads that predict tokens further ahead. self.mtp_heads = nn.ModuleList([ nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False) for _ in range(cfg.mtp_depth) ]) def forward(self, input_ids: torch.Tensor, labels: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, ) -> dict: # We only honour `labels` from the training harness (HF API). if labels is None: labels = input_ids h, aux_list = self.model(input_ids) logits = self.lm_head(h) out = {"logits": logits} # Main next-token loss. Trainer is expected to pass `input_ids = ids[:-1]` # and `labels = ids[1:]` so they already align (no internal shift). if labels is not None and labels.size(1) == logits.size(1): ce = F.cross_entropy( logits.float().view(-1, logits.size(-1)), labels.view(-1), ignore_index=-100, ) loss = ce # MTP auxiliary losses: head d predicts the token d positions ahead. # Skip entirely when mtp_loss_weight == 0 to save the per-head matmul # against the full vocab — that head alone is ~15-20 % of step time. if self.cfg.mtp_depth > 0 and self.cfg.mtp_loss_weight > 0: mtp_total = 0.0 for d, head in enumerate(self.mtp_heads, start=1): if labels.size(1) <= d: continue logits_d = head(h)[:, :-d, :].contiguous() labels_d = labels[:, d:].contiguous() mtp_total = mtp_total + F.cross_entropy( logits_d.float().view(-1, logits_d.size(-1)), labels_d.view(-1), ignore_index=-100, ) loss = loss + self.cfg.mtp_loss_weight * (mtp_total / max(self.cfg.mtp_depth, 1)) # Router aux losses (averaged over layers) if aux_list and "aux_load" in aux_list[0]: aux_load = torch.stack([a["aux_load"] for a in aux_list]).mean() aux_z = torch.stack([a["aux_z"] for a in aux_list]).mean() loss = (loss + self.cfg.router_aux_loss_coef * aux_load + self.cfg.router_z_loss_coef * aux_z) out["loss"] = loss return out # ─── Save / load ──────────────────────────────────────────────────────── def save_pluto(model: PlutoForCausalLM, output_dir: str | Path) -> None: model.cfg.save(output_dir) from safetensors.torch import save_model # `save_model` handles tied weights (embed↔lm_head) by deduplicating them. # We must NOT permanently move the model to CPU — restore device after save. devices = {p.device for p in model.parameters()} device = next(iter(devices)) if len(devices) == 1 else None model_cpu = model.cpu() save_model(model_cpu, str(Path(output_dir) / "model.safetensors")) if device is not None and device.type != "cpu": model.to(device) def load_pluto(model_dir: str | Path, dtype=torch.bfloat16, map_location="cpu") -> PlutoForCausalLM: cfg = PlutoConfig.load(model_dir) model = PlutoForCausalLM(cfg).to(dtype) from safetensors.torch import load_file state = load_file(str(Path(model_dir) / "model.safetensors"), device=str(map_location)) model.load_state_dict(state, strict=False) return model # ─── Param accounting ────────────────────────────────────────────────── def count_params(model: nn.Module) -> int: return sum(p.numel() for p in model.parameters()) def estimate_active_params(cfg: PlutoConfig) -> dict: """At-inference active params (MTP heads NOT counted, since they are training-only).""" head_dim = cfg.hidden_size // cfg.n_heads attn_per_layer = ( cfg.hidden_size * cfg.n_heads * head_dim # q_proj + cfg.hidden_size * cfg.n_kv_heads * head_dim # k_proj + cfg.hidden_size * cfg.n_kv_heads * head_dim # v_proj + cfg.hidden_size * cfg.hidden_size # o_proj ) expert_size = 3 * cfg.hidden_size * cfg.intermediate_size_expert # SwiGLU shared_size = (3 * cfg.hidden_size * cfg.intermediate_size_shared if cfg.intermediate_size_shared > 0 else 0) active_per_layer = attn_per_layer + cfg.top_k * expert_size + shared_size active_total = active_per_layer * cfg.n_layers # lm_head is also "active" (full matmul against vocab) active_total += cfg.vocab_size * cfg.hidden_size total_experts = expert_size * cfg.n_experts * cfg.n_layers total_shared = shared_size * cfg.n_layers total_attn = attn_per_layer * cfg.n_layers emb_params = cfg.vocab_size * cfg.hidden_size lm_head_params = 0 if cfg.tie_word_embeddings else cfg.vocab_size * cfg.hidden_size mtp_params = cfg.mtp_depth * cfg.vocab_size * cfg.hidden_size total_params = (total_experts + total_shared + total_attn + emb_params + lm_head_params + mtp_params + 2 * cfg.n_layers * cfg.hidden_size # RMSNorm weights + cfg.hidden_size) return { "total_params": total_params, "active_inference_params": active_total, "expert_total_params": total_experts, "attn_total_params": total_attn, "embedding_params": emb_params, "lm_head_params": lm_head_params, "mtp_head_params": mtp_params, } if __name__ == "__main__": cfg = PlutoConfig() stats = estimate_active_params(cfg) for k, v in stats.items(): print(f" {k:<28} {v/1e6:>8.2f} M") print(f" active/total ratio {stats['active_inference_params']/stats['total_params']*100:>5.2f} %") m = PlutoForCausalLM(cfg) n_real = count_params(m) print(f"\n real (actual) total {n_real/1e6:>8.2f} M") x = torch.randint(0, cfg.vocab_size, (2, 32)) out = m(x, labels=x) print(f" fwd OK logits {tuple(out['logits'].shape)} loss={out['loss'].item():.4f}")