pluto-nano-0.5-base / modeling_pluto.py
ShinMK3's picture
Upload folder using huggingface_hub
184079e verified
Raw
History Blame Contribute Delete
20.4 kB
"""
ASTRAI Pluto β€” native architecture for the Pluto family.
A standalone decoder-only Transformer with:
* RMSNorm + RoPE (no learned positional embeddings)
* Causal SDPA attention (multi-head, optional GQA)
* Top-K Mixture-of-Experts (SwiGLU experts), no required shared expert
* Multi-Token Prediction heads (training-only)
* Tied input/output embedding
* Router auxiliary loss (load balance) + z-loss
Not derived from any HuggingFace base model β€” fresh implementation in plain
PyTorch. Save/load uses a `pluto_config.json` + a safetensors weights file.
Naming: `PlutoModel` / `PlutoForCausalLM`. The `_meta` dict on the config holds
size hyper-params; routing / aux-loss config is on its own dataclass.
"""
from __future__ import annotations
import json
import math
import os
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
# ─── Config ─────────────────────────────────────────────────────────────
@dataclass
class PlutoConfig:
# Architecture (multilingual Nano β€” d=384, layers=16, GQA, 32k vocab)
vocab_size: int = 32768
hidden_size: int = 384
intermediate_size_expert: int = 1536
intermediate_size_shared: int = 0 # 0 = no shared expert
n_layers: int = 16
n_heads: int = 6
n_kv_heads: int = 2 # GQA: 6β†’2 β†’ ~50 % attn-param saving
n_experts: int = 35 # 5 langs Γ— 7 experts each
top_k: int = 1 # max sparsity β†’ ~50 M active inference
n_languages: int = 5 # en, pt, es, zh, hi
max_position_embeddings: int = 4096
rope_theta: float = 1_000_000.0
rms_norm_eps: float = 1e-6
tie_word_embeddings: bool = True
# MTP β€” training-only aux heads
mtp_depth: int = 2
mtp_loss_weight: float = 0.15
# Routing aux losses
router_aux_loss_coef: float = 0.01
router_z_loss_coef: float = 0.001
# Bookkeeping
model_type: str = "astrai_pluto"
pad_token_id: int | None = None
bos_token_id: int | None = None
eos_token_id: int | None = None
# Tokenizer config (saved for convenience)
tokenizer_name: str | None = None
def to_dict(self) -> dict:
return asdict(self)
@classmethod
def from_dict(cls, d: dict) -> "PlutoConfig":
# ignore extra keys silently for forward-compat
known = {f.name for f in cls.__dataclass_fields__.values()}
return cls(**{k: v for k, v in d.items() if k in known})
def save(self, output_dir: str | Path) -> None:
os.makedirs(output_dir, exist_ok=True)
with open(Path(output_dir) / "pluto_config.json", "w") as f:
json.dump(self.to_dict(), f, indent=2)
@classmethod
def load(cls, model_dir: str | Path) -> "PlutoConfig":
with open(Path(model_dir) / "pluto_config.json") as f:
return cls.from_dict(json.load(f))
# ─── Layers ─────────────────────────────────────────────────────────────
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Compute in fp32 for numerical stability, return in input dtype
out = x.float()
norm = out.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
return (out * norm).to(x.dtype) * self.weight
def _rope_freqs(dim: int, base: float, device, dtype=torch.float32) -> torch.Tensor:
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=dtype) / dim))
return inv_freq
def _rope_cache(seq_len: int, dim: int, base: float, device) -> tuple[torch.Tensor, torch.Tensor]:
inv_freq = _rope_freqs(dim, base, device)
t = torch.arange(seq_len, device=device, dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
return cos, sin
def _apply_rope(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
# q, k: [B, H, T, Dh]; cos, sin: [T, Dh/2]
def rotate(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x[..., ::2], x[..., 1::2]
rot = torch.stack((-x2 * sin + x1 * cos, x1 * sin + x2 * cos), dim=-1)
return rot.flatten(-2)
return rotate(q), rotate(k)
class PlutoAttention(nn.Module):
"""Causal SDPA attention with optional GQA + RoPE."""
def __init__(self, cfg: PlutoConfig):
super().__init__()
assert cfg.hidden_size % cfg.n_heads == 0
self.cfg = cfg
self.head_dim = cfg.hidden_size // cfg.n_heads
self.q_proj = nn.Linear(cfg.hidden_size, cfg.n_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(cfg.hidden_size, cfg.n_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(cfg.hidden_size, cfg.n_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False)
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
B, T, D = x.shape
H = self.cfg.n_heads
Hk = self.cfg.n_kv_heads
Dh = self.head_dim
q = self.q_proj(x).view(B, T, H, Dh).transpose(1, 2) # [B, H, T, Dh]
k = self.k_proj(x).view(B, T, Hk, Dh).transpose(1, 2) # [B, Hk, T, Dh]
v = self.v_proj(x).view(B, T, Hk, Dh).transpose(1, 2)
q, k = _apply_rope(q, k, cos[:T].to(q.dtype), sin[:T].to(q.dtype))
# GQA: expand kv if Hk < H
if Hk != H:
repeats = H // Hk
k = k.repeat_interleave(repeats, dim=1)
v = v.repeat_interleave(repeats, dim=1)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
y = y.transpose(1, 2).contiguous().view(B, T, D)
return self.o_proj(y)
class SwiGLU(nn.Module):
def __init__(self, dim: int, hidden: int):
super().__init__()
self.w_gate = nn.Linear(dim, hidden, bias=False)
self.w_up = nn.Linear(dim, hidden, bias=False)
self.w_down = nn.Linear(hidden, dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))
class PlutoMoE(nn.Module):
"""Top-K MoE using grouped matmul (torch._grouped_mm).
Expert weights are kept as 3 stacked tensors of shape [E, D, H] (gate, up)
and [E, H, D] (down) so the whole layer is 3 grouped GEMMs per forward.
Currently specialised for top_k == 1 (sort once, no aggregation). Top-K>1
falls back to the per-expert loop.
Optional shared expert (always active) if intermediate_size_shared > 0.
"""
def __init__(self, cfg: PlutoConfig):
super().__init__()
self.cfg = cfg
E, D, H = cfg.n_experts, cfg.hidden_size, cfg.intermediate_size_expert
self.router = nn.Linear(D, E, bias=False)
# SwiGLU expert weights stacked along the expert dim.
# `_grouped_mm(A, B, offs)` expects B in [E, K, N] for A in [M, K]
# β†’ output [M, N]. So we store:
# W_gate: [E, D, H] β†’ x @ W_gate β†’ [M, H]
# W_up: [E, D, H]
# W_down: [E, H, D]
self.W_gate = nn.Parameter(torch.empty(E, D, H))
self.W_up = nn.Parameter(torch.empty(E, D, H))
self.W_down = nn.Parameter(torch.empty(E, H, D))
# Init: Kaiming-like, scaled down so initial residual is well-behaved.
std_in = 1.0 / math.sqrt(D)
std_h = 1.0 / math.sqrt(H)
nn.init.normal_(self.W_gate, std=std_in)
nn.init.normal_(self.W_up, std=std_in)
nn.init.normal_(self.W_down, std=std_h)
self.shared = (SwiGLU(D, cfg.intermediate_size_shared)
if cfg.intermediate_size_shared > 0 else None)
@staticmethod
def _offsets_from_counts(counts: torch.Tensor) -> torch.Tensor:
# Convert [E] counts β†’ end-offset tensor [E] of int32.
# `torch._grouped_mm` consumes end-offsets (exclusive cumsum).
return counts.cumsum(0).to(torch.int32)
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict]:
B, T, D = x.shape
E = self.cfg.n_experts
x_flat = x.reshape(B * T, D)
logits = self.router(x_flat) # [B*T, E]
if self.cfg.top_k == 1:
# Sort tokens by expert id β†’ contiguous expert ranges β†’ grouped GEMM
top_idx = logits.argmax(dim=-1) # [B*T]
sort_idx = top_idx.argsort(stable=True)
x_sorted = x_flat[sort_idx] # [B*T, D]
counts = torch.bincount(top_idx, minlength=E) # [E]
offsets = self._offsets_from_counts(counts) # [E] end-offsets
# Grouped SwiGLU: each token uses ONE expert.
gate = torch._grouped_mm(x_sorted, self.W_gate, offsets) # [B*T, H]
up = torch._grouped_mm(x_sorted, self.W_up, offsets) # [B*T, H]
hidden = F.silu(gate) * up
out_sorted = torch._grouped_mm(hidden, self.W_down, offsets) # [B*T, D]
# Un-sort
inverse = torch.empty_like(sort_idx)
inverse[sort_idx] = torch.arange(sort_idx.size(0), device=x.device)
out = out_sorted[inverse]
else:
# Top-K>1 fallback: slower loop. Kept for completeness.
topk_vals, topk_idx = logits.topk(self.cfg.top_k, dim=-1)
topk_w = F.softmax(topk_vals, dim=-1)
out = torch.zeros_like(x_flat)
for k in range(self.cfg.top_k):
ids = topk_idx[..., k]
w = topk_w[..., k].unsqueeze(-1)
# Per-K grouped GEMM
sort_idx = ids.argsort(stable=True)
x_sorted = x_flat[sort_idx]
counts = torch.bincount(ids, minlength=E)
offsets = self._offsets_from_counts(counts)
gate = torch._grouped_mm(x_sorted, self.W_gate, offsets)
up = torch._grouped_mm(x_sorted, self.W_up, offsets)
hidden = F.silu(gate) * up
out_sorted = torch._grouped_mm(hidden, self.W_down, offsets)
inverse = torch.empty_like(sort_idx)
inverse[sort_idx] = torch.arange(sort_idx.size(0), device=x.device)
out = out + out_sorted[inverse] * w
top_idx = topk_idx[..., 0] # for aux-loss bookkeeping below
if self.shared is not None:
out = out + self.shared(x_flat)
out = out.reshape(B, T, D)
# Auxiliary losses (Switch Transformer load-balance + ST-MoE z-loss)
aux: dict = {}
if self.training:
probs = F.softmax(logits.float(), dim=-1)
expert_freq = probs.mean(dim=0) # [E]
counts_norm = (counts.float() / counts.float().sum().clamp_min(1.0))
aux["aux_load"] = (expert_freq * counts_norm).sum() * self.cfg.n_experts
aux["aux_z"] = (logits.float().logsumexp(-1) ** 2).mean()
return out, aux
class PlutoBlock(nn.Module):
def __init__(self, cfg: PlutoConfig):
super().__init__()
self.ln1 = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps)
self.attn = PlutoAttention(cfg)
self.ln2 = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps)
self.moe = PlutoMoE(cfg)
def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> tuple[torch.Tensor, dict]:
x = x + self.attn(self.ln1(x), cos, sin)
y, aux = self.moe(self.ln2(x))
x = x + y
return x, aux
# ─── Models ─────────────────────────────────────────────────────────────
class PlutoModel(nn.Module):
"""Decoder backbone: token embed β†’ N blocks β†’ final RMSNorm."""
def __init__(self, cfg: PlutoConfig):
super().__init__()
self.cfg = cfg
self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.hidden_size)
self.blocks = nn.ModuleList([PlutoBlock(cfg) for _ in range(cfg.n_layers)])
self.final_norm = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps)
self.register_buffer("_rope_initialised", torch.tensor(False), persistent=False)
self._rope_cos = None
self._rope_sin = None
def _ensure_rope(self, seq_len: int, device, dtype):
head_dim = self.cfg.hidden_size // self.cfg.n_heads
if (self._rope_cos is None or self._rope_cos.size(0) < seq_len
or self._rope_cos.device != device):
cos, sin = _rope_cache(self.cfg.max_position_embeddings, head_dim,
self.cfg.rope_theta, device)
self._rope_cos = cos.to(dtype)
self._rope_sin = sin.to(dtype)
def forward(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, list[dict]]:
B, T = input_ids.shape
h = self.embed_tokens(input_ids)
self._ensure_rope(T, h.device, h.dtype)
aux_list = []
for blk in self.blocks:
h, aux = blk(h, self._rope_cos, self._rope_sin)
aux_list.append(aux)
h = self.final_norm(h)
return h, aux_list
class PlutoForCausalLM(nn.Module):
"""LM head + optional MTP heads. Returns full loss in `forward`."""
def __init__(self, cfg: PlutoConfig):
super().__init__()
self.cfg = cfg
self.model = PlutoModel(cfg)
self.lm_head = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False)
if cfg.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
# MTP β€” training-only auxiliary heads that predict tokens further ahead.
self.mtp_heads = nn.ModuleList([
nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False)
for _ in range(cfg.mtp_depth)
])
def forward(self, input_ids: torch.Tensor, labels: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
) -> dict:
# We only honour `labels` from the training harness (HF API).
if labels is None:
labels = input_ids
h, aux_list = self.model(input_ids)
logits = self.lm_head(h)
out = {"logits": logits}
# Main next-token loss. Trainer is expected to pass `input_ids = ids[:-1]`
# and `labels = ids[1:]` so they already align (no internal shift).
if labels is not None and labels.size(1) == logits.size(1):
ce = F.cross_entropy(
logits.float().view(-1, logits.size(-1)),
labels.view(-1),
ignore_index=-100,
)
loss = ce
# MTP auxiliary losses: head d predicts the token d positions ahead.
# Skip entirely when mtp_loss_weight == 0 to save the per-head matmul
# against the full vocab β€” that head alone is ~15-20 % of step time.
if self.cfg.mtp_depth > 0 and self.cfg.mtp_loss_weight > 0:
mtp_total = 0.0
for d, head in enumerate(self.mtp_heads, start=1):
if labels.size(1) <= d: continue
logits_d = head(h)[:, :-d, :].contiguous()
labels_d = labels[:, d:].contiguous()
mtp_total = mtp_total + F.cross_entropy(
logits_d.float().view(-1, logits_d.size(-1)),
labels_d.view(-1),
ignore_index=-100,
)
loss = loss + self.cfg.mtp_loss_weight * (mtp_total / max(self.cfg.mtp_depth, 1))
# Router aux losses (averaged over layers)
if aux_list and "aux_load" in aux_list[0]:
aux_load = torch.stack([a["aux_load"] for a in aux_list]).mean()
aux_z = torch.stack([a["aux_z"] for a in aux_list]).mean()
loss = (loss + self.cfg.router_aux_loss_coef * aux_load
+ self.cfg.router_z_loss_coef * aux_z)
out["loss"] = loss
return out
# ─── Save / load ────────────────────────────────────────────────────────
def save_pluto(model: PlutoForCausalLM, output_dir: str | Path) -> None:
model.cfg.save(output_dir)
from safetensors.torch import save_model
# `save_model` handles tied weights (embed↔lm_head) by deduplicating them.
# We must NOT permanently move the model to CPU β€” restore device after save.
devices = {p.device for p in model.parameters()}
device = next(iter(devices)) if len(devices) == 1 else None
model_cpu = model.cpu()
save_model(model_cpu, str(Path(output_dir) / "model.safetensors"))
if device is not None and device.type != "cpu":
model.to(device)
def load_pluto(model_dir: str | Path, dtype=torch.bfloat16, map_location="cpu") -> PlutoForCausalLM:
cfg = PlutoConfig.load(model_dir)
model = PlutoForCausalLM(cfg).to(dtype)
from safetensors.torch import load_file
state = load_file(str(Path(model_dir) / "model.safetensors"), device=str(map_location))
model.load_state_dict(state, strict=False)
return model
# ─── Param accounting ──────────────────────────────────────────────────
def count_params(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
def estimate_active_params(cfg: PlutoConfig) -> dict:
"""At-inference active params (MTP heads NOT counted, since they are training-only)."""
head_dim = cfg.hidden_size // cfg.n_heads
attn_per_layer = (
cfg.hidden_size * cfg.n_heads * head_dim # q_proj
+ cfg.hidden_size * cfg.n_kv_heads * head_dim # k_proj
+ cfg.hidden_size * cfg.n_kv_heads * head_dim # v_proj
+ cfg.hidden_size * cfg.hidden_size # o_proj
)
expert_size = 3 * cfg.hidden_size * cfg.intermediate_size_expert # SwiGLU
shared_size = (3 * cfg.hidden_size * cfg.intermediate_size_shared
if cfg.intermediate_size_shared > 0 else 0)
active_per_layer = attn_per_layer + cfg.top_k * expert_size + shared_size
active_total = active_per_layer * cfg.n_layers
# lm_head is also "active" (full matmul against vocab)
active_total += cfg.vocab_size * cfg.hidden_size
total_experts = expert_size * cfg.n_experts * cfg.n_layers
total_shared = shared_size * cfg.n_layers
total_attn = attn_per_layer * cfg.n_layers
emb_params = cfg.vocab_size * cfg.hidden_size
lm_head_params = 0 if cfg.tie_word_embeddings else cfg.vocab_size * cfg.hidden_size
mtp_params = cfg.mtp_depth * cfg.vocab_size * cfg.hidden_size
total_params = (total_experts + total_shared + total_attn + emb_params
+ lm_head_params + mtp_params
+ 2 * cfg.n_layers * cfg.hidden_size # RMSNorm weights
+ cfg.hidden_size)
return {
"total_params": total_params,
"active_inference_params": active_total,
"expert_total_params": total_experts,
"attn_total_params": total_attn,
"embedding_params": emb_params,
"lm_head_params": lm_head_params,
"mtp_head_params": mtp_params,
}
if __name__ == "__main__":
cfg = PlutoConfig()
stats = estimate_active_params(cfg)
for k, v in stats.items():
print(f" {k:<28} {v/1e6:>8.2f} M")
print(f" active/total ratio {stats['active_inference_params']/stats['total_params']*100:>5.2f} %")
m = PlutoForCausalLM(cfg)
n_real = count_params(m)
print(f"\n real (actual) total {n_real/1e6:>8.2f} M")
x = torch.randint(0, cfg.vocab_size, (2, 32))
out = m(x, labels=x)
print(f" fwd OK logits {tuple(out['logits'].shape)} loss={out['loss'].item():.4f}")