| """PostSemClawModel — full-architecture model assembly. |
|
|
| Extracted from the monolithic train.py (W1 modularization). Semantics |
| unchanged. Imports `GPUEngram` from `hydra.engram` and `MuonAdamW` from |
| `hydra.optimizer`. |
|
|
| Triton kernel integration status (Phase 2): |
| HYDRA_FUSED_BCNORM — DEFERRED. The bcnorm_fused Triton kernel fuses |
| LayerNorm + RoPE on B/C projections. However, mamba-ssm's Mamba3 block |
| uses RMSNormGated (not LayerNorm) for B/C, and RoPE is applied inside |
| the mamba3_siso_combined CUDA kernel via the Angles parameter. Replacing |
| would require either (a) monkey-patching RMSNormGated + intercepting the |
| fused CUDA scan — invasive, 50+ lines, high breakage risk — or (b) a |
| full custom Mamba3Block reimplementation. Both are out of scope for |
| Phase 2. The kernel is validated standalone; integration deferred to |
| Phase 3 when HYDRA moves to a custom SSM block. |
|
|
| HYDRA_FUSED_SSD — DEFERRED. The ssd_exp_trap Triton kernel implements |
| exponential-trapezoidal discretization as a sequential scan. mamba-ssm's |
| Mamba3 block delegates the entire scan + gating + output projection to |
| mamba3_siso_combined (a compiled CUDA kernel with tilelang). Replacing |
| it would require decomposing the combined kernel into constituent ops |
| and substituting only the scan — not feasible without a custom block. |
| Same Phase 3 gate as above. |
|
|
| Both env vars are accepted but currently no-ops (gates read, logged, but |
| the code path is unchanged). This avoids silent regression if someone |
| sets them expecting a speedup. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.utils.checkpoint as _ckpt |
|
|
| from mamba_ssm import Mamba3 |
|
|
| from subsystems.hestia_mini import HestiaQAT |
| from subsystems.htm import HTMLayer |
| from subsystems.mhc_mini import ManifoldHyperConnection |
| from subsystems.sdr_semantic import SemanticFoldingSDR |
|
|
| from hydra.engram import GPUEngram |
| from hydra.hyena_block import HyenaBlock |
| # GDNBlock is imported lazily inside __init__ so the `fla` dependency is |
| # only required when HYDRA_GDN_LAYERS is actually non-empty. Baseline |
| # pure-Mamba3 runs continue to work without flash-linear-attention installed. |
| from hydra.optimizer import MuonAdamW |
|
|
|
|
| def norm(x: torch.Tensor) -> torch.Tensor: |
| """RMSNorm over the last dim — stateless, autocast-friendly.""" |
| return F.rms_norm(x, (x.size(-1),)) |
|
|
|
|
| class PostSemClawModel(nn.Module): |
| """Full Post-SEM-Claw model assembly. |
|
|
| Architecture: |
| Token Embedding -> [Mamba3 + residual] x n_layer |
| -> SDR + Engram (at configured layer) -> norm -> LM head |
|
|
| Interface (must match prepare.py evaluate_bpb): |
| model(x, y, reduction='none').view(-1) -> per-token losses |
| model(x, y, reduction='mean') -> scalar loss |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
|
|
| # Token embedding |
| self.wte = nn.Embedding(config.vocab_size, config.d_model) |
|
|
| # Mamba-3 blocks — official mamba-ssm fused CUDA kernel. No fallbacks. |
| # RoPE is applied internally by the Mamba3 CUDA kernel via the Angles |
| # parameter; external cos/sin buffers are not needed. |
| # |
| # Hyena supplement: layers whose index appears in `config.hyena_layers` |
| # are instantiated as HyenaBlock instead of Mamba3. The config field |
| # is populated from HYDRA_HYENA_LAYERS at construction time and then |
| # persisted to checkpoints, so resume is safe even when the env var |
| # is unset. Empty tuple → all-Mamba3, byte-identical to pre-port. |
| _hyena_layer_set = set(getattr(config, "hyena_layers", ()) or ()) |
| _gdn_layer_set = set(getattr(config, "gdn_layers", ()) or ()) |
| # Hyena wins on overlap; conflict is logged at construction time. |
| _both = _hyena_layer_set & _gdn_layer_set |
| if _both: |
| print(f"[WARN] layers in both hyena_layers and gdn_layers; using Hyena: {sorted(_both)}", flush=True) |
| _gdn_layer_set -= _hyena_layer_set |
|
|
| if _gdn_layer_set: |
| from hydra.gdn_block import GDNBlock # requires `fla` package |
|
|
| def _build_block(i: int) -> nn.Module: |
| if i in _hyena_layer_set: |
| return HyenaBlock( |
| d_model=config.d_model, |
| seq_len=config.sequence_len, |
| order=int(os.environ.get("HYDRA_HYENA_ORDER", "2")), |
| filter_order=int(os.environ.get("HYDRA_HYENA_FILTER_DIM", "64")), |
| ) |
| if i in _gdn_layer_set: |
| return GDNBlock( |
| d_model=config.d_model, |
| n_heads=config.n_heads, |
| ) |
| block = Mamba3( |
| d_model=config.d_model, |
| d_state=config.d_state, |
| expand=config.expand, |
| headdim=config.headdim, |
| is_mimo=False, # SISO path uses stable mamba3_siso_combined kernel |
| chunk_size=int(os.environ.get("HYDRA_MAMBA3_CHUNK", "64")), # 64 is the validated default; 128 tripped a Triton autotune hang (>8min, no progress) |
| is_outproj_norm=False, |
| dtype=torch.bfloat16, |
| ) |
| # Fix dt_bias gradient starvation: Mamba3 init samples dt uniformly |
| # in log-space from [0.001, 0.1], giving dt_bias in [-6.9, -2.25]. |
| # softplus'(dt_bias) = sigmoid(dt_bias). At -6.9: 0.1% grad survives; |
| # at -2.25: 9.5% survives. Shift up so sigmoid is 2-68% instead. |
| # The SSM can still learn to make dt_bias more negative if finer |
| # temporal resolution is needed — now the gradient will survive to |
| # do so. |
| with torch.no_grad(): |
| block.dt_bias.add_(3.0) |
| return block |
| |
| self.blocks = nn.ModuleList([_build_block(i) for i in range(config.n_layer)]) |
| |
| # Full-architecture SDR: offline semantic retina + STE (no-bypass). |
| self.sdr_semantic = SemanticFoldingSDR( |
| vocab_size=config.vocab_size, |
| n_bits=config.sdr_n_bits, |
| target_active=config.sdr_target_active, |
| delta_rank=config.sdr_delta_rank, |
| som_warmup_steps=config.sdr_som_warmup, |
| som_update_interval=config.sdr_som_interval, |
| ) |
| |
| # HTM spatial pooler + temporal memory (Rust, Hebbian). |
| self.htm = HTMLayer( |
| input_bits=config.sdr_n_bits, |
| n_columns=config.htm_n_columns, |
| cells_per_column=config.htm_cells_per_column, |
| batch_size=1, # grows lazily to actual B on first forward |
| seed=42, |
| learn=True, |
| reset_each_forward=True, |
| ) |
| |
| # Gradient bridge: (n_columns + anomaly) -> d_model. |
| self.htm_proj = nn.Linear(config.htm_n_columns + 1, config.d_model, bias=False) |
| |
| # GPU Engram with Hebbian writes — runs EVERY step. |
| self.engram = GPUEngram( |
| d_model=config.d_model, |
| n_columns=config.engram_n_columns, |
| max_ngram=3, |
| ) |
| self.engram_layer_idx = config.engram_layer_idx |
| |
| # SDR-to-d_model projection: routes differentiable SDR signal (STE) |
| # into the engram's residual stream. One bf16 matmul per forward step |
| # at (B*T, n_bits) @ (n_bits, d_model) — ~2.6M MACs vs 5B total, |
| # <0.05% overhead. Zero extra memory (no intermediates materialized |
| # beyond the matmul output). This single projection backpropagates LM |
| # loss gradients through sdr_semantic.delta_u/delta_v, finally giving |
| # the curated semantic retina real learning signal. |
| self.sdr_proj = nn.Linear(config.sdr_n_bits, config.d_model, bias=False) |
|
|
| # Manifold-Constrained Hyper-Connections (one per Mamba-3 block). |
| self.mhc = nn.ModuleList([ |
| ManifoldHyperConnection(d_model=config.d_model, n_streams=2, sinkhorn_iters=3) |
| for _ in range(config.n_layer) |
| ]) |
|
|
| # Hestia QAT — ternary weight quantization applied post-optimizer-step. |
| self.hestia = HestiaQAT(enabled=True, bits=1.58) |
|
|
| # LM head |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
| # Learnability knob 1: Multi-Token Prediction (Llama-3 style). |
| # MTP_K=1 -> standard next-token. MTP_K>1 -> extra heads predict |
| # tokens at positions t+1, t+2, ..., t+K. Heads are weight-tied to |
| # lm_head (we share Parameters), so the only extra compute is |
| # additional CE losses; no new params. Activated via HYDRA_MTP_K. |
| self._mtp_k = max(1, int(os.environ.get("HYDRA_MTP_K", "1"))) |
|
|
| # Learnability knob 3: gradient checkpointing on Mamba3 blocks. |
| self._grad_ckpt = os.environ.get("HYDRA_GRAD_CKPT", "0") == "1" |
|
|
| # Learnability knob 4: doc-separator BOS masking in packed sequences. |
| self._doc_sep_mask = os.environ.get("HYDRA_DOC_SEP_MASK", "0") == "1" |
| # BOS token id is looked up lazily on first forward (requires tokenizer |
| # load); -1 means uninitialized. |
| self._bos_token_id = -1 |
|
|
| # Learnability knob 5: explicit stop-grad on HTM tensor (htm_rust |
| # outputs already have requires_grad=False; this is defense-in-depth). |
| self._htm_stop_grad = os.environ.get("HYDRA_HTM_STOP_GRAD", "0") == "1" |
|
|
| # Learnability knob 6: entropy penalty coefficient on LM logits. |
| self._entropy_penalty = float(os.environ.get("HYDRA_ENTROPY_PENALTY", "0.0")) |
|
|
| # Residual dropout |
| self.drop = nn.Dropout(float(os.environ.get("HYDRA_DROPOUT", "0.2"))) |
|
|
| # Logits soft-capping |
| self.softcap = 15.0 |
|
|
| # Secondary metrics storage |
| self._metrics = {} |
|
|
| # Per-layer diagnostic panel. Env-gated; zero overhead when off. |
| # Emits residual-contribution (delta_ratio), feature std, effective rank, |
| # gradient norm per layer; used to identify minimum viable n_layer + find |
| # entropy leakage / dead layers. See docs/depth-sweep.md. |
| self._diag_enabled = os.environ.get("HYDRA_LAYER_DIAGNOSTICS", "0") == "1" |
| self._diag_step = 0 |
| self._diag_svd_every = int(os.environ.get("HYDRA_LAYER_DIAG_SVD_EVERY", "100")) |
|
|
| # 3060 hot-path: cache env-driven flags at construction time so the |
| # forward pass doesn't pay a dict-lookup + string-parse per step. |
| # Each os.environ.get() call costs ~1-2us; with ~30 forwards/sec |
| # this saves single-digit-percent overhead per training step. |
| self._cache_profile_forward = os.environ.get("HYDRA_PROFILE_FORWARD", "0") == "1" |
| self._cache_htm_subsample = int(os.environ.get("HYDRA_HTM_SUBSAMPLE", "8")) |
| self._cache_htm_log_anomaly = os.environ.get("HYDRA_HTM_LOG_ANOMALY", "0") == "1" |
| self._cache_softcap_clamp = os.environ.get("HYDRA_SOFTCAP_CLAMP", "0") == "1" |
| self._cache_sampled_softmax_k = int(os.environ.get("HYDRA_SAMPLED_SOFTMAX", "4096")) |
| self._cache_ce_chunk = int(os.environ.get("HYDRA_CE_CHUNK", "1024")) |
| if self._diag_enabled: |
| # Gradient-norm backward hooks on each Mamba3 block output. |
| for _i, _block in enumerate(self.blocks): |
| def _mk_grad_hook(_layer_idx): |
| def _hook(module, grad_input, grad_output): |
| if grad_output and grad_output[0] is not None: |
| g = grad_output[0].detach() |
| self._metrics[f'layer_{_layer_idx}_grad_norm'] = float( |
| g.pow(2).mean().sqrt().item() |
| ) |
| return _hook |
| _block.register_full_backward_hook(_mk_grad_hook(_i)) |
|
|
| # Forward hooks on each Mamba3 block capture the block's OUTPUT |
| # directly. This is the clean measurement: unlike merge_streams() |
| # sampling which sees (streams + M*block_output) in bf16 — where |
| # small block contributions round to zero against unit-norm |
| # residuals — this captures `block_output` itself as produced. |
| # Reports both its absolute RMS norm and its ratio to the block |
| # INPUT's RMS norm (contribution magnitude relative to the |
| # residual it's added to). |
| for _i, _block in enumerate(self.blocks): |
| def _mk_fwd_hook(_layer_idx): |
| def _hook(module, inputs, output): |
| with torch.no_grad(): |
| inp = inputs[0].detach().float() if inputs else None |
| out = output.detach().float() if isinstance(output, torch.Tensor) else None |
| if out is not None: |
| out_rms = out.pow(2).mean().sqrt().item() |
| self._metrics[f'layer_{_layer_idx}_block_out_rms'] = float(out_rms) |
| if inp is not None: |
| in_rms = inp.pow(2).mean().sqrt().item() |
| self._metrics[f'layer_{_layer_idx}_block_in_rms'] = float(in_rms) |
| self._metrics[f'layer_{_layer_idx}_contrib_ratio'] = float( |
| out_rms / (in_rms + 1e-8) |
| ) |
| return _hook |
| _block.register_forward_hook(_mk_fwd_hook(_i)) |
| |
| # Triton kernel integration gates (Phase 2 — deferred, see module docstring). |
| self._fused_bcnorm = os.environ.get("HYDRA_FUSED_BCNORM", "0") == "1" |
| self._fused_ssd = os.environ.get("HYDRA_FUSED_SSD", "0") == "1" |
| if self._fused_bcnorm or self._fused_ssd: |
| import sys |
| _active = [] |
| if self._fused_bcnorm: |
| _active.append("HYDRA_FUSED_BCNORM") |
| if self._fused_ssd: |
| _active.append("HYDRA_FUSED_SSD") |
| print( |
| f"[HYDRA] Triton kernel gates set: {', '.join(_active)}. " |
| f"NOTE: Both are DEFERRED (mamba-ssm Mamba3 uses internal " |
| f"CUDA kernels). Gates accepted but currently no-ops.", |
| file=sys.stderr, |
| ) |
| |
| # R6 optional torch.compile on the impl forward. Gated (default OFF). |
| if os.environ.get("HYDRA_MODEL_COMPILE", "0") == "1": |
| self._forward_impl = torch.compile( |
| self._forward_impl, |
| fullgraph=False, |
| dynamic=True, |
| mode="default", |
| ) |
| |
| @torch.no_grad() |
| def init_weights(self) -> None: |
| s = 3 ** 0.5 * self.config.d_model ** -0.5 |
| |
| # Move SDR retina indices (plain attribute, not buffer) to same device as params. |
| # Required because to_empty() only moves params/buffers, and _retina_indices |
| # is loaded from numpy (always CPU) by SemanticFoldingSDR.__init__. |
| device = self.wte.weight.device |
| if hasattr(self.sdr_semantic, '_retina_indices'): |
| self.sdr_semantic._retina_indices = self.sdr_semantic._retina_indices.to(device) |
| |
| # Embedding init: GPT-2 / LLaMA convention. std=1.0 was chosen for |
| # vocab=8192; at larger vocabs, smaller std prevents logit blowup. |
| # Use std = 1/sqrt(d_model) which scales sensibly with model width. |
| import math as _math |
| _d_model = self.wte.weight.shape[1] |
| wte_std = float(os.environ.get("HYDRA_WTE_STD", str(1.0 / _math.sqrt(_d_model)))) |
| nn.init.normal_(self.wte.weight, mean=0.0, std=wte_std) |
| # LM head init: was std=0.001 — PATHOLOGICAL at vocab>=32k because |
| # logits collapse to zero, loss locks at log(V)~=11, gradient through |
| # head ∝ 1/V is too small to escape. GPT-2 uses std=0.02; LLaMA uses |
| # std=1/sqrt(d_model). Pick 0.02 as robust default, env-overridable. |
| lm_head_std = float(os.environ.get("HYDRA_LM_HEAD_STD", "0.02")) |
| nn.init.normal_(self.lm_head.weight, mean=0.0, std=lm_head_std) |
| # F8 (NOT APPLIED): Weight tying would save V*D params but current LR |
| # groups have embedding_lr=1.0 and unembedding_lr=0.005 × d_model_scale |
| # — tying forces the shared tensor under a single LR group and either |
| # the embeddings learn 200x too slow (under unembed LR) or the LM head |
| # becomes unstable (under embed LR). Short 15-step smoke with tying + |
| # embed-group update showed initial loss jump 9 -> 20. Deferred until |
| # LR groups are re-tuned; see docs/OPTIMIZATION_PLAN.md Post-plan. |
| |
| for li, block in enumerate(self.blocks): |
| if hasattr(block, 'in_proj') and hasattr(block.in_proj, 'weight'): |
| nn.init.uniform_(block.in_proj.weight, -s, s) |
| if hasattr(block, 'out_proj') and hasattr(block.out_proj, 'weight'): |
| # GPT-2 residual init: std = 0.02 / sqrt(2 * n_layer). |
| # NOT zeros — zero init makes the block a permanent pass-through |
| # (block_out_rms=0, zero gradient flow to SSM internals). |
| # With non-zero init the block contributes to the residual stream |
| # from step 1, giving the SSM scan actual gradient signal. |
| n_layer = self.config.n_layer |
| out_std = float(os.environ.get( |
| "HYDRA_OUT_PROJ_STD", |
| str(0.02 / (2 * n_layer) ** 0.5), |
| )) |
| nn.init.normal_(block.out_proj.weight, mean=0.0, std=out_std) |
| |
| nn.init.normal_(self.htm_proj.weight, mean=0.0, std=s) |
| |
| # SDR proj: tiny init preserves the existing residual dynamics at step 0. |
| # The signal ramps up gradually as delta_u/delta_v learn via STE gradients. |
| nn.init.normal_(self.sdr_proj.weight, mean=0.0, std=1e-4) |
| |
| # Cast to bf16 to match Mamba3 dtype; Muon groups by shape so mixed |
| # dtypes in the same shape group would break lerp_ dtype checks. |
| self.wte.to(dtype=torch.bfloat16) |
| self.htm_proj.to(dtype=torch.bfloat16) |
| self.sdr_proj.to(dtype=torch.bfloat16) |
| self.engram.to(dtype=torch.bfloat16) |
| |
| def set_bos_token_id(self, bos_id: int) -> None: |
| """Inform the model of the tokenizer's BOS id so doc-separator |
| masking (learnability #4) knows which positions to skip. Called from |
| training setup once the tokenizer is loaded.""" |
| self._bos_token_id = int(bos_id) |
|
|
| def invalidate_hyena_caches(self) -> None: |
| """Invalidate filter-rfft caches on all Hyena blocks. |
|
|
| MUST be called after each `optimizer.step()` when |
| `HYDRA_HYENA_FILTER_CACHE=1` is set, otherwise cached rfft values |
| will be reused with stale filter parameters. |
|
|
| No-op for blocks that are not HyenaBlock (Mamba3, etc.). |
| """ |
| for block in self.blocks: |
| if hasattr(block, "operator") and hasattr(block.operator, "invalidate_filter_cache"): |
| block.operator.invalidate_filter_cache() |
|
|
| def flush_hyena_pending_grads(self) -> None: |
| """Push pending train-cache filter gradients into filter params. |
|
|
| Used ONLY when HYDRA_HYENA_TRAIN_CACHE=1. Must be called exactly once |
| per optimizer step, BEFORE `optimizer.step()` and BEFORE |
| `invalidate_hyena_caches()`. The lightning_module wires this in |
| `optimizer_step` around the existing optimizer.step() call. |
|
|
| No-op if: |
| * No HyenaBlocks are in the model, OR |
| * No micro-batch ever ran with grad enabled (e.g. all-eval step). |
| """ |
| for block in self.blocks: |
| if hasattr(block, "operator") and hasattr(block.operator, "flush_pending_filter_grads"): |
| block.operator.flush_pending_filter_grads() |
|
|
| def estimate_flops(self) -> int: |
| nparams = sum(p.numel() for p in self.parameters()) |
| embed_params = self.wte.weight.numel() |
| return 6 * (nparams - embed_params) |
|
|
| def num_scaling_params(self) -> dict: |
| wte = sum(p.numel() for p in self.wte.parameters()) |
| lm_head = sum(p.numel() for p in self.lm_head.parameters()) |
| blocks = sum(p.numel() for p in self.blocks.parameters()) |
| sdr = sum(p.numel() for p in self.sdr_semantic.parameters()) |
| htm_proj = sum(p.numel() for p in self.htm_proj.parameters()) |
| engram = sum(p.numel() for p in self.engram.parameters()) |
| total = sum(p.numel() for p in self.parameters()) |
| return { |
| 'wte': wte, 'lm_head': lm_head, 'blocks': blocks, |
| 'sdr_semantic': sdr, 'htm_proj': htm_proj, |
| 'engram': engram, 'total': total, |
| } |
|
|
| def get_secondary_metrics(self) -> dict: |
| """Flush any lingering CUDA tensors to host (single sync).""" |
| flushed = {} |
| for k, v in self._metrics.items(): |
| if hasattr(v, 'item'): |
| try: |
| flushed[k] = float(v.item()) |
| except Exception: |
| flushed[k] = v |
| else: |
| flushed[k] = v |
| return flushed |
|
|
| def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.6, matrix_lr=0.04, |
| weight_decay=0.2, adam_betas=(0.8, 0.95), scalar_lr=0.5): |
| """Setup MuonAdamW optimizer with per-component LR groups.""" |
| model_dim = self.config.d_model |
|
|
| embedding_params = list(self.wte.parameters()) |
| lm_head_params = list(self.lm_head.parameters()) |
|
|
| # Muon routing guard: 2D parameters are NOT automatically matrices. |
| # Exclude: |
| # (a) params whose name ends in `.freq` — Sin frequency vectors used |
| # by Hyena's implicit filter MLP. Shape (1, dim) is nominally 2D |
| # but semantically a per-dim scalar. Muon's polar-express |
| # orthogonalization would force it toward an orthogonal matrix, |
| # destroying the learned modulation frequencies. |
| # (b) 2-D params with min(shape) < MUON_MIN_DIM. Tiny projections |
| # (e.g. HyenaFilter.implicit_filter.0.weight of shape (64, 3)) |
| # get collapsed toward near-identity by orthogonalization on the |
| # narrow axis, damaging expressivity. These belong in AdamW. |
| # These exclusions route the params into the AdamW scalar/vector group. |
| MUON_MIN_DIM = 8 |
|
|
| def _muon_eligible(name: str, p: torch.Tensor) -> bool: |
| if p.dim() != 2: |
| return False |
| if name.endswith(".freq"): |
| return False |
| if min(p.shape) < MUON_MIN_DIM: |
| return False |
| return True |
|
|
| # Matrix params -> Muon (2D weight matrices passing the routing guard). |
| matrix_params = [] |
| for name, p in self.blocks.named_parameters(): |
| if _muon_eligible(name, p): |
| matrix_params.append(p) |
| # NOTE (W1 audit REG-2): SemanticFoldingSDR.delta_u / delta_v are now |
| # GRADIENT-ALIVE via the STE path reconnected in _forward_impl: the |
| # differentiable SDR output feeds into sdr_proj which feeds into the |
| # engram residual. Gradient flows: LM_loss -> engram -> ... -> sdr_proj |
| # -> sdr_semantic.forward() -> STE -> delta_u/delta_v. These are 1D |
| # params (not 2D matrices) so they automatically land in scalar AdamW |
| # via the exclusion below — no special treatment needed. |
| # for p in self.sdr_semantic.parameters(): |
| # if p.dim() == 2: |
| # matrix_params.append(p) |
| for name, p in self.htm_proj.named_parameters(): |
| if _muon_eligible(name, p): |
| matrix_params.append(p) |
| for name, p in self.engram.named_parameters(): |
| if _muon_eligible(name, p): |
| matrix_params.append(p) |
|
|
| # SDR params are intentionally not in any optimizer group — they |
| # receive no gradient in the current forward, so any update would be |
| # pure noise (weight_decay × lr on a zero-grad param). |
| sdr_param_ids = set(id(p) for p in self.sdr_semantic.parameters()) |
| assigned = set(id(p) for p in embedding_params + lm_head_params + matrix_params) |
| # Extract dt_bias from each Mamba3 block into its own high-LR group. |
| # dt_bias controls SSM temporal discretization: softplus(dt_bias) gives |
| # the step size delta_t. A single dt_bias per head is 1D — falls into |
| # scalar_params by default. Give it a dedicated group with HYDRA_DT_BIAS_LR |
| # so heads can differentiate their timescales instead of all tracking at |
| # ln(2) across all layers. |
| dt_bias_params = [p for name, p in self.blocks.named_parameters() |
| if name.endswith('dt_bias')] |
| dt_bias_ids = set(id(p) for p in dt_bias_params) if dt_bias_params else set() |
| scalar_params = [ |
| p for p in self.parameters() |
| if id(p) not in assigned and id(p) not in sdr_param_ids and id(p) not in dt_bias_ids |
| ] |
|
|
| total_assigned = len(embedding_params) + len(lm_head_params) + len(matrix_params) + len(scalar_params) + len(dt_bias_params) |
| total_params = len(list(self.parameters())) |
| sdr_excluded = len(list(self.sdr_semantic.parameters())) |
| assert total_assigned + sdr_excluded == total_params, ( |
| f"Parameter count mismatch: assigned {total_assigned} + sdr_excluded " |
| f"{sdr_excluded} vs total {total_params}" |
| ) |
|
|
| dmodel_lr_scale = (model_dim / 768) ** -0.5 |
| print(f"Scaling AdamW LRs by 1/sqrt({model_dim}/768) = {dmodel_lr_scale:.6f}") |
|
|
| param_groups = [ |
| dict(kind='adamw', params=lm_head_params, |
| lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, |
| eps=1e-10, weight_decay=0.0), |
| dict(kind='adamw', params=embedding_params, |
| lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, |
| eps=1e-10, weight_decay=0.0), |
| ] |
|
|
| if scalar_params: |
| param_groups.append( |
| dict(kind='adamw', params=scalar_params, |
| lr=scalar_lr * dmodel_lr_scale, betas=adam_betas, |
| eps=1e-10, weight_decay=0.0) |
| ) |
|
|
| # dt_bias: dedicated group with embed-level LR so each head learns its |
| # own temporal discretization. Env-overridable for sweeps. |
| if dt_bias_params: |
| _dt_bias_lr = float(os.environ.get("HYDRA_DT_BIAS_LR", str(embedding_lr))) |
| param_groups.append(dict( |
| kind='adamw', params=dt_bias_params, |
| lr=_dt_bias_lr * dmodel_lr_scale, betas=adam_betas, |
| eps=1e-10, weight_decay=0.0, |
| )) |
|
|
| for shape in sorted({p.shape for p in matrix_params}): |
| group_params = [p for p in matrix_params if p.shape == shape] |
| # ns_steps: Muon polar-express inner iterations. Default 5 (paper), |
| # but 3 converges on small matrices (d_model ~ 384) with ~40% lower |
| # optimizer step cost. Env-tunable for experimentation. |
| _ns_steps = int(os.environ.get("HYDRA_MUON_NS_STEPS", "3")) |
| param_groups.append(dict( |
| kind='muon', params=group_params, lr=matrix_lr, |
| momentum=0.95, ns_steps=_ns_steps, beta2=0.95, weight_decay=weight_decay, |
| )) |
|
|
| optimizer = MuonAdamW(param_groups) |
| for group in optimizer.param_groups: |
| group["initial_lr"] = group["lr"] |
| return optimizer |
|
|
| def forward(self, idx, targets=None, reduction='mean'): |
| """idx: (B, T) int64. Returns loss if targets given, else logits. |
|
|
| Nested bf16 autocast is a no-op when ambient autocast is already on; |
| when it's off (e.g. integration tests) we establish the dtype contract. |
| """ |
| if torch.is_autocast_enabled(): |
| return self._forward_impl(idx, targets=targets, reduction=reduction) |
| with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): |
| return self._forward_impl(idx, targets=targets, reduction=reduction) |
| |
| def _forward_impl(self, idx, targets=None, reduction='mean'): |
| B, T = idx.shape |
| |
| # Diagnostic: per-subsystem CUDA event timing. Env-gated; zero overhead |
| # when disabled. Logs one timing line per forward call. Used to isolate |
| # which subsystem is the tps bottleneck on paid hardware. |
| _profile = self._cache_profile_forward |
| if _profile: |
| def _ev(): |
| e = torch.cuda.Event(enable_timing=True) |
| e.record() |
| return e |
| _t0 = _ev() |
| else: |
| _t0 = None |
| |
| # Compute SDR ONCE via the differentiable STE forward. |
| # The STE output is float (autocast dtype); we detach+cast a uint8 copy |
| # for HTM (which only needs binary patterns). This replaces the old |
| # binary_only() call — same scatter cost, but routes LM loss gradients |
| # through sdr_semantic.delta_u/delta_v for the first time. |
| sdr_ste = self.sdr_semantic(idx) # float, with STE grad path |
| sdr_binary = sdr_ste.detach().to(dtype=torch.uint8) # detached uint8 for HTM |
| self._last_sdr = sdr_binary |
| |
| # HTM subsampling: run HTM on 1 of every N micro-batches within a |
| # gradient accumulation step, reuse the cached result for the other |
| # N-1 micro-batches. Cooperative launch monopolizes all SMs (grid.sync |
| # requires full-grid residency), so HTM and mamba can't overlap via |
| # streams. Subsampling removes HTM from most micro-batches' critical |
| # path instead. |
| # |
| # Math: N=8, 64 accum steps → 8 HTM calls (10.6ms each) + 56 fast |
| # calls (4ms each). Total = 84.8 + 224 = 309ms → 106k tps. |
| # |
| # HYDRA_HTM_SUBSAMPLE=N (default 8). Set =1 for every-microbatch HTM. |
| _htm_sub = self._cache_htm_subsample |
| if not hasattr(self, '_htm_call_idx'): |
| self._htm_call_idx = 0 |
| |
| _run_htm = (self._htm_call_idx % _htm_sub == 0) |
| self._htm_call_idx += 1 |
| |
| if _run_htm: |
| htm_handle = self.htm.forward_async(sdr_binary) |
| else: |
| htm_handle = None |
| |
| if _profile: _t_htm_async = _ev() |
| |
| dense_emb = self.wte(idx) # (B, T, d_model) bf16 |
| |
| if _profile: _t_wte = _ev() |
| |
| if _run_htm: |
| htm_out = self.htm.forward_await(htm_handle) |
| self._htm_cache = htm_out.detach() # cache for non-HTM micro-batches |
| elif hasattr(self, '_htm_cache') and self._htm_cache is not None \ |
| and self._htm_cache.shape[0] == B and self._htm_cache.shape[1] == T: |
| htm_out = self._htm_cache |
| else: |
| # Very first call with subsample > 1: run HTM anyway. |
| htm_handle = self.htm.forward_async(sdr_binary) |
| htm_out = self.htm.forward_await(htm_handle) |
| self._htm_cache = htm_out.detach() |
| |
| if _profile: _t_htm_await = _ev() |
| # 3060 hot-path: skip the per-step htm_anomaly mean kernel + sync |
| # unless explicitly requested via HYDRA_HTM_LOG_ANOMALY=1. Saves |
| # one tiny GPU launch + one .item()-style copy per forward call. |
| sdr_active_bits = float(self.sdr_semantic.target_active) |
| if self._cache_htm_log_anomaly: |
| with torch.no_grad(): |
| htm_anomaly = htm_out[..., -1].mean() |
| else: |
| htm_anomaly = None |
| |
| # Learnability #5: explicit stop-grad on HTM output. htm_rust already |
| # produces a detached tensor, but making it explicit here hardens the |
| # contract against future refactors that might route HTM through a |
| # grad-enabled op. |
| if self._htm_stop_grad: |
| htm_out = htm_out.detach() |
| |
| # Gradient bridge: HTM columns+anomaly -> d_model. |
| # 3060 hot-path: only cast when dtypes actually differ (htm_out is |
| # bf16 in normal training, dense_emb is bf16 under autocast — the |
| # .to() in that case is a no-op but still costs Python overhead). |
| if htm_out.dtype != dense_emb.dtype: |
| htm_out = htm_out.to(dense_emb.dtype) |
| htm_proj_out = self.htm_proj(htm_out) |
| x = dense_emb + htm_proj_out |
| x = norm(x) |
| |
| if _profile: _t_htm_proj = _ev() |
| |
| # mHC-routed Mamba-3 stack with Engram injection at configured layer. |
| streams = self.mhc[0].init_streams(x) |
| _engram_ev = None |
| |
| # Per-layer diagnostic panel. The pre-layer merged state h_pre lets us |
| # measure residual contribution of each layer: delta_N = h_post - h_pre. |
| # All reads are detached no-grad to avoid autograd graph pollution. |
| _diag = self._diag_enabled |
| if _diag: |
| # Cast to float32 for the diagnostic arithmetic: the layer's |
| # residual contribution is small (~0.5 × rms-normed block output), |
| # which underflows in bf16 subtraction (3-digit mantissa) and |
| # reports delta_ratio=0 at the boundaries. float32 snapshot is |
| # ~3.8 MB extra memory per diag sample (B=1, T=2048, d=96) — |
| # negligible vs peak VRAM. |
| with torch.no_grad(): |
| h_pre = self.mhc[0].merge_streams(streams).detach().float() |
| _run_svd = (self._diag_step % self._diag_svd_every) == 0 |
|
|
| # 3060 hot-path: hoist grad-ckpt branch + import out of the per-layer |
| # loop. Module-level torch.utils.checkpoint is imported at the top of |
| # the file (see _ckpt alias) — no per-step `import` cost. |
| _use_ckpt = self._grad_ckpt and self.training |
| for i, (block, mhc_layer) in enumerate(zip(self.blocks, self.mhc)): |
| if _use_ckpt: |
| def _block_fn(h, _block=block): |
| return _ckpt.checkpoint( |
| lambda _h: self.drop(_block(norm(_h))), h, use_reentrant=False |
| ) |
| else: |
| def _block_fn(h, _block=block): |
| return self.drop(_block(norm(h))) |
|
|
| streams = mhc_layer(streams, _block_fn) |
|
|
| if i == self.engram_layer_idx: |
| if _profile: _t_pre_engram = _ev() |
| x_mid = mhc_layer.merge_streams(streams) |
| # Inject differentiable SDR signal into the engram query via |
| # a lightweight projection (one bf16 matmul per forward step). |
| # This is the first time LM loss gradients reach the SDR retina |
| # via delta_u/delta_v — the curated semantic folding patterns |
| # can now adapt during training instead of staying frozen. |
| sdr_feat = self.sdr_proj(sdr_ste.to(x_mid.dtype)) |
| # Norm the projection to prevent magnitude blowup: the raw STE |
| # output has 327/16384 1.0 activations per token, and a single |
| # matmul through sdr_proj (16384→256) with no normalization |
| # grows weight norm from 1e-4 to ~182 within 2K steps, |
| # overwhelming the engram residual. |
| sdr_feat = norm(sdr_feat) |
| x_mid = x_mid + sdr_feat |
| x_mid, hit_rate = self.engram(x_mid, idx) |
| streams = mhc_layer.init_streams(x_mid) |
| self._metrics['engram_hit_rate'] = hit_rate |
| if _profile: _engram_ev = _ev() |
|
|
| if _diag: |
| with torch.no_grad(): |
| h_post = mhc_layer.merge_streams(streams).detach().float() |
| in_n = h_pre.pow(2).mean().sqrt() |
| out_n = h_post.pow(2).mean().sqrt() |
| d_n = (h_post - h_pre).pow(2).mean().sqrt() |
| self._metrics[f'layer_{i}_in_norm'] = float(in_n.item()) |
| self._metrics[f'layer_{i}_out_norm'] = float(out_n.item()) |
| self._metrics[f'layer_{i}_delta_ratio'] = float((d_n / (in_n + 1e-6)).item()) |
| self._metrics[f'layer_{i}_feat_std'] = float(h_post.std(dim=-1).mean().item()) |
| if _run_svd: |
| # Effective rank via participation ratio of singular values. |
| # eff_rank = (Σσ)^2 / Σσ² — smooth rank proxy, bounded by d_model. |
| # Sampled to keep overhead low (SVD is O(min(B*T, D)^2·D)). |
| flat = h_post.reshape(-1, h_post.shape[-1])[:512].float() |
| try: |
| s = torch.linalg.svdvals(flat) |
| eff_rank = float(((s.sum() ** 2) / (s.pow(2).sum() + 1e-6)).item()) |
| self._metrics[f'layer_{i}_eff_rank'] = eff_rank |
| except Exception: |
| pass |
| h_pre = h_post |
|
|
| if _diag: |
| self._diag_step += 1 |
|
|
| if _profile: _t_blocks = _ev() |
|
|
| self._metrics['sdr_active_bits'] = sdr_active_bits |
| if htm_anomaly is not None: |
| self._metrics['htm_anomaly'] = htm_anomaly |
|
|
| x = self.mhc[-1].merge_streams(streams) |
| x = norm(x) |
|
|
| if _profile: _t_merge = _ev() |
|
|
| softcap = self.softcap |
| _softcap_clamp = self._cache_softcap_clamp |
| if targets is not None: |
| smoothing = self.config.label_smoothing |
| V = self.config.vocab_size |
|
|
| # Learnability #4: doc-separator masking. In packed rows, |
| # tokenizer.encode(..., prepend=bos_token) places a BOS at every |
| # document boundary. Without masking, the model is penalized for |
| # failing to predict "doc B's BOS" from the last tokens of doc A |
| # — pure noise. We set targets==bos to -1 (ignore_index). Done |
| # BEFORE MTP/entropy/sampled-softmax branches so all downstream |
| # losses inherit the mask. |
| if self._doc_sep_mask and self._bos_token_id >= 0: |
| targets = torch.where( |
| targets == self._bos_token_id, |
| torch.full_like(targets, -1), |
| targets, |
| ) |
|
|
| # Sampled softmax: instead of computing logits for ALL V tokens, |
| # compute only for the target + K random negatives. Reduces the |
| # lm_head matmul from (B*T, d) × (d, V) to (B*T, d) × (d, K+1). |
| # At V=65536 and K=4096: 16× less compute, ~4× tps improvement. |
| # The log-sum-exp correction adjusts for the sampling bias. |
| # Set HYDRA_SAMPLED_SOFTMAX=0 to disable (full softmax). |
| K_neg = self._cache_sampled_softmax_k |
| use_sampled = K_neg > 0 and K_neg < V and self.training |
|
|
| if use_sampled: |
| # Flatten hidden states + targets |
| h_flat = x.reshape(-1, x.shape[-1]) # (B*T, d) |
| t_flat = targets.reshape(-1) # (B*T,) |
| n = h_flat.shape[0] |
|
|
| # Learnability #4 hardening: sampled-softmax gather crashes on |
| # negative ids (-1 from doc-sep mask). Replace -1 with 0 for |
| # gather; the actual loss is masked below. |
| valid_mask_flat = (t_flat >= 0) |
| t_flat_safe = torch.where(valid_mask_flat, t_flat, torch.zeros_like(t_flat)) |
|
|
| # Sample K negatives uniformly from [0, V) |
| neg_ids = torch.randint(0, V, (K_neg,), device=x.device) |
| # Gather lm_head weights for target + negatives |
| all_ids = torch.cat([t_flat_safe, neg_ids]) # (B*T + K,) |
| sampled_w = self.lm_head.weight[all_ids] # (B*T + K, d) |
|
|
| # Compute sampled logits: for each position, dot with its |
| # target weight and all K negative weights. |
| # Target logit: dot product of h[i] with w[target[i]] |
| target_w = sampled_w[:n] # (B*T, d) |
| neg_w = sampled_w[n:] # (K, d) |
| target_logit = (h_flat * target_w).sum(-1) # (B*T,) |
| neg_logits = h_flat @ neg_w.t() # (B*T, K) |
|
|
| if not _softcap_clamp: |
| target_logit = softcap * torch.tanh(target_logit / softcap) |
| neg_logits = softcap * torch.tanh(neg_logits / softcap) |
|
|
| # Sampled softmax loss: -log(exp(target) / (exp(target) + sum(exp(neg)))) |
| # With log-sum-exp correction for sampling K of V negatives. |
| # Correction: add log(V/K) to negative logits to account for |
| # the fact that we're only seeing K of V possible negatives. |
| log_correction = torch.tensor(V / K_neg, device=x.device).log() |
| all_logits = torch.cat([ |
| target_logit.unsqueeze(-1), # (B*T, 1) |
| neg_logits + log_correction, # (B*T, K) |
| ], dim=-1).float() # (B*T, K+1) |
| |
| # CE with target always at index 0 |
| ce_targets = torch.zeros(n, dtype=torch.long, device=x.device) |
| if reduction == 'none': |
| per_tok = F.cross_entropy(all_logits, ce_targets, reduction='none') |
| if self._doc_sep_mask and self._bos_token_id >= 0: |
| per_tok = torch.where(valid_mask_flat, per_tok, torch.zeros_like(per_tok)) |
| return per_tok |
| per_tok_ce = F.cross_entropy( |
| all_logits, ce_targets, reduction='none', |
| label_smoothing=smoothing, |
| ) |
| # Mask doc-separator positions. valid_mask_flat is always |
| # computed; when doc_sep_mask is off every token is valid so |
| # this reduces to a plain mean. |
| valid_f = valid_mask_flat.float() |
| valid_n = valid_f.sum().clamp(min=1) |
| out = (per_tok_ce * valid_f).sum() / valid_n |
| else: |
| # Full softmax path (eval or HYDRA_SAMPLED_SOFTMAX=0) |
| chunk_size = self._cache_ce_chunk |
| if chunk_size <= 0: |
| MAX_LOGITS_BYTES = 256 * 1024 * 1024 |
| tokens_per_chunk = max(V, MAX_LOGITS_BYTES // (V * 4)) |
| chunk_size = max(1, tokens_per_chunk // max(1, B)) |
| chunk_size = min(chunk_size, T) |
| |
| if reduction == 'none': |
| loss_parts = [] |
| for start in range(0, T, chunk_size): |
| end = min(start + chunk_size, T) |
| chunk_logits = self.lm_head(x[:, start:end, :]).float() |
| if _softcap_clamp: |
| chunk_logits = torch.clamp(chunk_logits, -softcap, softcap) |
| else: |
| chunk_logits = softcap * torch.tanh(chunk_logits / softcap) |
| chunk_targets = targets[:, start:end].reshape(-1) |
| chunk_loss = F.cross_entropy( |
| chunk_logits.view(-1, chunk_logits.size(-1)), |
| chunk_targets, ignore_index=-1, reduction='none', |
| ) |
| loss_parts.append(chunk_loss) |
| return torch.cat(loss_parts) |
| |
| total_loss = 0.0 |
| total_tokens = 0 |
| for start in range(0, T, chunk_size): |
| end = min(start + chunk_size, T) |
| chunk_logits = self.lm_head(x[:, start:end, :]).float() |
| if _softcap_clamp: |
| chunk_logits = torch.clamp(chunk_logits, -softcap, softcap) |
| else: |
| chunk_logits = softcap * torch.tanh(chunk_logits / softcap) |
| chunk_targets = targets[:, start:end].reshape(-1) |
| chunk_loss = F.cross_entropy( |
| chunk_logits.view(-1, chunk_logits.size(-1)), |
| chunk_targets, ignore_index=-1, reduction='sum', |
| label_smoothing=smoothing, |
| ) |
| total_loss = total_loss + chunk_loss |
| total_tokens += (chunk_targets != -1).sum() |
| out = total_loss / total_tokens |
| |
| # ----------------------------------------------------------- |
| # Learnability #1: Multi-Token Prediction. |
| # For k in {2..K}, add a CE loss at position (t) predicting |
| # the token at position (t+k), using the SAME lm_head weights |
| # (weight-tied). Cost: K-1 extra CEs on a subset of positions. |
| # Only triggered in reduction='mean' path, training only. |
| # ----------------------------------------------------------- |
| if reduction == 'mean' and self._mtp_k > 1 and self.training and use_sampled: |
| # TRUE zero-cost MTP: reuse primary's neg_logits (B*T, K_neg) |
| # entirely. Only cost per extra head: O(B*T*d) target-weight |
| # gather + dot product. neg_logits is sliced (view) to match. |
| mtp_loss_sum = out.new_tensor(0.0) |
| mtp_terms = 0 |
| # Reshape primary neg_logits back to (B, T, K_neg) so we can slice positions |
| neg_logits_bt = neg_logits.view(B, T, K_neg) |
| for k in range(2, self._mtp_k + 1): |
| shift = k - 1 |
| if T <= shift: |
| continue |
| n_k = B * (T - shift) |
| h_k_flat = x[:, :T - shift, :].reshape(n_k, -1) # (n_k, d) |
| t_k = targets[:, shift:].reshape(-1) # (n_k,) |
| mask_k = (t_k >= 0) |
| t_k_safe = torch.where(mask_k, t_k, torch.zeros_like(t_k)) |
| tgt_w_k = self.lm_head.weight[t_k_safe] # (n_k, d) |
| tgt_logit_k = (h_k_flat * tgt_w_k).sum(-1) # (n_k,) |
| if not _softcap_clamp: |
| tgt_logit_k = softcap * torch.tanh(tgt_logit_k / softcap) |
| # REUSE primary neg_logits — slice positions [:T-shift] |
| neg_logits_k = neg_logits_bt[:, :T - shift, :].reshape(n_k, K_neg) |
| all_logits_k = torch.cat([ |
| tgt_logit_k.unsqueeze(-1), |
| neg_logits_k + log_correction, |
| ], dim=-1).float() |
| ce_targets_k = torch.zeros(n_k, dtype=torch.long, device=x.device) |
| per_tok_ce_k = F.cross_entropy( |
| all_logits_k, ce_targets_k, reduction='none', |
| label_smoothing=smoothing, |
| ) |
| per_tok_ce_k = torch.where(mask_k, per_tok_ce_k, torch.zeros_like(per_tok_ce_k)) |
| n_valid_k = mask_k.sum().clamp(min=1) |
| mtp_loss_sum = mtp_loss_sum + per_tok_ce_k.sum() / n_valid_k |
| mtp_terms += 1 |
| if mtp_terms > 0: |
| out = (out + mtp_loss_sum) / float(mtp_terms + 1) |
|
|
| # |
| # Learnability #6: output entropy penalty. |
| # L += -lambda * H(softmax(logits)). Negative entropy penalizes |
| # peaked distributions; encourages diverse predictions and |
| # breaks repetition loops. Computed on a small subset of |
| # positions to keep V-sized logits cost bounded. |
| # |
| if reduction == 'mean' and self._entropy_penalty > 0.0 and self.training: |
| # Sample up to 64 random positions. V-sized logits on 64 |
| # positions = 64 * V * 4 bytes (~50 MB at V=200k) — fits |
| # on the 3060 and adds ~2 ms. |
| h_flat = x.reshape(-1, x.shape[-1]) |
| n_pos = h_flat.shape[0] |
| n_sample = min(64, n_pos) |
| idx_sample = torch.randint(0, n_pos, (n_sample,), device=x.device) |
| h_sample = h_flat[idx_sample] |
| logits_s = F.linear(h_sample, self.lm_head.weight).float() |
| if _softcap_clamp: |
| logits_s = torch.clamp(logits_s, -softcap, softcap) |
| else: |
| logits_s = softcap * torch.tanh(logits_s / softcap) |
| log_probs = F.log_softmax(logits_s, dim=-1) |
| probs = log_probs.exp() |
| entropy = -(probs * log_probs).sum(-1).mean() # scalar, nats |
| out = out - self._entropy_penalty * entropy |
|
|
| if _profile: |
| _t_end = _ev() |
| torch.cuda.synchronize() |
| def _ms(a, b): return a.elapsed_time(b) |
| print( |
| f"[PROFILE B={B} T={T}] " |
| f"htm_launch={_ms(_t0, _t_htm_async):.2f} " |
| f"wte={_ms(_t_htm_async, _t_wte):.2f} " |
| f"htm_await={_ms(_t_wte, _t_htm_await):.2f} " |
| f"htm_proj={_ms(_t_htm_await, _t_htm_proj):.2f} " |
| f"mamba_mhc_engram={_ms(_t_htm_proj, _t_blocks):.2f} " |
| f"merge={_ms(_t_blocks, _t_merge):.2f} " |
| f"lm_head_loss={_ms(_t_merge, _t_end):.2f} " |
| f"total={_ms(_t0, _t_end):.2f} ms", |
| flush=True, |
| ) |
| return out |
|
|
| logits = self.lm_head(x).float() |
| if _softcap_clamp: |
| logits = torch.clamp(logits, -softcap, softcap) |
| else: |
| logits = softcap * torch.tanh(logits / softcap) |
| return logits |
|
|