Text Generation
Transformers
Safetensors
astrai_pluto
mixture-of-experts
Mixture of Experts
astrai
pluto-nano
base
causal-lm
custom_code
Instructions to use ASTRAI-labs/pluto-nano-0.5-base with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ASTRAI-labs/pluto-nano-0.5-base with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="ASTRAI-labs/pluto-nano-0.5-base", trust_remote_code=True)# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("ASTRAI-labs/pluto-nano-0.5-base", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- vLLM
How to use ASTRAI-labs/pluto-nano-0.5-base with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "ASTRAI-labs/pluto-nano-0.5-base" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "ASTRAI-labs/pluto-nano-0.5-base", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker
docker model run hf.co/ASTRAI-labs/pluto-nano-0.5-base
- SGLang
How to use ASTRAI-labs/pluto-nano-0.5-base with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "ASTRAI-labs/pluto-nano-0.5-base" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "ASTRAI-labs/pluto-nano-0.5-base", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "ASTRAI-labs/pluto-nano-0.5-base" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "ASTRAI-labs/pluto-nano-0.5-base", "prompt": "Once upon a time,", "max_tokens": 512, "temperature": 0.5 }' - Docker Model Runner
How to use ASTRAI-labs/pluto-nano-0.5-base with Docker Model Runner:
docker model run hf.co/ASTRAI-labs/pluto-nano-0.5-base
| """ | |
| ASTRAI Pluto β native architecture for the Pluto family. | |
| A standalone decoder-only Transformer with: | |
| * RMSNorm + RoPE (no learned positional embeddings) | |
| * Causal SDPA attention (multi-head, optional GQA) | |
| * Top-K Mixture-of-Experts (SwiGLU experts), no required shared expert | |
| * Multi-Token Prediction heads (training-only) | |
| * Tied input/output embedding | |
| * Router auxiliary loss (load balance) + z-loss | |
| Not derived from any HuggingFace base model β fresh implementation in plain | |
| PyTorch. Save/load uses a `pluto_config.json` + a safetensors weights file. | |
| Naming: `PlutoModel` / `PlutoForCausalLM`. The `_meta` dict on the config holds | |
| size hyper-params; routing / aux-loss config is on its own dataclass. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import math | |
| import os | |
| from dataclasses import asdict, dataclass, field | |
| from pathlib import Path | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # βββ Config βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class PlutoConfig: | |
| # Architecture (multilingual Nano β d=384, layers=16, GQA, 32k vocab) | |
| vocab_size: int = 32768 | |
| hidden_size: int = 384 | |
| intermediate_size_expert: int = 1536 | |
| intermediate_size_shared: int = 0 # 0 = no shared expert | |
| n_layers: int = 16 | |
| n_heads: int = 6 | |
| n_kv_heads: int = 2 # GQA: 6β2 β ~50 % attn-param saving | |
| n_experts: int = 35 # 5 langs Γ 7 experts each | |
| top_k: int = 1 # max sparsity β ~50 M active inference | |
| n_languages: int = 5 # en, pt, es, zh, hi | |
| max_position_embeddings: int = 4096 | |
| rope_theta: float = 1_000_000.0 | |
| rms_norm_eps: float = 1e-6 | |
| tie_word_embeddings: bool = True | |
| # MTP β training-only aux heads | |
| mtp_depth: int = 2 | |
| mtp_loss_weight: float = 0.15 | |
| # Routing aux losses | |
| router_aux_loss_coef: float = 0.01 | |
| router_z_loss_coef: float = 0.001 | |
| # Bookkeeping | |
| model_type: str = "astrai_pluto" | |
| pad_token_id: int | None = None | |
| bos_token_id: int | None = None | |
| eos_token_id: int | None = None | |
| # Tokenizer config (saved for convenience) | |
| tokenizer_name: str | None = None | |
| def to_dict(self) -> dict: | |
| return asdict(self) | |
| def from_dict(cls, d: dict) -> "PlutoConfig": | |
| # ignore extra keys silently for forward-compat | |
| known = {f.name for f in cls.__dataclass_fields__.values()} | |
| return cls(**{k: v for k, v in d.items() if k in known}) | |
| def save(self, output_dir: str | Path) -> None: | |
| os.makedirs(output_dir, exist_ok=True) | |
| with open(Path(output_dir) / "pluto_config.json", "w") as f: | |
| json.dump(self.to_dict(), f, indent=2) | |
| def load(cls, model_dir: str | Path) -> "PlutoConfig": | |
| with open(Path(model_dir) / "pluto_config.json") as f: | |
| return cls.from_dict(json.load(f)) | |
| # βββ Layers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class RMSNorm(nn.Module): | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| self.eps = eps | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # Compute in fp32 for numerical stability, return in input dtype | |
| out = x.float() | |
| norm = out.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() | |
| return (out * norm).to(x.dtype) * self.weight | |
| def _rope_freqs(dim: int, base: float, device, dtype=torch.float32) -> torch.Tensor: | |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=dtype) / dim)) | |
| return inv_freq | |
| def _rope_cache(seq_len: int, dim: int, base: float, device) -> tuple[torch.Tensor, torch.Tensor]: | |
| inv_freq = _rope_freqs(dim, base, device) | |
| t = torch.arange(seq_len, device=device, dtype=torch.float32) | |
| freqs = torch.outer(t, inv_freq) | |
| cos = freqs.cos() | |
| sin = freqs.sin() | |
| return cos, sin | |
| def _apply_rope(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor): | |
| # q, k: [B, H, T, Dh]; cos, sin: [T, Dh/2] | |
| def rotate(x: torch.Tensor) -> torch.Tensor: | |
| x1, x2 = x[..., ::2], x[..., 1::2] | |
| rot = torch.stack((-x2 * sin + x1 * cos, x1 * sin + x2 * cos), dim=-1) | |
| return rot.flatten(-2) | |
| return rotate(q), rotate(k) | |
| class PlutoAttention(nn.Module): | |
| """Causal SDPA attention with optional GQA + RoPE.""" | |
| def __init__(self, cfg: PlutoConfig): | |
| super().__init__() | |
| assert cfg.hidden_size % cfg.n_heads == 0 | |
| self.cfg = cfg | |
| self.head_dim = cfg.hidden_size // cfg.n_heads | |
| self.q_proj = nn.Linear(cfg.hidden_size, cfg.n_heads * self.head_dim, bias=False) | |
| self.k_proj = nn.Linear(cfg.hidden_size, cfg.n_kv_heads * self.head_dim, bias=False) | |
| self.v_proj = nn.Linear(cfg.hidden_size, cfg.n_kv_heads * self.head_dim, bias=False) | |
| self.o_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size, bias=False) | |
| def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: | |
| B, T, D = x.shape | |
| H = self.cfg.n_heads | |
| Hk = self.cfg.n_kv_heads | |
| Dh = self.head_dim | |
| q = self.q_proj(x).view(B, T, H, Dh).transpose(1, 2) # [B, H, T, Dh] | |
| k = self.k_proj(x).view(B, T, Hk, Dh).transpose(1, 2) # [B, Hk, T, Dh] | |
| v = self.v_proj(x).view(B, T, Hk, Dh).transpose(1, 2) | |
| q, k = _apply_rope(q, k, cos[:T].to(q.dtype), sin[:T].to(q.dtype)) | |
| # GQA: expand kv if Hk < H | |
| if Hk != H: | |
| repeats = H // Hk | |
| k = k.repeat_interleave(repeats, dim=1) | |
| v = v.repeat_interleave(repeats, dim=1) | |
| y = F.scaled_dot_product_attention(q, k, v, is_causal=True) | |
| y = y.transpose(1, 2).contiguous().view(B, T, D) | |
| return self.o_proj(y) | |
| class SwiGLU(nn.Module): | |
| def __init__(self, dim: int, hidden: int): | |
| super().__init__() | |
| self.w_gate = nn.Linear(dim, hidden, bias=False) | |
| self.w_up = nn.Linear(dim, hidden, bias=False) | |
| self.w_down = nn.Linear(hidden, dim, bias=False) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x)) | |
| class PlutoMoE(nn.Module): | |
| """Top-K MoE using grouped matmul (torch._grouped_mm). | |
| Expert weights are kept as 3 stacked tensors of shape [E, D, H] (gate, up) | |
| and [E, H, D] (down) so the whole layer is 3 grouped GEMMs per forward. | |
| Currently specialised for top_k == 1 (sort once, no aggregation). Top-K>1 | |
| falls back to the per-expert loop. | |
| Optional shared expert (always active) if intermediate_size_shared > 0. | |
| """ | |
| def __init__(self, cfg: PlutoConfig): | |
| super().__init__() | |
| self.cfg = cfg | |
| E, D, H = cfg.n_experts, cfg.hidden_size, cfg.intermediate_size_expert | |
| self.router = nn.Linear(D, E, bias=False) | |
| # SwiGLU expert weights stacked along the expert dim. | |
| # `_grouped_mm(A, B, offs)` expects B in [E, K, N] for A in [M, K] | |
| # β output [M, N]. So we store: | |
| # W_gate: [E, D, H] β x @ W_gate β [M, H] | |
| # W_up: [E, D, H] | |
| # W_down: [E, H, D] | |
| self.W_gate = nn.Parameter(torch.empty(E, D, H)) | |
| self.W_up = nn.Parameter(torch.empty(E, D, H)) | |
| self.W_down = nn.Parameter(torch.empty(E, H, D)) | |
| # Init: Kaiming-like, scaled down so initial residual is well-behaved. | |
| std_in = 1.0 / math.sqrt(D) | |
| std_h = 1.0 / math.sqrt(H) | |
| nn.init.normal_(self.W_gate, std=std_in) | |
| nn.init.normal_(self.W_up, std=std_in) | |
| nn.init.normal_(self.W_down, std=std_h) | |
| self.shared = (SwiGLU(D, cfg.intermediate_size_shared) | |
| if cfg.intermediate_size_shared > 0 else None) | |
| def _offsets_from_counts(counts: torch.Tensor) -> torch.Tensor: | |
| # Convert [E] counts β end-offset tensor [E] of int32. | |
| # `torch._grouped_mm` consumes end-offsets (exclusive cumsum). | |
| return counts.cumsum(0).to(torch.int32) | |
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict]: | |
| B, T, D = x.shape | |
| E = self.cfg.n_experts | |
| x_flat = x.reshape(B * T, D) | |
| logits = self.router(x_flat) # [B*T, E] | |
| if self.cfg.top_k == 1: | |
| # Sort tokens by expert id β contiguous expert ranges β grouped GEMM | |
| top_idx = logits.argmax(dim=-1) # [B*T] | |
| sort_idx = top_idx.argsort(stable=True) | |
| x_sorted = x_flat[sort_idx] # [B*T, D] | |
| counts = torch.bincount(top_idx, minlength=E) # [E] | |
| offsets = self._offsets_from_counts(counts) # [E] end-offsets | |
| # Grouped SwiGLU: each token uses ONE expert. | |
| gate = torch._grouped_mm(x_sorted, self.W_gate, offsets) # [B*T, H] | |
| up = torch._grouped_mm(x_sorted, self.W_up, offsets) # [B*T, H] | |
| hidden = F.silu(gate) * up | |
| out_sorted = torch._grouped_mm(hidden, self.W_down, offsets) # [B*T, D] | |
| # Un-sort | |
| inverse = torch.empty_like(sort_idx) | |
| inverse[sort_idx] = torch.arange(sort_idx.size(0), device=x.device) | |
| out = out_sorted[inverse] | |
| else: | |
| # Top-K>1 fallback: slower loop. Kept for completeness. | |
| topk_vals, topk_idx = logits.topk(self.cfg.top_k, dim=-1) | |
| topk_w = F.softmax(topk_vals, dim=-1) | |
| out = torch.zeros_like(x_flat) | |
| for k in range(self.cfg.top_k): | |
| ids = topk_idx[..., k] | |
| w = topk_w[..., k].unsqueeze(-1) | |
| # Per-K grouped GEMM | |
| sort_idx = ids.argsort(stable=True) | |
| x_sorted = x_flat[sort_idx] | |
| counts = torch.bincount(ids, minlength=E) | |
| offsets = self._offsets_from_counts(counts) | |
| gate = torch._grouped_mm(x_sorted, self.W_gate, offsets) | |
| up = torch._grouped_mm(x_sorted, self.W_up, offsets) | |
| hidden = F.silu(gate) * up | |
| out_sorted = torch._grouped_mm(hidden, self.W_down, offsets) | |
| inverse = torch.empty_like(sort_idx) | |
| inverse[sort_idx] = torch.arange(sort_idx.size(0), device=x.device) | |
| out = out + out_sorted[inverse] * w | |
| top_idx = topk_idx[..., 0] # for aux-loss bookkeeping below | |
| if self.shared is not None: | |
| out = out + self.shared(x_flat) | |
| out = out.reshape(B, T, D) | |
| # Auxiliary losses (Switch Transformer load-balance + ST-MoE z-loss) | |
| aux: dict = {} | |
| if self.training: | |
| probs = F.softmax(logits.float(), dim=-1) | |
| expert_freq = probs.mean(dim=0) # [E] | |
| counts_norm = (counts.float() / counts.float().sum().clamp_min(1.0)) | |
| aux["aux_load"] = (expert_freq * counts_norm).sum() * self.cfg.n_experts | |
| aux["aux_z"] = (logits.float().logsumexp(-1) ** 2).mean() | |
| return out, aux | |
| class PlutoBlock(nn.Module): | |
| def __init__(self, cfg: PlutoConfig): | |
| super().__init__() | |
| self.ln1 = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps) | |
| self.attn = PlutoAttention(cfg) | |
| self.ln2 = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps) | |
| self.moe = PlutoMoE(cfg) | |
| def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> tuple[torch.Tensor, dict]: | |
| x = x + self.attn(self.ln1(x), cos, sin) | |
| y, aux = self.moe(self.ln2(x)) | |
| x = x + y | |
| return x, aux | |
| # βββ Models βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class PlutoModel(nn.Module): | |
| """Decoder backbone: token embed β N blocks β final RMSNorm.""" | |
| def __init__(self, cfg: PlutoConfig): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.hidden_size) | |
| self.blocks = nn.ModuleList([PlutoBlock(cfg) for _ in range(cfg.n_layers)]) | |
| self.final_norm = RMSNorm(cfg.hidden_size, cfg.rms_norm_eps) | |
| self.register_buffer("_rope_initialised", torch.tensor(False), persistent=False) | |
| self._rope_cos = None | |
| self._rope_sin = None | |
| def _ensure_rope(self, seq_len: int, device, dtype): | |
| head_dim = self.cfg.hidden_size // self.cfg.n_heads | |
| if (self._rope_cos is None or self._rope_cos.size(0) < seq_len | |
| or self._rope_cos.device != device): | |
| cos, sin = _rope_cache(self.cfg.max_position_embeddings, head_dim, | |
| self.cfg.rope_theta, device) | |
| self._rope_cos = cos.to(dtype) | |
| self._rope_sin = sin.to(dtype) | |
| def forward(self, input_ids: torch.Tensor) -> tuple[torch.Tensor, list[dict]]: | |
| B, T = input_ids.shape | |
| h = self.embed_tokens(input_ids) | |
| self._ensure_rope(T, h.device, h.dtype) | |
| aux_list = [] | |
| for blk in self.blocks: | |
| h, aux = blk(h, self._rope_cos, self._rope_sin) | |
| aux_list.append(aux) | |
| h = self.final_norm(h) | |
| return h, aux_list | |
| class PlutoForCausalLM(nn.Module): | |
| """LM head + optional MTP heads. Returns full loss in `forward`.""" | |
| def __init__(self, cfg: PlutoConfig): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.model = PlutoModel(cfg) | |
| self.lm_head = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False) | |
| if cfg.tie_word_embeddings: | |
| self.lm_head.weight = self.model.embed_tokens.weight | |
| # MTP β training-only auxiliary heads that predict tokens further ahead. | |
| self.mtp_heads = nn.ModuleList([ | |
| nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False) | |
| for _ in range(cfg.mtp_depth) | |
| ]) | |
| def forward(self, input_ids: torch.Tensor, labels: torch.Tensor | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| ) -> dict: | |
| # We only honour `labels` from the training harness (HF API). | |
| if labels is None: | |
| labels = input_ids | |
| h, aux_list = self.model(input_ids) | |
| logits = self.lm_head(h) | |
| out = {"logits": logits} | |
| # Main next-token loss. Trainer is expected to pass `input_ids = ids[:-1]` | |
| # and `labels = ids[1:]` so they already align (no internal shift). | |
| if labels is not None and labels.size(1) == logits.size(1): | |
| ce = F.cross_entropy( | |
| logits.float().view(-1, logits.size(-1)), | |
| labels.view(-1), | |
| ignore_index=-100, | |
| ) | |
| loss = ce | |
| # MTP auxiliary losses: head d predicts the token d positions ahead. | |
| # Skip entirely when mtp_loss_weight == 0 to save the per-head matmul | |
| # against the full vocab β that head alone is ~15-20 % of step time. | |
| if self.cfg.mtp_depth > 0 and self.cfg.mtp_loss_weight > 0: | |
| mtp_total = 0.0 | |
| for d, head in enumerate(self.mtp_heads, start=1): | |
| if labels.size(1) <= d: continue | |
| logits_d = head(h)[:, :-d, :].contiguous() | |
| labels_d = labels[:, d:].contiguous() | |
| mtp_total = mtp_total + F.cross_entropy( | |
| logits_d.float().view(-1, logits_d.size(-1)), | |
| labels_d.view(-1), | |
| ignore_index=-100, | |
| ) | |
| loss = loss + self.cfg.mtp_loss_weight * (mtp_total / max(self.cfg.mtp_depth, 1)) | |
| # Router aux losses (averaged over layers) | |
| if aux_list and "aux_load" in aux_list[0]: | |
| aux_load = torch.stack([a["aux_load"] for a in aux_list]).mean() | |
| aux_z = torch.stack([a["aux_z"] for a in aux_list]).mean() | |
| loss = (loss + self.cfg.router_aux_loss_coef * aux_load | |
| + self.cfg.router_z_loss_coef * aux_z) | |
| out["loss"] = loss | |
| return out | |
| # βββ Save / load ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def save_pluto(model: PlutoForCausalLM, output_dir: str | Path) -> None: | |
| model.cfg.save(output_dir) | |
| from safetensors.torch import save_model | |
| # `save_model` handles tied weights (embedβlm_head) by deduplicating them. | |
| # We must NOT permanently move the model to CPU β restore device after save. | |
| devices = {p.device for p in model.parameters()} | |
| device = next(iter(devices)) if len(devices) == 1 else None | |
| model_cpu = model.cpu() | |
| save_model(model_cpu, str(Path(output_dir) / "model.safetensors")) | |
| if device is not None and device.type != "cpu": | |
| model.to(device) | |
| def load_pluto(model_dir: str | Path, dtype=torch.bfloat16, map_location="cpu") -> PlutoForCausalLM: | |
| cfg = PlutoConfig.load(model_dir) | |
| model = PlutoForCausalLM(cfg).to(dtype) | |
| from safetensors.torch import load_file | |
| state = load_file(str(Path(model_dir) / "model.safetensors"), device=str(map_location)) | |
| model.load_state_dict(state, strict=False) | |
| return model | |
| # βββ Param accounting ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def count_params(model: nn.Module) -> int: | |
| return sum(p.numel() for p in model.parameters()) | |
| def estimate_active_params(cfg: PlutoConfig) -> dict: | |
| """At-inference active params (MTP heads NOT counted, since they are training-only).""" | |
| head_dim = cfg.hidden_size // cfg.n_heads | |
| attn_per_layer = ( | |
| cfg.hidden_size * cfg.n_heads * head_dim # q_proj | |
| + cfg.hidden_size * cfg.n_kv_heads * head_dim # k_proj | |
| + cfg.hidden_size * cfg.n_kv_heads * head_dim # v_proj | |
| + cfg.hidden_size * cfg.hidden_size # o_proj | |
| ) | |
| expert_size = 3 * cfg.hidden_size * cfg.intermediate_size_expert # SwiGLU | |
| shared_size = (3 * cfg.hidden_size * cfg.intermediate_size_shared | |
| if cfg.intermediate_size_shared > 0 else 0) | |
| active_per_layer = attn_per_layer + cfg.top_k * expert_size + shared_size | |
| active_total = active_per_layer * cfg.n_layers | |
| # lm_head is also "active" (full matmul against vocab) | |
| active_total += cfg.vocab_size * cfg.hidden_size | |
| total_experts = expert_size * cfg.n_experts * cfg.n_layers | |
| total_shared = shared_size * cfg.n_layers | |
| total_attn = attn_per_layer * cfg.n_layers | |
| emb_params = cfg.vocab_size * cfg.hidden_size | |
| lm_head_params = 0 if cfg.tie_word_embeddings else cfg.vocab_size * cfg.hidden_size | |
| mtp_params = cfg.mtp_depth * cfg.vocab_size * cfg.hidden_size | |
| total_params = (total_experts + total_shared + total_attn + emb_params | |
| + lm_head_params + mtp_params | |
| + 2 * cfg.n_layers * cfg.hidden_size # RMSNorm weights | |
| + cfg.hidden_size) | |
| return { | |
| "total_params": total_params, | |
| "active_inference_params": active_total, | |
| "expert_total_params": total_experts, | |
| "attn_total_params": total_attn, | |
| "embedding_params": emb_params, | |
| "lm_head_params": lm_head_params, | |
| "mtp_head_params": mtp_params, | |
| } | |
| if __name__ == "__main__": | |
| cfg = PlutoConfig() | |
| stats = estimate_active_params(cfg) | |
| for k, v in stats.items(): | |
| print(f" {k:<28} {v/1e6:>8.2f} M") | |
| print(f" active/total ratio {stats['active_inference_params']/stats['total_params']*100:>5.2f} %") | |
| m = PlutoForCausalLM(cfg) | |
| n_real = count_params(m) | |
| print(f"\n real (actual) total {n_real/1e6:>8.2f} M") | |
| x = torch.randint(0, cfg.vocab_size, (2, 32)) | |
| out = m(x, labels=x) | |
| print(f" fwd OK logits {tuple(out['logits'].shape)} loss={out['loss'].item():.4f}") | |