feather-a10g-large-runtime / overlay /hydra /model.py.current.bak
icarus112's picture
Upload folder using huggingface_hub
c383594 verified
"""PostSemClawModel — full-architecture model assembly.
Extracted from the monolithic train.py (W1 modularization). Semantics
unchanged. Imports `GPUEngram` from `hydra.engram` and `MuonAdamW` from
`hydra.optimizer`.
Triton kernel integration status (Phase 2):
HYDRA_FUSED_BCNORM — DEFERRED. The bcnorm_fused Triton kernel fuses
LayerNorm + RoPE on B/C projections. However, mamba-ssm's Mamba3 block
uses RMSNormGated (not LayerNorm) for B/C, and RoPE is applied inside
the mamba3_siso_combined CUDA kernel via the Angles parameter. Replacing
would require either (a) monkey-patching RMSNormGated + intercepting the
fused CUDA scan — invasive, 50+ lines, high breakage risk — or (b) a
full custom Mamba3Block reimplementation. Both are out of scope for
Phase 2. The kernel is validated standalone; integration deferred to
Phase 3 when HYDRA moves to a custom SSM block.
HYDRA_FUSED_SSD — DEFERRED. The ssd_exp_trap Triton kernel implements
exponential-trapezoidal discretization as a sequential scan. mamba-ssm's
Mamba3 block delegates the entire scan + gating + output projection to
mamba3_siso_combined (a compiled CUDA kernel with tilelang). Replacing
it would require decomposing the combined kernel into constituent ops
and substituting only the scan — not feasible without a custom block.
Same Phase 3 gate as above.
Both env vars are accepted but currently no-ops (gates read, logged, but
the code path is unchanged). This avoids silent regression if someone
sets them expecting a speedup.
"""
from __future__ import annotations
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as _ckpt
from mamba_ssm import Mamba3
from subsystems.hestia_mini import HestiaQAT
from subsystems.htm import HTMLayer
from subsystems.mhc_mini import ManifoldHyperConnection
from subsystems.sdr_semantic import SemanticFoldingSDR
from hydra.engram import GPUEngram
from hydra.hyena_block import HyenaBlock
# GDNBlock is imported lazily inside __init__ so the `fla` dependency is
# only required when HYDRA_GDN_LAYERS is actually non-empty. Baseline
# pure-Mamba3 runs continue to work without flash-linear-attention installed.
from hydra.optimizer import MuonAdamW
def norm(x: torch.Tensor) -> torch.Tensor:
"""RMSNorm over the last dim — stateless, autocast-friendly."""
return F.rms_norm(x, (x.size(-1),))
class PostSemClawModel(nn.Module):
"""Full Post-SEM-Claw model assembly.
Architecture:
Token Embedding -> [Mamba3 + residual] x n_layer
-> SDR + Engram (at configured layer) -> norm -> LM head
Interface (must match prepare.py evaluate_bpb):
model(x, y, reduction='none').view(-1) -> per-token losses
model(x, y, reduction='mean') -> scalar loss
"""
def __init__(self, config):
super().__init__()
self.config = config
# Token embedding
self.wte = nn.Embedding(config.vocab_size, config.d_model)
# Mamba-3 blocks — official mamba-ssm fused CUDA kernel. No fallbacks.
# RoPE is applied internally by the Mamba3 CUDA kernel via the Angles
# parameter; external cos/sin buffers are not needed.
#
# Hyena supplement: layers whose index appears in `config.hyena_layers`
# are instantiated as HyenaBlock instead of Mamba3. The config field
# is populated from HYDRA_HYENA_LAYERS at construction time and then
# persisted to checkpoints, so resume is safe even when the env var
# is unset. Empty tuple → all-Mamba3, byte-identical to pre-port.
_hyena_layer_set = set(getattr(config, "hyena_layers", ()) or ())
_gdn_layer_set = set(getattr(config, "gdn_layers", ()) or ())
# Hyena wins on overlap; conflict is logged at construction time.
_both = _hyena_layer_set & _gdn_layer_set
if _both:
print(f"[WARN] layers in both hyena_layers and gdn_layers; using Hyena: {sorted(_both)}", flush=True)
_gdn_layer_set -= _hyena_layer_set
if _gdn_layer_set:
from hydra.gdn_block import GDNBlock # requires `fla` package
def _build_block(i: int) -> nn.Module:
if i in _hyena_layer_set:
return HyenaBlock(
d_model=config.d_model,
seq_len=config.sequence_len,
order=int(os.environ.get("HYDRA_HYENA_ORDER", "2")),
filter_order=int(os.environ.get("HYDRA_HYENA_FILTER_DIM", "64")),
)
if i in _gdn_layer_set:
return GDNBlock(
d_model=config.d_model,
n_heads=config.n_heads,
)
block = Mamba3(
d_model=config.d_model,
d_state=config.d_state,
expand=config.expand,
headdim=config.headdim,
is_mimo=False, # SISO path uses stable mamba3_siso_combined kernel
chunk_size=int(os.environ.get("HYDRA_MAMBA3_CHUNK", "64")), # 64 is the validated default; 128 tripped a Triton autotune hang (>8min, no progress)
is_outproj_norm=False,
dtype=torch.bfloat16,
)
# Fix dt_bias gradient starvation: Mamba3 init samples dt uniformly
# in log-space from [0.001, 0.1], giving dt_bias in [-6.9, -2.25].
# softplus'(dt_bias) = sigmoid(dt_bias). At -6.9: 0.1% grad survives;
# at -2.25: 9.5% survives. Shift up so sigmoid is 2-68% instead.
# The SSM can still learn to make dt_bias more negative if finer
# temporal resolution is needed — now the gradient will survive to
# do so.
with torch.no_grad():
block.dt_bias.add_(3.0)
return block
self.blocks = nn.ModuleList([_build_block(i) for i in range(config.n_layer)])
# Full-architecture SDR: offline semantic retina + STE (no-bypass).
self.sdr_semantic = SemanticFoldingSDR(
vocab_size=config.vocab_size,
n_bits=config.sdr_n_bits,
target_active=config.sdr_target_active,
delta_rank=config.sdr_delta_rank,
som_warmup_steps=config.sdr_som_warmup,
som_update_interval=config.sdr_som_interval,
)
# HTM spatial pooler + temporal memory (Rust, Hebbian).
self.htm = HTMLayer(
input_bits=config.sdr_n_bits,
n_columns=config.htm_n_columns,
cells_per_column=config.htm_cells_per_column,
batch_size=1, # grows lazily to actual B on first forward
seed=42,
learn=True,
reset_each_forward=True,
)
# Gradient bridge: (n_columns + anomaly) -> d_model.
self.htm_proj = nn.Linear(config.htm_n_columns + 1, config.d_model, bias=False)
# GPU Engram with Hebbian writes — runs EVERY step.
self.engram = GPUEngram(
d_model=config.d_model,
n_columns=config.engram_n_columns,
max_ngram=3,
)
self.engram_layer_idx = config.engram_layer_idx
# SDR-to-d_model projection: routes differentiable SDR signal (STE)
# into the engram's residual stream. One bf16 matmul per forward step
# at (B*T, n_bits) @ (n_bits, d_model) — ~2.6M MACs vs 5B total,
# <0.05% overhead. Zero extra memory (no intermediates materialized
# beyond the matmul output). This single projection backpropagates LM
# loss gradients through sdr_semantic.delta_u/delta_v, finally giving
# the curated semantic retina real learning signal.
self.sdr_proj = nn.Linear(config.sdr_n_bits, config.d_model, bias=False)
# Manifold-Constrained Hyper-Connections (one per Mamba-3 block).
self.mhc = nn.ModuleList([
ManifoldHyperConnection(d_model=config.d_model, n_streams=2, sinkhorn_iters=3)
for _ in range(config.n_layer)
])
# Hestia QAT — ternary weight quantization applied post-optimizer-step.
self.hestia = HestiaQAT(enabled=True, bits=1.58)
# LM head
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
# Learnability knob 1: Multi-Token Prediction (Llama-3 style).
# MTP_K=1 -> standard next-token. MTP_K>1 -> extra heads predict
# tokens at positions t+1, t+2, ..., t+K. Heads are weight-tied to
# lm_head (we share Parameters), so the only extra compute is
# additional CE losses; no new params. Activated via HYDRA_MTP_K.
self._mtp_k = max(1, int(os.environ.get("HYDRA_MTP_K", "1")))
# Learnability knob 3: gradient checkpointing on Mamba3 blocks.
self._grad_ckpt = os.environ.get("HYDRA_GRAD_CKPT", "0") == "1"
# Learnability knob 4: doc-separator BOS masking in packed sequences.
self._doc_sep_mask = os.environ.get("HYDRA_DOC_SEP_MASK", "0") == "1"
# BOS token id is looked up lazily on first forward (requires tokenizer
# load); -1 means uninitialized.
self._bos_token_id = -1
# Learnability knob 5: explicit stop-grad on HTM tensor (htm_rust
# outputs already have requires_grad=False; this is defense-in-depth).
self._htm_stop_grad = os.environ.get("HYDRA_HTM_STOP_GRAD", "0") == "1"
# Learnability knob 6: entropy penalty coefficient on LM logits.
self._entropy_penalty = float(os.environ.get("HYDRA_ENTROPY_PENALTY", "0.0"))
# Residual dropout
self.drop = nn.Dropout(float(os.environ.get("HYDRA_DROPOUT", "0.2")))
# Logits soft-capping
self.softcap = 15.0
# Secondary metrics storage
self._metrics = {}
# Per-layer diagnostic panel. Env-gated; zero overhead when off.
# Emits residual-contribution (delta_ratio), feature std, effective rank,
# gradient norm per layer; used to identify minimum viable n_layer + find
# entropy leakage / dead layers. See docs/depth-sweep.md.
self._diag_enabled = os.environ.get("HYDRA_LAYER_DIAGNOSTICS", "0") == "1"
self._diag_step = 0
self._diag_svd_every = int(os.environ.get("HYDRA_LAYER_DIAG_SVD_EVERY", "100"))
# 3060 hot-path: cache env-driven flags at construction time so the
# forward pass doesn't pay a dict-lookup + string-parse per step.
# Each os.environ.get() call costs ~1-2us; with ~30 forwards/sec
# this saves single-digit-percent overhead per training step.
self._cache_profile_forward = os.environ.get("HYDRA_PROFILE_FORWARD", "0") == "1"
self._cache_htm_subsample = int(os.environ.get("HYDRA_HTM_SUBSAMPLE", "8"))
self._cache_htm_log_anomaly = os.environ.get("HYDRA_HTM_LOG_ANOMALY", "0") == "1"
self._cache_softcap_clamp = os.environ.get("HYDRA_SOFTCAP_CLAMP", "0") == "1"
self._cache_sampled_softmax_k = int(os.environ.get("HYDRA_SAMPLED_SOFTMAX", "4096"))
self._cache_ce_chunk = int(os.environ.get("HYDRA_CE_CHUNK", "1024"))
if self._diag_enabled:
# Gradient-norm backward hooks on each Mamba3 block output.
for _i, _block in enumerate(self.blocks):
def _mk_grad_hook(_layer_idx):
def _hook(module, grad_input, grad_output):
if grad_output and grad_output[0] is not None:
g = grad_output[0].detach()
self._metrics[f'layer_{_layer_idx}_grad_norm'] = float(
g.pow(2).mean().sqrt().item()
)
return _hook
_block.register_full_backward_hook(_mk_grad_hook(_i))
# Forward hooks on each Mamba3 block capture the block's OUTPUT
# directly. This is the clean measurement: unlike merge_streams()
# sampling which sees (streams + M*block_output) in bf16 — where
# small block contributions round to zero against unit-norm
# residuals — this captures `block_output` itself as produced.
# Reports both its absolute RMS norm and its ratio to the block
# INPUT's RMS norm (contribution magnitude relative to the
# residual it's added to).
for _i, _block in enumerate(self.blocks):
def _mk_fwd_hook(_layer_idx):
def _hook(module, inputs, output):
with torch.no_grad():
inp = inputs[0].detach().float() if inputs else None
out = output.detach().float() if isinstance(output, torch.Tensor) else None
if out is not None:
out_rms = out.pow(2).mean().sqrt().item()
self._metrics[f'layer_{_layer_idx}_block_out_rms'] = float(out_rms)
if inp is not None:
in_rms = inp.pow(2).mean().sqrt().item()
self._metrics[f'layer_{_layer_idx}_block_in_rms'] = float(in_rms)
self._metrics[f'layer_{_layer_idx}_contrib_ratio'] = float(
out_rms / (in_rms + 1e-8)
)
return _hook
_block.register_forward_hook(_mk_fwd_hook(_i))
# Triton kernel integration gates (Phase 2 — deferred, see module docstring).
self._fused_bcnorm = os.environ.get("HYDRA_FUSED_BCNORM", "0") == "1"
self._fused_ssd = os.environ.get("HYDRA_FUSED_SSD", "0") == "1"
if self._fused_bcnorm or self._fused_ssd:
import sys
_active = []
if self._fused_bcnorm:
_active.append("HYDRA_FUSED_BCNORM")
if self._fused_ssd:
_active.append("HYDRA_FUSED_SSD")
print(
f"[HYDRA] Triton kernel gates set: {', '.join(_active)}. "
f"NOTE: Both are DEFERRED (mamba-ssm Mamba3 uses internal "
f"CUDA kernels). Gates accepted but currently no-ops.",
file=sys.stderr,
)
# R6 optional torch.compile on the impl forward. Gated (default OFF).
if os.environ.get("HYDRA_MODEL_COMPILE", "0") == "1":
self._forward_impl = torch.compile(
self._forward_impl,
fullgraph=False,
dynamic=True,
mode="default",
)
@torch.no_grad()
def init_weights(self) -> None:
s = 3 ** 0.5 * self.config.d_model ** -0.5
# Move SDR retina indices (plain attribute, not buffer) to same device as params.
# Required because to_empty() only moves params/buffers, and _retina_indices
# is loaded from numpy (always CPU) by SemanticFoldingSDR.__init__.
device = self.wte.weight.device
if hasattr(self.sdr_semantic, '_retina_indices'):
self.sdr_semantic._retina_indices = self.sdr_semantic._retina_indices.to(device)
# Embedding init: GPT-2 / LLaMA convention. std=1.0 was chosen for
# vocab=8192; at larger vocabs, smaller std prevents logit blowup.
# Use std = 1/sqrt(d_model) which scales sensibly with model width.
import math as _math
_d_model = self.wte.weight.shape[1]
wte_std = float(os.environ.get("HYDRA_WTE_STD", str(1.0 / _math.sqrt(_d_model))))
nn.init.normal_(self.wte.weight, mean=0.0, std=wte_std)
# LM head init: was std=0.001 — PATHOLOGICAL at vocab>=32k because
# logits collapse to zero, loss locks at log(V)~=11, gradient through
# head ∝ 1/V is too small to escape. GPT-2 uses std=0.02; LLaMA uses
# std=1/sqrt(d_model). Pick 0.02 as robust default, env-overridable.
lm_head_std = float(os.environ.get("HYDRA_LM_HEAD_STD", "0.02"))
nn.init.normal_(self.lm_head.weight, mean=0.0, std=lm_head_std)
# F8 (NOT APPLIED): Weight tying would save V*D params but current LR
# groups have embedding_lr=1.0 and unembedding_lr=0.005 × d_model_scale
# — tying forces the shared tensor under a single LR group and either
# the embeddings learn 200x too slow (under unembed LR) or the LM head
# becomes unstable (under embed LR). Short 15-step smoke with tying +
# embed-group update showed initial loss jump 9 -> 20. Deferred until
# LR groups are re-tuned; see docs/OPTIMIZATION_PLAN.md Post-plan.
for li, block in enumerate(self.blocks):
if hasattr(block, 'in_proj') and hasattr(block.in_proj, 'weight'):
nn.init.uniform_(block.in_proj.weight, -s, s)
if hasattr(block, 'out_proj') and hasattr(block.out_proj, 'weight'):
# GPT-2 residual init: std = 0.02 / sqrt(2 * n_layer).
# NOT zeros — zero init makes the block a permanent pass-through
# (block_out_rms=0, zero gradient flow to SSM internals).
# With non-zero init the block contributes to the residual stream
# from step 1, giving the SSM scan actual gradient signal.
n_layer = self.config.n_layer
out_std = float(os.environ.get(
"HYDRA_OUT_PROJ_STD",
str(0.02 / (2 * n_layer) ** 0.5),
))
nn.init.normal_(block.out_proj.weight, mean=0.0, std=out_std)
nn.init.normal_(self.htm_proj.weight, mean=0.0, std=s)
# SDR proj: tiny init preserves the existing residual dynamics at step 0.
# The signal ramps up gradually as delta_u/delta_v learn via STE gradients.
nn.init.normal_(self.sdr_proj.weight, mean=0.0, std=1e-4)
# Cast to bf16 to match Mamba3 dtype; Muon groups by shape so mixed
# dtypes in the same shape group would break lerp_ dtype checks.
self.wte.to(dtype=torch.bfloat16)
self.htm_proj.to(dtype=torch.bfloat16)
self.sdr_proj.to(dtype=torch.bfloat16)
self.engram.to(dtype=torch.bfloat16)
def set_bos_token_id(self, bos_id: int) -> None:
"""Inform the model of the tokenizer's BOS id so doc-separator
masking (learnability #4) knows which positions to skip. Called from
training setup once the tokenizer is loaded."""
self._bos_token_id = int(bos_id)
def invalidate_hyena_caches(self) -> None:
"""Invalidate filter-rfft caches on all Hyena blocks.
MUST be called after each `optimizer.step()` when
`HYDRA_HYENA_FILTER_CACHE=1` is set, otherwise cached rfft values
will be reused with stale filter parameters.
No-op for blocks that are not HyenaBlock (Mamba3, etc.).
"""
for block in self.blocks:
if hasattr(block, "operator") and hasattr(block.operator, "invalidate_filter_cache"):
block.operator.invalidate_filter_cache()
def flush_hyena_pending_grads(self) -> None:
"""Push pending train-cache filter gradients into filter params.
Used ONLY when HYDRA_HYENA_TRAIN_CACHE=1. Must be called exactly once
per optimizer step, BEFORE `optimizer.step()` and BEFORE
`invalidate_hyena_caches()`. The lightning_module wires this in
`optimizer_step` around the existing optimizer.step() call.
No-op if:
* No HyenaBlocks are in the model, OR
* No micro-batch ever ran with grad enabled (e.g. all-eval step).
"""
for block in self.blocks:
if hasattr(block, "operator") and hasattr(block.operator, "flush_pending_filter_grads"):
block.operator.flush_pending_filter_grads()
def estimate_flops(self) -> int:
nparams = sum(p.numel() for p in self.parameters())
embed_params = self.wte.weight.numel()
return 6 * (nparams - embed_params)
def num_scaling_params(self) -> dict:
wte = sum(p.numel() for p in self.wte.parameters())
lm_head = sum(p.numel() for p in self.lm_head.parameters())
blocks = sum(p.numel() for p in self.blocks.parameters())
sdr = sum(p.numel() for p in self.sdr_semantic.parameters())
htm_proj = sum(p.numel() for p in self.htm_proj.parameters())
engram = sum(p.numel() for p in self.engram.parameters())
total = sum(p.numel() for p in self.parameters())
return {
'wte': wte, 'lm_head': lm_head, 'blocks': blocks,
'sdr_semantic': sdr, 'htm_proj': htm_proj,
'engram': engram, 'total': total,
}
def get_secondary_metrics(self) -> dict:
"""Flush any lingering CUDA tensors to host (single sync)."""
flushed = {}
for k, v in self._metrics.items():
if hasattr(v, 'item'):
try:
flushed[k] = float(v.item())
except Exception:
flushed[k] = v
else:
flushed[k] = v
return flushed
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.6, matrix_lr=0.04,
weight_decay=0.2, adam_betas=(0.8, 0.95), scalar_lr=0.5):
"""Setup MuonAdamW optimizer with per-component LR groups."""
model_dim = self.config.d_model
embedding_params = list(self.wte.parameters())
lm_head_params = list(self.lm_head.parameters())
# Muon routing guard: 2D parameters are NOT automatically matrices.
# Exclude:
# (a) params whose name ends in `.freq` — Sin frequency vectors used
# by Hyena's implicit filter MLP. Shape (1, dim) is nominally 2D
# but semantically a per-dim scalar. Muon's polar-express
# orthogonalization would force it toward an orthogonal matrix,
# destroying the learned modulation frequencies.
# (b) 2-D params with min(shape) < MUON_MIN_DIM. Tiny projections
# (e.g. HyenaFilter.implicit_filter.0.weight of shape (64, 3))
# get collapsed toward near-identity by orthogonalization on the
# narrow axis, damaging expressivity. These belong in AdamW.
# These exclusions route the params into the AdamW scalar/vector group.
MUON_MIN_DIM = 8
def _muon_eligible(name: str, p: torch.Tensor) -> bool:
if p.dim() != 2:
return False
if name.endswith(".freq"):
return False
if min(p.shape) < MUON_MIN_DIM:
return False
return True
# Matrix params -> Muon (2D weight matrices passing the routing guard).
matrix_params = []
for name, p in self.blocks.named_parameters():
if _muon_eligible(name, p):
matrix_params.append(p)
# NOTE (W1 audit REG-2): SemanticFoldingSDR.delta_u / delta_v are now
# GRADIENT-ALIVE via the STE path reconnected in _forward_impl: the
# differentiable SDR output feeds into sdr_proj which feeds into the
# engram residual. Gradient flows: LM_loss -> engram -> ... -> sdr_proj
# -> sdr_semantic.forward() -> STE -> delta_u/delta_v. These are 1D
# params (not 2D matrices) so they automatically land in scalar AdamW
# via the exclusion below — no special treatment needed.
# for p in self.sdr_semantic.parameters():
# if p.dim() == 2:
# matrix_params.append(p)
for name, p in self.htm_proj.named_parameters():
if _muon_eligible(name, p):
matrix_params.append(p)
for name, p in self.engram.named_parameters():
if _muon_eligible(name, p):
matrix_params.append(p)
# SDR params are intentionally not in any optimizer group — they
# receive no gradient in the current forward, so any update would be
# pure noise (weight_decay × lr on a zero-grad param).
sdr_param_ids = set(id(p) for p in self.sdr_semantic.parameters())
assigned = set(id(p) for p in embedding_params + lm_head_params + matrix_params)
# Extract dt_bias from each Mamba3 block into its own high-LR group.
# dt_bias controls SSM temporal discretization: softplus(dt_bias) gives
# the step size delta_t. A single dt_bias per head is 1D — falls into
# scalar_params by default. Give it a dedicated group with HYDRA_DT_BIAS_LR
# so heads can differentiate their timescales instead of all tracking at
# ln(2) across all layers.
dt_bias_params = [p for name, p in self.blocks.named_parameters()
if name.endswith('dt_bias')]
dt_bias_ids = set(id(p) for p in dt_bias_params) if dt_bias_params else set()
scalar_params = [
p for p in self.parameters()
if id(p) not in assigned and id(p) not in sdr_param_ids and id(p) not in dt_bias_ids
]
total_assigned = len(embedding_params) + len(lm_head_params) + len(matrix_params) + len(scalar_params) + len(dt_bias_params)
total_params = len(list(self.parameters()))
sdr_excluded = len(list(self.sdr_semantic.parameters()))
assert total_assigned + sdr_excluded == total_params, (
f"Parameter count mismatch: assigned {total_assigned} + sdr_excluded "
f"{sdr_excluded} vs total {total_params}"
)
dmodel_lr_scale = (model_dim / 768) ** -0.5
print(f"Scaling AdamW LRs by 1/sqrt({model_dim}/768) = {dmodel_lr_scale:.6f}")
param_groups = [
dict(kind='adamw', params=lm_head_params,
lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas,
eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=embedding_params,
lr=embedding_lr * dmodel_lr_scale, betas=adam_betas,
eps=1e-10, weight_decay=0.0),
]
if scalar_params:
param_groups.append(
dict(kind='adamw', params=scalar_params,
lr=scalar_lr * dmodel_lr_scale, betas=adam_betas,
eps=1e-10, weight_decay=0.0)
)
# dt_bias: dedicated group with embed-level LR so each head learns its
# own temporal discretization. Env-overridable for sweeps.
if dt_bias_params:
_dt_bias_lr = float(os.environ.get("HYDRA_DT_BIAS_LR", str(embedding_lr)))
param_groups.append(dict(
kind='adamw', params=dt_bias_params,
lr=_dt_bias_lr * dmodel_lr_scale, betas=adam_betas,
eps=1e-10, weight_decay=0.0,
))
for shape in sorted({p.shape for p in matrix_params}):
group_params = [p for p in matrix_params if p.shape == shape]
# ns_steps: Muon polar-express inner iterations. Default 5 (paper),
# but 3 converges on small matrices (d_model ~ 384) with ~40% lower
# optimizer step cost. Env-tunable for experimentation.
_ns_steps = int(os.environ.get("HYDRA_MUON_NS_STEPS", "3"))
param_groups.append(dict(
kind='muon', params=group_params, lr=matrix_lr,
momentum=0.95, ns_steps=_ns_steps, beta2=0.95, weight_decay=weight_decay,
))
optimizer = MuonAdamW(param_groups)
for group in optimizer.param_groups:
group["initial_lr"] = group["lr"]
return optimizer
def forward(self, idx, targets=None, reduction='mean'):
"""idx: (B, T) int64. Returns loss if targets given, else logits.
Nested bf16 autocast is a no-op when ambient autocast is already on;
when it's off (e.g. integration tests) we establish the dtype contract.
"""
if torch.is_autocast_enabled():
return self._forward_impl(idx, targets=targets, reduction=reduction)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
return self._forward_impl(idx, targets=targets, reduction=reduction)
def _forward_impl(self, idx, targets=None, reduction='mean'):
B, T = idx.shape
# Diagnostic: per-subsystem CUDA event timing. Env-gated; zero overhead
# when disabled. Logs one timing line per forward call. Used to isolate
# which subsystem is the tps bottleneck on paid hardware.
_profile = self._cache_profile_forward
if _profile:
def _ev():
e = torch.cuda.Event(enable_timing=True)
e.record()
return e
_t0 = _ev()
else:
_t0 = None
# Compute SDR ONCE via the differentiable STE forward.
# The STE output is float (autocast dtype); we detach+cast a uint8 copy
# for HTM (which only needs binary patterns). This replaces the old
# binary_only() call — same scatter cost, but routes LM loss gradients
# through sdr_semantic.delta_u/delta_v for the first time.
sdr_ste = self.sdr_semantic(idx) # float, with STE grad path
sdr_binary = sdr_ste.detach().to(dtype=torch.uint8) # detached uint8 for HTM
self._last_sdr = sdr_binary
# HTM subsampling: run HTM on 1 of every N micro-batches within a
# gradient accumulation step, reuse the cached result for the other
# N-1 micro-batches. Cooperative launch monopolizes all SMs (grid.sync
# requires full-grid residency), so HTM and mamba can't overlap via
# streams. Subsampling removes HTM from most micro-batches' critical
# path instead.
#
# Math: N=8, 64 accum steps → 8 HTM calls (10.6ms each) + 56 fast
# calls (4ms each). Total = 84.8 + 224 = 309ms → 106k tps.
#
# HYDRA_HTM_SUBSAMPLE=N (default 8). Set =1 for every-microbatch HTM.
_htm_sub = self._cache_htm_subsample
if not hasattr(self, '_htm_call_idx'):
self._htm_call_idx = 0
_run_htm = (self._htm_call_idx % _htm_sub == 0)
self._htm_call_idx += 1
if _run_htm:
htm_handle = self.htm.forward_async(sdr_binary)
else:
htm_handle = None
if _profile: _t_htm_async = _ev()
dense_emb = self.wte(idx) # (B, T, d_model) bf16
if _profile: _t_wte = _ev()
if _run_htm:
htm_out = self.htm.forward_await(htm_handle)
self._htm_cache = htm_out.detach() # cache for non-HTM micro-batches
elif hasattr(self, '_htm_cache') and self._htm_cache is not None \
and self._htm_cache.shape[0] == B and self._htm_cache.shape[1] == T:
htm_out = self._htm_cache
else:
# Very first call with subsample > 1: run HTM anyway.
htm_handle = self.htm.forward_async(sdr_binary)
htm_out = self.htm.forward_await(htm_handle)
self._htm_cache = htm_out.detach()
if _profile: _t_htm_await = _ev()
# 3060 hot-path: skip the per-step htm_anomaly mean kernel + sync
# unless explicitly requested via HYDRA_HTM_LOG_ANOMALY=1. Saves
# one tiny GPU launch + one .item()-style copy per forward call.
sdr_active_bits = float(self.sdr_semantic.target_active)
if self._cache_htm_log_anomaly:
with torch.no_grad():
htm_anomaly = htm_out[..., -1].mean()
else:
htm_anomaly = None
# Learnability #5: explicit stop-grad on HTM output. htm_rust already
# produces a detached tensor, but making it explicit here hardens the
# contract against future refactors that might route HTM through a
# grad-enabled op.
if self._htm_stop_grad:
htm_out = htm_out.detach()
# Gradient bridge: HTM columns+anomaly -> d_model.
# 3060 hot-path: only cast when dtypes actually differ (htm_out is
# bf16 in normal training, dense_emb is bf16 under autocast — the
# .to() in that case is a no-op but still costs Python overhead).
if htm_out.dtype != dense_emb.dtype:
htm_out = htm_out.to(dense_emb.dtype)
htm_proj_out = self.htm_proj(htm_out)
x = dense_emb + htm_proj_out
x = norm(x)
if _profile: _t_htm_proj = _ev()
# mHC-routed Mamba-3 stack with Engram injection at configured layer.
streams = self.mhc[0].init_streams(x)
_engram_ev = None
# Per-layer diagnostic panel. The pre-layer merged state h_pre lets us
# measure residual contribution of each layer: delta_N = h_post - h_pre.
# All reads are detached no-grad to avoid autograd graph pollution.
_diag = self._diag_enabled
if _diag:
# Cast to float32 for the diagnostic arithmetic: the layer's
# residual contribution is small (~0.5 × rms-normed block output),
# which underflows in bf16 subtraction (3-digit mantissa) and
# reports delta_ratio=0 at the boundaries. float32 snapshot is
# ~3.8 MB extra memory per diag sample (B=1, T=2048, d=96) —
# negligible vs peak VRAM.
with torch.no_grad():
h_pre = self.mhc[0].merge_streams(streams).detach().float()
_run_svd = (self._diag_step % self._diag_svd_every) == 0
# 3060 hot-path: hoist grad-ckpt branch + import out of the per-layer
# loop. Module-level torch.utils.checkpoint is imported at the top of
# the file (see _ckpt alias) — no per-step `import` cost.
_use_ckpt = self._grad_ckpt and self.training
for i, (block, mhc_layer) in enumerate(zip(self.blocks, self.mhc)):
if _use_ckpt:
def _block_fn(h, _block=block):
return _ckpt.checkpoint(
lambda _h: self.drop(_block(norm(_h))), h, use_reentrant=False
)
else:
def _block_fn(h, _block=block):
return self.drop(_block(norm(h)))
streams = mhc_layer(streams, _block_fn)
if i == self.engram_layer_idx:
if _profile: _t_pre_engram = _ev()
x_mid = mhc_layer.merge_streams(streams)
# Inject differentiable SDR signal into the engram query via
# a lightweight projection (one bf16 matmul per forward step).
# This is the first time LM loss gradients reach the SDR retina
# via delta_u/delta_v — the curated semantic folding patterns
# can now adapt during training instead of staying frozen.
sdr_feat = self.sdr_proj(sdr_ste.to(x_mid.dtype))
# Norm the projection to prevent magnitude blowup: the raw STE
# output has 327/16384 1.0 activations per token, and a single
# matmul through sdr_proj (16384256) with no normalization
# grows weight norm from 1e-4 to ~182 within 2K steps,
# overwhelming the engram residual.
sdr_feat = norm(sdr_feat)
x_mid = x_mid + sdr_feat
x_mid, hit_rate = self.engram(x_mid, idx)
streams = mhc_layer.init_streams(x_mid)
self._metrics['engram_hit_rate'] = hit_rate
if _profile: _engram_ev = _ev()
if _diag:
with torch.no_grad():
h_post = mhc_layer.merge_streams(streams).detach().float()
in_n = h_pre.pow(2).mean().sqrt()
out_n = h_post.pow(2).mean().sqrt()
d_n = (h_post - h_pre).pow(2).mean().sqrt()
self._metrics[f'layer_{i}_in_norm'] = float(in_n.item())
self._metrics[f'layer_{i}_out_norm'] = float(out_n.item())
self._metrics[f'layer_{i}_delta_ratio'] = float((d_n / (in_n + 1e-6)).item())
self._metrics[f'layer_{i}_feat_std'] = float(h_post.std(dim=-1).mean().item())
if _run_svd:
# Effective rank via participation ratio of singular values.
# eff_rank = (Σσ)^2 / Σσ² — smooth rank proxy, bounded by d_model.
# Sampled to keep overhead low (SVD is O(min(B*T, D)^2·D)).
flat = h_post.reshape(-1, h_post.shape[-1])[:512].float()
try:
s = torch.linalg.svdvals(flat)
eff_rank = float(((s.sum() ** 2) / (s.pow(2).sum() + 1e-6)).item())
self._metrics[f'layer_{i}_eff_rank'] = eff_rank
except Exception:
pass
h_pre = h_post
if _diag:
self._diag_step += 1
if _profile: _t_blocks = _ev()
self._metrics['sdr_active_bits'] = sdr_active_bits
if htm_anomaly is not None:
self._metrics['htm_anomaly'] = htm_anomaly
x = self.mhc[-1].merge_streams(streams)
x = norm(x)
if _profile: _t_merge = _ev()
softcap = self.softcap
_softcap_clamp = self._cache_softcap_clamp
if targets is not None:
smoothing = self.config.label_smoothing
V = self.config.vocab_size
# Learnability #4: doc-separator masking. In packed rows,
# tokenizer.encode(..., prepend=bos_token) places a BOS at every
# document boundary. Without masking, the model is penalized for
# failing to predict "doc B's BOS" from the last tokens of doc A
# — pure noise. We set targets==bos to -1 (ignore_index). Done
# BEFORE MTP/entropy/sampled-softmax branches so all downstream
# losses inherit the mask.
if self._doc_sep_mask and self._bos_token_id >= 0:
targets = torch.where(
targets == self._bos_token_id,
torch.full_like(targets, -1),
targets,
)
# Sampled softmax: instead of computing logits for ALL V tokens,
# compute only for the target + K random negatives. Reduces the
# lm_head matmul from (B*T, d) × (d, V) to (B*T, d) × (d, K+1).
# At V=65536 and K=4096: 16× less compute, ~4× tps improvement.
# The log-sum-exp correction adjusts for the sampling bias.
# Set HYDRA_SAMPLED_SOFTMAX=0 to disable (full softmax).
K_neg = self._cache_sampled_softmax_k
use_sampled = K_neg > 0 and K_neg < V and self.training
if use_sampled:
# Flatten hidden states + targets
h_flat = x.reshape(-1, x.shape[-1]) # (B*T, d)
t_flat = targets.reshape(-1) # (B*T,)
n = h_flat.shape[0]
# Learnability #4 hardening: sampled-softmax gather crashes on
# negative ids (-1 from doc-sep mask). Replace -1 with 0 for
# gather; the actual loss is masked below.
valid_mask_flat = (t_flat >= 0)
t_flat_safe = torch.where(valid_mask_flat, t_flat, torch.zeros_like(t_flat))
# Sample K negatives uniformly from [0, V)
neg_ids = torch.randint(0, V, (K_neg,), device=x.device)
# Gather lm_head weights for target + negatives
all_ids = torch.cat([t_flat_safe, neg_ids]) # (B*T + K,)
sampled_w = self.lm_head.weight[all_ids] # (B*T + K, d)
# Compute sampled logits: for each position, dot with its
# target weight and all K negative weights.
# Target logit: dot product of h[i] with w[target[i]]
target_w = sampled_w[:n] # (B*T, d)
neg_w = sampled_w[n:] # (K, d)
target_logit = (h_flat * target_w).sum(-1) # (B*T,)
neg_logits = h_flat @ neg_w.t() # (B*T, K)
if not _softcap_clamp:
target_logit = softcap * torch.tanh(target_logit / softcap)
neg_logits = softcap * torch.tanh(neg_logits / softcap)
# Sampled softmax loss: -log(exp(target) / (exp(target) + sum(exp(neg))))
# With log-sum-exp correction for sampling K of V negatives.
# Correction: add log(V/K) to negative logits to account for
# the fact that we're only seeing K of V possible negatives.
log_correction = torch.tensor(V / K_neg, device=x.device).log()
all_logits = torch.cat([
target_logit.unsqueeze(-1), # (B*T, 1)
neg_logits + log_correction, # (B*T, K)
], dim=-1).float() # (B*T, K+1)
# CE with target always at index 0
ce_targets = torch.zeros(n, dtype=torch.long, device=x.device)
if reduction == 'none':
per_tok = F.cross_entropy(all_logits, ce_targets, reduction='none')
if self._doc_sep_mask and self._bos_token_id >= 0:
per_tok = torch.where(valid_mask_flat, per_tok, torch.zeros_like(per_tok))
return per_tok
per_tok_ce = F.cross_entropy(
all_logits, ce_targets, reduction='none',
label_smoothing=smoothing,
)
# Mask doc-separator positions. valid_mask_flat is always
# computed; when doc_sep_mask is off every token is valid so
# this reduces to a plain mean.
valid_f = valid_mask_flat.float()
valid_n = valid_f.sum().clamp(min=1)
out = (per_tok_ce * valid_f).sum() / valid_n
else:
# Full softmax path (eval or HYDRA_SAMPLED_SOFTMAX=0)
chunk_size = self._cache_ce_chunk
if chunk_size <= 0:
MAX_LOGITS_BYTES = 256 * 1024 * 1024
tokens_per_chunk = max(V, MAX_LOGITS_BYTES // (V * 4))
chunk_size = max(1, tokens_per_chunk // max(1, B))
chunk_size = min(chunk_size, T)
if reduction == 'none':
loss_parts = []
for start in range(0, T, chunk_size):
end = min(start + chunk_size, T)
chunk_logits = self.lm_head(x[:, start:end, :]).float()
if _softcap_clamp:
chunk_logits = torch.clamp(chunk_logits, -softcap, softcap)
else:
chunk_logits = softcap * torch.tanh(chunk_logits / softcap)
chunk_targets = targets[:, start:end].reshape(-1)
chunk_loss = F.cross_entropy(
chunk_logits.view(-1, chunk_logits.size(-1)),
chunk_targets, ignore_index=-1, reduction='none',
)
loss_parts.append(chunk_loss)
return torch.cat(loss_parts)
total_loss = 0.0
total_tokens = 0
for start in range(0, T, chunk_size):
end = min(start + chunk_size, T)
chunk_logits = self.lm_head(x[:, start:end, :]).float()
if _softcap_clamp:
chunk_logits = torch.clamp(chunk_logits, -softcap, softcap)
else:
chunk_logits = softcap * torch.tanh(chunk_logits / softcap)
chunk_targets = targets[:, start:end].reshape(-1)
chunk_loss = F.cross_entropy(
chunk_logits.view(-1, chunk_logits.size(-1)),
chunk_targets, ignore_index=-1, reduction='sum',
label_smoothing=smoothing,
)
total_loss = total_loss + chunk_loss
total_tokens += (chunk_targets != -1).sum()
out = total_loss / total_tokens
# -----------------------------------------------------------
# Learnability #1: Multi-Token Prediction.
# For k in {2..K}, add a CE loss at position (t) predicting
# the token at position (t+k), using the SAME lm_head weights
# (weight-tied). Cost: K-1 extra CEs on a subset of positions.
# Only triggered in reduction='mean' path, training only.
# -----------------------------------------------------------
if reduction == 'mean' and self._mtp_k > 1 and self.training and use_sampled:
# TRUE zero-cost MTP: reuse primary's neg_logits (B*T, K_neg)
# entirely. Only cost per extra head: O(B*T*d) target-weight
# gather + dot product. neg_logits is sliced (view) to match.
mtp_loss_sum = out.new_tensor(0.0)
mtp_terms = 0
# Reshape primary neg_logits back to (B, T, K_neg) so we can slice positions
neg_logits_bt = neg_logits.view(B, T, K_neg)
for k in range(2, self._mtp_k + 1):
shift = k - 1
if T <= shift:
continue
n_k = B * (T - shift)
h_k_flat = x[:, :T - shift, :].reshape(n_k, -1) # (n_k, d)
t_k = targets[:, shift:].reshape(-1) # (n_k,)
mask_k = (t_k >= 0)
t_k_safe = torch.where(mask_k, t_k, torch.zeros_like(t_k))
tgt_w_k = self.lm_head.weight[t_k_safe] # (n_k, d)
tgt_logit_k = (h_k_flat * tgt_w_k).sum(-1) # (n_k,)
if not _softcap_clamp:
tgt_logit_k = softcap * torch.tanh(tgt_logit_k / softcap)
# REUSE primary neg_logits — slice positions [:T-shift]
neg_logits_k = neg_logits_bt[:, :T - shift, :].reshape(n_k, K_neg)
all_logits_k = torch.cat([
tgt_logit_k.unsqueeze(-1),
neg_logits_k + log_correction,
], dim=-1).float()
ce_targets_k = torch.zeros(n_k, dtype=torch.long, device=x.device)
per_tok_ce_k = F.cross_entropy(
all_logits_k, ce_targets_k, reduction='none',
label_smoothing=smoothing,
)
per_tok_ce_k = torch.where(mask_k, per_tok_ce_k, torch.zeros_like(per_tok_ce_k))
n_valid_k = mask_k.sum().clamp(min=1)
mtp_loss_sum = mtp_loss_sum + per_tok_ce_k.sum() / n_valid_k
mtp_terms += 1
if mtp_terms > 0:
out = (out + mtp_loss_sum) / float(mtp_terms + 1)
# -----------------------------------------------------------
# Learnability #6: output entropy penalty.
# L += -lambda * H(softmax(logits)). Negative entropy penalizes
# peaked distributions; encourages diverse predictions and
# breaks repetition loops. Computed on a small subset of
# positions to keep V-sized logits cost bounded.
# -----------------------------------------------------------
if reduction == 'mean' and self._entropy_penalty > 0.0 and self.training:
# Sample up to 64 random positions. V-sized logits on 64
# positions = 64 * V * 4 bytes (~50 MB at V=200k) — fits
# on the 3060 and adds ~2 ms.
h_flat = x.reshape(-1, x.shape[-1])
n_pos = h_flat.shape[0]
n_sample = min(64, n_pos)
idx_sample = torch.randint(0, n_pos, (n_sample,), device=x.device)
h_sample = h_flat[idx_sample]
logits_s = F.linear(h_sample, self.lm_head.weight).float()
if _softcap_clamp:
logits_s = torch.clamp(logits_s, -softcap, softcap)
else:
logits_s = softcap * torch.tanh(logits_s / softcap)
log_probs = F.log_softmax(logits_s, dim=-1)
probs = log_probs.exp()
entropy = -(probs * log_probs).sum(-1).mean() # scalar, nats
out = out - self._entropy_penalty * entropy
if _profile:
_t_end = _ev()
torch.cuda.synchronize()
def _ms(a, b): return a.elapsed_time(b)
print(
f"[PROFILE B={B} T={T}] "
f"htm_launch={_ms(_t0, _t_htm_async):.2f} "
f"wte={_ms(_t_htm_async, _t_wte):.2f} "
f"htm_await={_ms(_t_wte, _t_htm_await):.2f} "
f"htm_proj={_ms(_t_htm_await, _t_htm_proj):.2f} "
f"mamba_mhc_engram={_ms(_t_htm_proj, _t_blocks):.2f} "
f"merge={_ms(_t_blocks, _t_merge):.2f} "
f"lm_head_loss={_ms(_t_merge, _t_end):.2f} "
f"total={_ms(_t0, _t_end):.2f} ms",
flush=True,
)
return out
logits = self.lm_head(x).float()
if _softcap_clamp:
logits = torch.clamp(logits, -softcap, softcap)
else:
logits = softcap * torch.tanh(logits / softcap)
return logits