"""PostSemClawModel — full-architecture model assembly. Extracted from the monolithic train.py (W1 modularization). Semantics unchanged. Imports `GPUEngram` from `hydra.engram` and `MuonAdamW` from `hydra.optimizer`. Triton kernel integration status (Phase 2): HYDRA_FUSED_BCNORM — DEFERRED. The bcnorm_fused Triton kernel fuses LayerNorm + RoPE on B/C projections. However, mamba-ssm's Mamba3 block uses RMSNormGated (not LayerNorm) for B/C, and RoPE is applied inside the mamba3_siso_combined CUDA kernel via the Angles parameter. Replacing would require either (a) monkey-patching RMSNormGated + intercepting the fused CUDA scan — invasive, 50+ lines, high breakage risk — or (b) a full custom Mamba3Block reimplementation. Both are out of scope for Phase 2. The kernel is validated standalone; integration deferred to Phase 3 when HYDRA moves to a custom SSM block. HYDRA_FUSED_SSD — DEFERRED. The ssd_exp_trap Triton kernel implements exponential-trapezoidal discretization as a sequential scan. mamba-ssm's Mamba3 block delegates the entire scan + gating + output projection to mamba3_siso_combined (a compiled CUDA kernel with tilelang). Replacing it would require decomposing the combined kernel into constituent ops and substituting only the scan — not feasible without a custom block. Same Phase 3 gate as above. Both env vars are accepted but currently no-ops (gates read, logged, but the code path is unchanged). This avoids silent regression if someone sets them expecting a speedup. """ from __future__ import annotations import os import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as _ckpt from mamba_ssm import Mamba3 from subsystems.hestia_mini import HestiaQAT from subsystems.htm import HTMLayer from subsystems.mhc_mini import ManifoldHyperConnection from subsystems.sdr_semantic import SemanticFoldingSDR from hydra.engram import GPUEngram from hydra.hyena_block import HyenaBlock # GDNBlock is imported lazily inside __init__ so the `fla` dependency is # only required when HYDRA_GDN_LAYERS is actually non-empty. Baseline # pure-Mamba3 runs continue to work without flash-linear-attention installed. from hydra.optimizer import MuonAdamW def norm(x: torch.Tensor) -> torch.Tensor: """RMSNorm over the last dim — stateless, autocast-friendly.""" return F.rms_norm(x, (x.size(-1),)) class PostSemClawModel(nn.Module): """Full Post-SEM-Claw model assembly. Architecture: Token Embedding -> [Mamba3 + residual] x n_layer -> SDR + Engram (at configured layer) -> norm -> LM head Interface (must match prepare.py evaluate_bpb): model(x, y, reduction='none').view(-1) -> per-token losses model(x, y, reduction='mean') -> scalar loss """ def __init__(self, config): super().__init__() self.config = config # Token embedding self.wte = nn.Embedding(config.vocab_size, config.d_model) # Mamba-3 blocks — official mamba-ssm fused CUDA kernel. No fallbacks. # RoPE is applied internally by the Mamba3 CUDA kernel via the Angles # parameter; external cos/sin buffers are not needed. # # Hyena supplement: layers whose index appears in `config.hyena_layers` # are instantiated as HyenaBlock instead of Mamba3. The config field # is populated from HYDRA_HYENA_LAYERS at construction time and then # persisted to checkpoints, so resume is safe even when the env var # is unset. Empty tuple → all-Mamba3, byte-identical to pre-port. _hyena_layer_set = set(getattr(config, "hyena_layers", ()) or ()) _gdn_layer_set = set(getattr(config, "gdn_layers", ()) or ()) # Hyena wins on overlap; conflict is logged at construction time. _both = _hyena_layer_set & _gdn_layer_set if _both: print(f"[WARN] layers in both hyena_layers and gdn_layers; using Hyena: {sorted(_both)}", flush=True) _gdn_layer_set -= _hyena_layer_set if _gdn_layer_set: from hydra.gdn_block import GDNBlock # requires `fla` package def _build_block(i: int) -> nn.Module: if i in _hyena_layer_set: return HyenaBlock( d_model=config.d_model, seq_len=config.sequence_len, order=int(os.environ.get("HYDRA_HYENA_ORDER", "2")), filter_order=int(os.environ.get("HYDRA_HYENA_FILTER_DIM", "64")), ) if i in _gdn_layer_set: return GDNBlock( d_model=config.d_model, n_heads=config.n_heads, ) block = Mamba3( d_model=config.d_model, d_state=config.d_state, expand=config.expand, headdim=config.headdim, is_mimo=False, # SISO path uses stable mamba3_siso_combined kernel chunk_size=int(os.environ.get("HYDRA_MAMBA3_CHUNK", "64")), # 64 is the validated default; 128 tripped a Triton autotune hang (>8min, no progress) is_outproj_norm=False, dtype=torch.bfloat16, ) # Fix dt_bias gradient starvation: Mamba3 init samples dt uniformly # in log-space from [0.001, 0.1], giving dt_bias in [-6.9, -2.25]. # softplus'(dt_bias) = sigmoid(dt_bias). At -6.9: 0.1% grad survives; # at -2.25: 9.5% survives. Shift up so sigmoid is 2-68% instead. # The SSM can still learn to make dt_bias more negative if finer # temporal resolution is needed — now the gradient will survive to # do so. with torch.no_grad(): block.dt_bias.add_(3.0) return block self.blocks = nn.ModuleList([_build_block(i) for i in range(config.n_layer)]) # Full-architecture SDR: offline semantic retina + STE (no-bypass). self.sdr_semantic = SemanticFoldingSDR( vocab_size=config.vocab_size, n_bits=config.sdr_n_bits, target_active=config.sdr_target_active, delta_rank=config.sdr_delta_rank, som_warmup_steps=config.sdr_som_warmup, som_update_interval=config.sdr_som_interval, ) # HTM spatial pooler + temporal memory (Rust, Hebbian). self.htm = HTMLayer( input_bits=config.sdr_n_bits, n_columns=config.htm_n_columns, cells_per_column=config.htm_cells_per_column, batch_size=1, # grows lazily to actual B on first forward seed=42, learn=True, reset_each_forward=True, ) # Gradient bridge: (n_columns + anomaly) -> d_model. self.htm_proj = nn.Linear(config.htm_n_columns + 1, config.d_model, bias=False) # GPU Engram with Hebbian writes — runs EVERY step. self.engram = GPUEngram( d_model=config.d_model, n_columns=config.engram_n_columns, max_ngram=3, ) self.engram_layer_idx = config.engram_layer_idx # SDR-to-d_model projection: routes differentiable SDR signal (STE) # into the engram's residual stream. One bf16 matmul per forward step # at (B*T, n_bits) @ (n_bits, d_model) — ~2.6M MACs vs 5B total, # <0.05% overhead. Zero extra memory (no intermediates materialized # beyond the matmul output). This single projection backpropagates LM # loss gradients through sdr_semantic.delta_u/delta_v, finally giving # the curated semantic retina real learning signal. self.sdr_proj = nn.Linear(config.sdr_n_bits, config.d_model, bias=False) # Manifold-Constrained Hyper-Connections (one per Mamba-3 block). self.mhc = nn.ModuleList([ ManifoldHyperConnection(d_model=config.d_model, n_streams=2, sinkhorn_iters=3) for _ in range(config.n_layer) ]) # Hestia QAT — ternary weight quantization applied post-optimizer-step. self.hestia = HestiaQAT(enabled=True, bits=1.58) # LM head self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) # Learnability knob 1: Multi-Token Prediction (Llama-3 style). # MTP_K=1 -> standard next-token. MTP_K>1 -> extra heads predict # tokens at positions t+1, t+2, ..., t+K. Heads are weight-tied to # lm_head (we share Parameters), so the only extra compute is # additional CE losses; no new params. Activated via HYDRA_MTP_K. self._mtp_k = max(1, int(os.environ.get("HYDRA_MTP_K", "1"))) # Learnability knob 3: gradient checkpointing on Mamba3 blocks. self._grad_ckpt = os.environ.get("HYDRA_GRAD_CKPT", "0") == "1" # Learnability knob 4: doc-separator BOS masking in packed sequences. self._doc_sep_mask = os.environ.get("HYDRA_DOC_SEP_MASK", "0") == "1" # BOS token id is looked up lazily on first forward (requires tokenizer # load); -1 means uninitialized. self._bos_token_id = -1 # Learnability knob 5: explicit stop-grad on HTM tensor (htm_rust # outputs already have requires_grad=False; this is defense-in-depth). self._htm_stop_grad = os.environ.get("HYDRA_HTM_STOP_GRAD", "0") == "1" # Learnability knob 6: entropy penalty coefficient on LM logits. self._entropy_penalty = float(os.environ.get("HYDRA_ENTROPY_PENALTY", "0.0")) # Residual dropout self.drop = nn.Dropout(float(os.environ.get("HYDRA_DROPOUT", "0.2"))) # Logits soft-capping self.softcap = 15.0 # Secondary metrics storage self._metrics = {} # Per-layer diagnostic panel. Env-gated; zero overhead when off. # Emits residual-contribution (delta_ratio), feature std, effective rank, # gradient norm per layer; used to identify minimum viable n_layer + find # entropy leakage / dead layers. See docs/depth-sweep.md. self._diag_enabled = os.environ.get("HYDRA_LAYER_DIAGNOSTICS", "0") == "1" self._diag_step = 0 self._diag_svd_every = int(os.environ.get("HYDRA_LAYER_DIAG_SVD_EVERY", "100")) # 3060 hot-path: cache env-driven flags at construction time so the # forward pass doesn't pay a dict-lookup + string-parse per step. # Each os.environ.get() call costs ~1-2us; with ~30 forwards/sec # this saves single-digit-percent overhead per training step. self._cache_profile_forward = os.environ.get("HYDRA_PROFILE_FORWARD", "0") == "1" self._cache_htm_subsample = int(os.environ.get("HYDRA_HTM_SUBSAMPLE", "8")) self._cache_htm_log_anomaly = os.environ.get("HYDRA_HTM_LOG_ANOMALY", "0") == "1" self._cache_softcap_clamp = os.environ.get("HYDRA_SOFTCAP_CLAMP", "0") == "1" self._cache_sampled_softmax_k = int(os.environ.get("HYDRA_SAMPLED_SOFTMAX", "4096")) self._cache_ce_chunk = int(os.environ.get("HYDRA_CE_CHUNK", "1024")) if self._diag_enabled: # Gradient-norm backward hooks on each Mamba3 block output. for _i, _block in enumerate(self.blocks): def _mk_grad_hook(_layer_idx): def _hook(module, grad_input, grad_output): if grad_output and grad_output[0] is not None: g = grad_output[0].detach() self._metrics[f'layer_{_layer_idx}_grad_norm'] = float( g.pow(2).mean().sqrt().item() ) return _hook _block.register_full_backward_hook(_mk_grad_hook(_i)) # Forward hooks on each Mamba3 block capture the block's OUTPUT # directly. This is the clean measurement: unlike merge_streams() # sampling which sees (streams + M*block_output) in bf16 — where # small block contributions round to zero against unit-norm # residuals — this captures `block_output` itself as produced. # Reports both its absolute RMS norm and its ratio to the block # INPUT's RMS norm (contribution magnitude relative to the # residual it's added to). for _i, _block in enumerate(self.blocks): def _mk_fwd_hook(_layer_idx): def _hook(module, inputs, output): with torch.no_grad(): inp = inputs[0].detach().float() if inputs else None out = output.detach().float() if isinstance(output, torch.Tensor) else None if out is not None: out_rms = out.pow(2).mean().sqrt().item() self._metrics[f'layer_{_layer_idx}_block_out_rms'] = float(out_rms) if inp is not None: in_rms = inp.pow(2).mean().sqrt().item() self._metrics[f'layer_{_layer_idx}_block_in_rms'] = float(in_rms) self._metrics[f'layer_{_layer_idx}_contrib_ratio'] = float( out_rms / (in_rms + 1e-8) ) return _hook _block.register_forward_hook(_mk_fwd_hook(_i)) # Triton kernel integration gates (Phase 2 — deferred, see module docstring). self._fused_bcnorm = os.environ.get("HYDRA_FUSED_BCNORM", "0") == "1" self._fused_ssd = os.environ.get("HYDRA_FUSED_SSD", "0") == "1" if self._fused_bcnorm or self._fused_ssd: import sys _active = [] if self._fused_bcnorm: _active.append("HYDRA_FUSED_BCNORM") if self._fused_ssd: _active.append("HYDRA_FUSED_SSD") print( f"[HYDRA] Triton kernel gates set: {', '.join(_active)}. " f"NOTE: Both are DEFERRED (mamba-ssm Mamba3 uses internal " f"CUDA kernels). Gates accepted but currently no-ops.", file=sys.stderr, ) # R6 optional torch.compile on the impl forward. Gated (default OFF). if os.environ.get("HYDRA_MODEL_COMPILE", "0") == "1": self._forward_impl = torch.compile( self._forward_impl, fullgraph=False, dynamic=True, mode="default", ) @torch.no_grad() def init_weights(self) -> None: s = 3 ** 0.5 * self.config.d_model ** -0.5 # Move SDR retina indices (plain attribute, not buffer) to same device as params. # Required because to_empty() only moves params/buffers, and _retina_indices # is loaded from numpy (always CPU) by SemanticFoldingSDR.__init__. device = self.wte.weight.device if hasattr(self.sdr_semantic, '_retina_indices'): self.sdr_semantic._retina_indices = self.sdr_semantic._retina_indices.to(device) # Embedding init: GPT-2 / LLaMA convention. std=1.0 was chosen for # vocab=8192; at larger vocabs, smaller std prevents logit blowup. # Use std = 1/sqrt(d_model) which scales sensibly with model width. import math as _math _d_model = self.wte.weight.shape[1] wte_std = float(os.environ.get("HYDRA_WTE_STD", str(1.0 / _math.sqrt(_d_model)))) nn.init.normal_(self.wte.weight, mean=0.0, std=wte_std) # LM head init: was std=0.001 — PATHOLOGICAL at vocab>=32k because # logits collapse to zero, loss locks at log(V)~=11, gradient through # head ∝ 1/V is too small to escape. GPT-2 uses std=0.02; LLaMA uses # std=1/sqrt(d_model). Pick 0.02 as robust default, env-overridable. lm_head_std = float(os.environ.get("HYDRA_LM_HEAD_STD", "0.02")) nn.init.normal_(self.lm_head.weight, mean=0.0, std=lm_head_std) # F8 (NOT APPLIED): Weight tying would save V*D params but current LR # groups have embedding_lr=1.0 and unembedding_lr=0.005 × d_model_scale # — tying forces the shared tensor under a single LR group and either # the embeddings learn 200x too slow (under unembed LR) or the LM head # becomes unstable (under embed LR). Short 15-step smoke with tying + # embed-group update showed initial loss jump 9 -> 20. Deferred until # LR groups are re-tuned; see docs/OPTIMIZATION_PLAN.md Post-plan. for li, block in enumerate(self.blocks): if hasattr(block, 'in_proj') and hasattr(block.in_proj, 'weight'): nn.init.uniform_(block.in_proj.weight, -s, s) if hasattr(block, 'out_proj') and hasattr(block.out_proj, 'weight'): # GPT-2 residual init: std = 0.02 / sqrt(2 * n_layer). # NOT zeros — zero init makes the block a permanent pass-through # (block_out_rms=0, zero gradient flow to SSM internals). # With non-zero init the block contributes to the residual stream # from step 1, giving the SSM scan actual gradient signal. n_layer = self.config.n_layer out_std = float(os.environ.get( "HYDRA_OUT_PROJ_STD", str(0.02 / (2 * n_layer) ** 0.5), )) nn.init.normal_(block.out_proj.weight, mean=0.0, std=out_std) nn.init.normal_(self.htm_proj.weight, mean=0.0, std=s) # SDR proj: tiny init preserves the existing residual dynamics at step 0. # The signal ramps up gradually as delta_u/delta_v learn via STE gradients. nn.init.normal_(self.sdr_proj.weight, mean=0.0, std=1e-4) # Cast to bf16 to match Mamba3 dtype; Muon groups by shape so mixed # dtypes in the same shape group would break lerp_ dtype checks. self.wte.to(dtype=torch.bfloat16) self.htm_proj.to(dtype=torch.bfloat16) self.sdr_proj.to(dtype=torch.bfloat16) self.engram.to(dtype=torch.bfloat16) def set_bos_token_id(self, bos_id: int) -> None: """Inform the model of the tokenizer's BOS id so doc-separator masking (learnability #4) knows which positions to skip. Called from training setup once the tokenizer is loaded.""" self._bos_token_id = int(bos_id) def invalidate_hyena_caches(self) -> None: """Invalidate filter-rfft caches on all Hyena blocks. MUST be called after each `optimizer.step()` when `HYDRA_HYENA_FILTER_CACHE=1` is set, otherwise cached rfft values will be reused with stale filter parameters. No-op for blocks that are not HyenaBlock (Mamba3, etc.). """ for block in self.blocks: if hasattr(block, "operator") and hasattr(block.operator, "invalidate_filter_cache"): block.operator.invalidate_filter_cache() def flush_hyena_pending_grads(self) -> None: """Push pending train-cache filter gradients into filter params. Used ONLY when HYDRA_HYENA_TRAIN_CACHE=1. Must be called exactly once per optimizer step, BEFORE `optimizer.step()` and BEFORE `invalidate_hyena_caches()`. The lightning_module wires this in `optimizer_step` around the existing optimizer.step() call. No-op if: * No HyenaBlocks are in the model, OR * No micro-batch ever ran with grad enabled (e.g. all-eval step). """ for block in self.blocks: if hasattr(block, "operator") and hasattr(block.operator, "flush_pending_filter_grads"): block.operator.flush_pending_filter_grads() def estimate_flops(self) -> int: nparams = sum(p.numel() for p in self.parameters()) embed_params = self.wte.weight.numel() return 6 * (nparams - embed_params) def num_scaling_params(self) -> dict: wte = sum(p.numel() for p in self.wte.parameters()) lm_head = sum(p.numel() for p in self.lm_head.parameters()) blocks = sum(p.numel() for p in self.blocks.parameters()) sdr = sum(p.numel() for p in self.sdr_semantic.parameters()) htm_proj = sum(p.numel() for p in self.htm_proj.parameters()) engram = sum(p.numel() for p in self.engram.parameters()) total = sum(p.numel() for p in self.parameters()) return { 'wte': wte, 'lm_head': lm_head, 'blocks': blocks, 'sdr_semantic': sdr, 'htm_proj': htm_proj, 'engram': engram, 'total': total, } def get_secondary_metrics(self) -> dict: """Flush any lingering CUDA tensors to host (single sync).""" flushed = {} for k, v in self._metrics.items(): if hasattr(v, 'item'): try: flushed[k] = float(v.item()) except Exception: flushed[k] = v else: flushed[k] = v return flushed def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.6, matrix_lr=0.04, weight_decay=0.2, adam_betas=(0.8, 0.95), scalar_lr=0.5): """Setup MuonAdamW optimizer with per-component LR groups.""" model_dim = self.config.d_model embedding_params = list(self.wte.parameters()) lm_head_params = list(self.lm_head.parameters()) # Muon routing guard: 2D parameters are NOT automatically matrices. # Exclude: # (a) params whose name ends in `.freq` — Sin frequency vectors used # by Hyena's implicit filter MLP. Shape (1, dim) is nominally 2D # but semantically a per-dim scalar. Muon's polar-express # orthogonalization would force it toward an orthogonal matrix, # destroying the learned modulation frequencies. # (b) 2-D params with min(shape) < MUON_MIN_DIM. Tiny projections # (e.g. HyenaFilter.implicit_filter.0.weight of shape (64, 3)) # get collapsed toward near-identity by orthogonalization on the # narrow axis, damaging expressivity. These belong in AdamW. # These exclusions route the params into the AdamW scalar/vector group. MUON_MIN_DIM = 8 def _muon_eligible(name: str, p: torch.Tensor) -> bool: if p.dim() != 2: return False if name.endswith(".freq"): return False if min(p.shape) < MUON_MIN_DIM: return False return True # Matrix params -> Muon (2D weight matrices passing the routing guard). matrix_params = [] for name, p in self.blocks.named_parameters(): if _muon_eligible(name, p): matrix_params.append(p) # NOTE (W1 audit REG-2): SemanticFoldingSDR.delta_u / delta_v are now # GRADIENT-ALIVE via the STE path reconnected in _forward_impl: the # differentiable SDR output feeds into sdr_proj which feeds into the # engram residual. Gradient flows: LM_loss -> engram -> ... -> sdr_proj # -> sdr_semantic.forward() -> STE -> delta_u/delta_v. These are 1D # params (not 2D matrices) so they automatically land in scalar AdamW # via the exclusion below — no special treatment needed. # for p in self.sdr_semantic.parameters(): # if p.dim() == 2: # matrix_params.append(p) for name, p in self.htm_proj.named_parameters(): if _muon_eligible(name, p): matrix_params.append(p) for name, p in self.engram.named_parameters(): if _muon_eligible(name, p): matrix_params.append(p) # SDR params are intentionally not in any optimizer group — they # receive no gradient in the current forward, so any update would be # pure noise (weight_decay × lr on a zero-grad param). sdr_param_ids = set(id(p) for p in self.sdr_semantic.parameters()) assigned = set(id(p) for p in embedding_params + lm_head_params + matrix_params) # Extract dt_bias from each Mamba3 block into its own high-LR group. # dt_bias controls SSM temporal discretization: softplus(dt_bias) gives # the step size delta_t. A single dt_bias per head is 1D — falls into # scalar_params by default. Give it a dedicated group with HYDRA_DT_BIAS_LR # so heads can differentiate their timescales instead of all tracking at # ln(2) across all layers. dt_bias_params = [p for name, p in self.blocks.named_parameters() if name.endswith('dt_bias')] dt_bias_ids = set(id(p) for p in dt_bias_params) if dt_bias_params else set() scalar_params = [ p for p in self.parameters() if id(p) not in assigned and id(p) not in sdr_param_ids and id(p) not in dt_bias_ids ] total_assigned = len(embedding_params) + len(lm_head_params) + len(matrix_params) + len(scalar_params) + len(dt_bias_params) total_params = len(list(self.parameters())) sdr_excluded = len(list(self.sdr_semantic.parameters())) assert total_assigned + sdr_excluded == total_params, ( f"Parameter count mismatch: assigned {total_assigned} + sdr_excluded " f"{sdr_excluded} vs total {total_params}" ) dmodel_lr_scale = (model_dim / 768) ** -0.5 print(f"Scaling AdamW LRs by 1/sqrt({model_dim}/768) = {dmodel_lr_scale:.6f}") param_groups = [ dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0), ] if scalar_params: param_groups.append( dict(kind='adamw', params=scalar_params, lr=scalar_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0) ) # dt_bias: dedicated group with embed-level LR so each head learns its # own temporal discretization. Env-overridable for sweeps. if dt_bias_params: _dt_bias_lr = float(os.environ.get("HYDRA_DT_BIAS_LR", str(embedding_lr))) param_groups.append(dict( kind='adamw', params=dt_bias_params, lr=_dt_bias_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0, )) for shape in sorted({p.shape for p in matrix_params}): group_params = [p for p in matrix_params if p.shape == shape] # ns_steps: Muon polar-express inner iterations. Default 5 (paper), # but 3 converges on small matrices (d_model ~ 384) with ~40% lower # optimizer step cost. Env-tunable for experimentation. _ns_steps = int(os.environ.get("HYDRA_MUON_NS_STEPS", "3")) param_groups.append(dict( kind='muon', params=group_params, lr=matrix_lr, momentum=0.95, ns_steps=_ns_steps, beta2=0.95, weight_decay=weight_decay, )) optimizer = MuonAdamW(param_groups) for group in optimizer.param_groups: group["initial_lr"] = group["lr"] return optimizer def forward(self, idx, targets=None, reduction='mean'): """idx: (B, T) int64. Returns loss if targets given, else logits. Nested bf16 autocast is a no-op when ambient autocast is already on; when it's off (e.g. integration tests) we establish the dtype contract. """ if torch.is_autocast_enabled(): return self._forward_impl(idx, targets=targets, reduction=reduction) with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): return self._forward_impl(idx, targets=targets, reduction=reduction) def _forward_impl(self, idx, targets=None, reduction='mean'): B, T = idx.shape # Diagnostic: per-subsystem CUDA event timing. Env-gated; zero overhead # when disabled. Logs one timing line per forward call. Used to isolate # which subsystem is the tps bottleneck on paid hardware. _profile = self._cache_profile_forward if _profile: def _ev(): e = torch.cuda.Event(enable_timing=True) e.record() return e _t0 = _ev() else: _t0 = None # Compute SDR ONCE via the differentiable STE forward. # The STE output is float (autocast dtype); we detach+cast a uint8 copy # for HTM (which only needs binary patterns). This replaces the old # binary_only() call — same scatter cost, but routes LM loss gradients # through sdr_semantic.delta_u/delta_v for the first time. sdr_ste = self.sdr_semantic(idx) # float, with STE grad path sdr_binary = sdr_ste.detach().to(dtype=torch.uint8) # detached uint8 for HTM self._last_sdr = sdr_binary # HTM subsampling: run HTM on 1 of every N micro-batches within a # gradient accumulation step, reuse the cached result for the other # N-1 micro-batches. Cooperative launch monopolizes all SMs (grid.sync # requires full-grid residency), so HTM and mamba can't overlap via # streams. Subsampling removes HTM from most micro-batches' critical # path instead. # # Math: N=8, 64 accum steps → 8 HTM calls (10.6ms each) + 56 fast # calls (4ms each). Total = 84.8 + 224 = 309ms → 106k tps. # # HYDRA_HTM_SUBSAMPLE=N (default 8). Set =1 for every-microbatch HTM. _htm_sub = self._cache_htm_subsample if not hasattr(self, '_htm_call_idx'): self._htm_call_idx = 0 _run_htm = (self._htm_call_idx % _htm_sub == 0) self._htm_call_idx += 1 if _run_htm: htm_handle = self.htm.forward_async(sdr_binary) else: htm_handle = None if _profile: _t_htm_async = _ev() dense_emb = self.wte(idx) # (B, T, d_model) bf16 if _profile: _t_wte = _ev() if _run_htm: htm_out = self.htm.forward_await(htm_handle) self._htm_cache = htm_out.detach() # cache for non-HTM micro-batches elif hasattr(self, '_htm_cache') and self._htm_cache is not None \ and self._htm_cache.shape[0] == B and self._htm_cache.shape[1] == T: htm_out = self._htm_cache else: # Very first call with subsample > 1: run HTM anyway. htm_handle = self.htm.forward_async(sdr_binary) htm_out = self.htm.forward_await(htm_handle) self._htm_cache = htm_out.detach() if _profile: _t_htm_await = _ev() # 3060 hot-path: skip the per-step htm_anomaly mean kernel + sync # unless explicitly requested via HYDRA_HTM_LOG_ANOMALY=1. Saves # one tiny GPU launch + one .item()-style copy per forward call. sdr_active_bits = float(self.sdr_semantic.target_active) if self._cache_htm_log_anomaly: with torch.no_grad(): htm_anomaly = htm_out[..., -1].mean() else: htm_anomaly = None # Learnability #5: explicit stop-grad on HTM output. htm_rust already # produces a detached tensor, but making it explicit here hardens the # contract against future refactors that might route HTM through a # grad-enabled op. if self._htm_stop_grad: htm_out = htm_out.detach() # Gradient bridge: HTM columns+anomaly -> d_model. # 3060 hot-path: only cast when dtypes actually differ (htm_out is # bf16 in normal training, dense_emb is bf16 under autocast — the # .to() in that case is a no-op but still costs Python overhead). if htm_out.dtype != dense_emb.dtype: htm_out = htm_out.to(dense_emb.dtype) htm_proj_out = self.htm_proj(htm_out) x = dense_emb + htm_proj_out x = norm(x) if _profile: _t_htm_proj = _ev() # mHC-routed Mamba-3 stack with Engram injection at configured layer. streams = self.mhc[0].init_streams(x) _engram_ev = None # Per-layer diagnostic panel. The pre-layer merged state h_pre lets us # measure residual contribution of each layer: delta_N = h_post - h_pre. # All reads are detached no-grad to avoid autograd graph pollution. _diag = self._diag_enabled if _diag: # Cast to float32 for the diagnostic arithmetic: the layer's # residual contribution is small (~0.5 × rms-normed block output), # which underflows in bf16 subtraction (3-digit mantissa) and # reports delta_ratio=0 at the boundaries. float32 snapshot is # ~3.8 MB extra memory per diag sample (B=1, T=2048, d=96) — # negligible vs peak VRAM. with torch.no_grad(): h_pre = self.mhc[0].merge_streams(streams).detach().float() _run_svd = (self._diag_step % self._diag_svd_every) == 0 # 3060 hot-path: hoist grad-ckpt branch + import out of the per-layer # loop. Module-level torch.utils.checkpoint is imported at the top of # the file (see _ckpt alias) — no per-step `import` cost. _use_ckpt = self._grad_ckpt and self.training for i, (block, mhc_layer) in enumerate(zip(self.blocks, self.mhc)): if _use_ckpt: def _block_fn(h, _block=block): return _ckpt.checkpoint( lambda _h: self.drop(_block(norm(_h))), h, use_reentrant=False ) else: def _block_fn(h, _block=block): return self.drop(_block(norm(h))) streams = mhc_layer(streams, _block_fn) if i == self.engram_layer_idx: if _profile: _t_pre_engram = _ev() x_mid = mhc_layer.merge_streams(streams) # Inject differentiable SDR signal into the engram query via # a lightweight projection (one bf16 matmul per forward step). # This is the first time LM loss gradients reach the SDR retina # via delta_u/delta_v — the curated semantic folding patterns # can now adapt during training instead of staying frozen. sdr_feat = self.sdr_proj(sdr_ste.to(x_mid.dtype)) # Norm the projection to prevent magnitude blowup: the raw STE # output has 327/16384 1.0 activations per token, and a single # matmul through sdr_proj (16384→256) with no normalization # grows weight norm from 1e-4 to ~182 within 2K steps, # overwhelming the engram residual. sdr_feat = norm(sdr_feat) x_mid = x_mid + sdr_feat x_mid, hit_rate = self.engram(x_mid, idx) streams = mhc_layer.init_streams(x_mid) self._metrics['engram_hit_rate'] = hit_rate if _profile: _engram_ev = _ev() if _diag: with torch.no_grad(): h_post = mhc_layer.merge_streams(streams).detach().float() in_n = h_pre.pow(2).mean().sqrt() out_n = h_post.pow(2).mean().sqrt() d_n = (h_post - h_pre).pow(2).mean().sqrt() self._metrics[f'layer_{i}_in_norm'] = float(in_n.item()) self._metrics[f'layer_{i}_out_norm'] = float(out_n.item()) self._metrics[f'layer_{i}_delta_ratio'] = float((d_n / (in_n + 1e-6)).item()) self._metrics[f'layer_{i}_feat_std'] = float(h_post.std(dim=-1).mean().item()) if _run_svd: # Effective rank via participation ratio of singular values. # eff_rank = (Σσ)^2 / Σσ² — smooth rank proxy, bounded by d_model. # Sampled to keep overhead low (SVD is O(min(B*T, D)^2·D)). flat = h_post.reshape(-1, h_post.shape[-1])[:512].float() try: s = torch.linalg.svdvals(flat) eff_rank = float(((s.sum() ** 2) / (s.pow(2).sum() + 1e-6)).item()) self._metrics[f'layer_{i}_eff_rank'] = eff_rank except Exception: pass h_pre = h_post if _diag: self._diag_step += 1 if _profile: _t_blocks = _ev() self._metrics['sdr_active_bits'] = sdr_active_bits if htm_anomaly is not None: self._metrics['htm_anomaly'] = htm_anomaly x = self.mhc[-1].merge_streams(streams) x = norm(x) if _profile: _t_merge = _ev() softcap = self.softcap _softcap_clamp = self._cache_softcap_clamp if targets is not None: smoothing = self.config.label_smoothing V = self.config.vocab_size # Learnability #4: doc-separator masking. In packed rows, # tokenizer.encode(..., prepend=bos_token) places a BOS at every # document boundary. Without masking, the model is penalized for # failing to predict "doc B's BOS" from the last tokens of doc A # — pure noise. We set targets==bos to -1 (ignore_index). Done # BEFORE MTP/entropy/sampled-softmax branches so all downstream # losses inherit the mask. if self._doc_sep_mask and self._bos_token_id >= 0: targets = torch.where( targets == self._bos_token_id, torch.full_like(targets, -1), targets, ) # Sampled softmax: instead of computing logits for ALL V tokens, # compute only for the target + K random negatives. Reduces the # lm_head matmul from (B*T, d) × (d, V) to (B*T, d) × (d, K+1). # At V=65536 and K=4096: 16× less compute, ~4× tps improvement. # The log-sum-exp correction adjusts for the sampling bias. # Set HYDRA_SAMPLED_SOFTMAX=0 to disable (full softmax). K_neg = self._cache_sampled_softmax_k use_sampled = K_neg > 0 and K_neg < V and self.training if use_sampled: # Flatten hidden states + targets h_flat = x.reshape(-1, x.shape[-1]) # (B*T, d) t_flat = targets.reshape(-1) # (B*T,) n = h_flat.shape[0] # Learnability #4 hardening: sampled-softmax gather crashes on # negative ids (-1 from doc-sep mask). Replace -1 with 0 for # gather; the actual loss is masked below. valid_mask_flat = (t_flat >= 0) t_flat_safe = torch.where(valid_mask_flat, t_flat, torch.zeros_like(t_flat)) # Sample K negatives uniformly from [0, V) neg_ids = torch.randint(0, V, (K_neg,), device=x.device) # Gather lm_head weights for target + negatives all_ids = torch.cat([t_flat_safe, neg_ids]) # (B*T + K,) sampled_w = self.lm_head.weight[all_ids] # (B*T + K, d) # Compute sampled logits: for each position, dot with its # target weight and all K negative weights. # Target logit: dot product of h[i] with w[target[i]] target_w = sampled_w[:n] # (B*T, d) neg_w = sampled_w[n:] # (K, d) target_logit = (h_flat * target_w).sum(-1) # (B*T,) neg_logits = h_flat @ neg_w.t() # (B*T, K) if not _softcap_clamp: target_logit = softcap * torch.tanh(target_logit / softcap) neg_logits = softcap * torch.tanh(neg_logits / softcap) # Sampled softmax loss: -log(exp(target) / (exp(target) + sum(exp(neg)))) # With log-sum-exp correction for sampling K of V negatives. # Correction: add log(V/K) to negative logits to account for # the fact that we're only seeing K of V possible negatives. log_correction = torch.tensor(V / K_neg, device=x.device).log() all_logits = torch.cat([ target_logit.unsqueeze(-1), # (B*T, 1) neg_logits + log_correction, # (B*T, K) ], dim=-1).float() # (B*T, K+1) # CE with target always at index 0 ce_targets = torch.zeros(n, dtype=torch.long, device=x.device) if reduction == 'none': per_tok = F.cross_entropy(all_logits, ce_targets, reduction='none') if self._doc_sep_mask and self._bos_token_id >= 0: per_tok = torch.where(valid_mask_flat, per_tok, torch.zeros_like(per_tok)) return per_tok per_tok_ce = F.cross_entropy( all_logits, ce_targets, reduction='none', label_smoothing=smoothing, ) # Mask doc-separator positions. valid_mask_flat is always # computed; when doc_sep_mask is off every token is valid so # this reduces to a plain mean. valid_f = valid_mask_flat.float() valid_n = valid_f.sum().clamp(min=1) out = (per_tok_ce * valid_f).sum() / valid_n else: # Full softmax path (eval or HYDRA_SAMPLED_SOFTMAX=0) chunk_size = self._cache_ce_chunk if chunk_size <= 0: MAX_LOGITS_BYTES = 256 * 1024 * 1024 tokens_per_chunk = max(V, MAX_LOGITS_BYTES // (V * 4)) chunk_size = max(1, tokens_per_chunk // max(1, B)) chunk_size = min(chunk_size, T) if reduction == 'none': loss_parts = [] for start in range(0, T, chunk_size): end = min(start + chunk_size, T) chunk_logits = self.lm_head(x[:, start:end, :]).float() if _softcap_clamp: chunk_logits = torch.clamp(chunk_logits, -softcap, softcap) else: chunk_logits = softcap * torch.tanh(chunk_logits / softcap) chunk_targets = targets[:, start:end].reshape(-1) chunk_loss = F.cross_entropy( chunk_logits.view(-1, chunk_logits.size(-1)), chunk_targets, ignore_index=-1, reduction='none', ) loss_parts.append(chunk_loss) return torch.cat(loss_parts) total_loss = 0.0 total_tokens = 0 for start in range(0, T, chunk_size): end = min(start + chunk_size, T) chunk_logits = self.lm_head(x[:, start:end, :]).float() if _softcap_clamp: chunk_logits = torch.clamp(chunk_logits, -softcap, softcap) else: chunk_logits = softcap * torch.tanh(chunk_logits / softcap) chunk_targets = targets[:, start:end].reshape(-1) chunk_loss = F.cross_entropy( chunk_logits.view(-1, chunk_logits.size(-1)), chunk_targets, ignore_index=-1, reduction='sum', label_smoothing=smoothing, ) total_loss = total_loss + chunk_loss total_tokens += (chunk_targets != -1).sum() out = total_loss / total_tokens # ----------------------------------------------------------- # Learnability #1: Multi-Token Prediction. # For k in {2..K}, add a CE loss at position (t) predicting # the token at position (t+k), using the SAME lm_head weights # (weight-tied). Cost: K-1 extra CEs on a subset of positions. # Only triggered in reduction='mean' path, training only. # ----------------------------------------------------------- if reduction == 'mean' and self._mtp_k > 1 and self.training and use_sampled: # TRUE zero-cost MTP: reuse primary's neg_logits (B*T, K_neg) # entirely. Only cost per extra head: O(B*T*d) target-weight # gather + dot product. neg_logits is sliced (view) to match. mtp_loss_sum = out.new_tensor(0.0) mtp_terms = 0 # Reshape primary neg_logits back to (B, T, K_neg) so we can slice positions neg_logits_bt = neg_logits.view(B, T, K_neg) for k in range(2, self._mtp_k + 1): shift = k - 1 if T <= shift: continue n_k = B * (T - shift) h_k_flat = x[:, :T - shift, :].reshape(n_k, -1) # (n_k, d) t_k = targets[:, shift:].reshape(-1) # (n_k,) mask_k = (t_k >= 0) t_k_safe = torch.where(mask_k, t_k, torch.zeros_like(t_k)) tgt_w_k = self.lm_head.weight[t_k_safe] # (n_k, d) tgt_logit_k = (h_k_flat * tgt_w_k).sum(-1) # (n_k,) if not _softcap_clamp: tgt_logit_k = softcap * torch.tanh(tgt_logit_k / softcap) # REUSE primary neg_logits — slice positions [:T-shift] neg_logits_k = neg_logits_bt[:, :T - shift, :].reshape(n_k, K_neg) all_logits_k = torch.cat([ tgt_logit_k.unsqueeze(-1), neg_logits_k + log_correction, ], dim=-1).float() ce_targets_k = torch.zeros(n_k, dtype=torch.long, device=x.device) per_tok_ce_k = F.cross_entropy( all_logits_k, ce_targets_k, reduction='none', label_smoothing=smoothing, ) per_tok_ce_k = torch.where(mask_k, per_tok_ce_k, torch.zeros_like(per_tok_ce_k)) n_valid_k = mask_k.sum().clamp(min=1) mtp_loss_sum = mtp_loss_sum + per_tok_ce_k.sum() / n_valid_k mtp_terms += 1 if mtp_terms > 0: out = (out + mtp_loss_sum) / float(mtp_terms + 1) # ----------------------------------------------------------- # Learnability #6: output entropy penalty. # L += -lambda * H(softmax(logits)). Negative entropy penalizes # peaked distributions; encourages diverse predictions and # breaks repetition loops. Computed on a small subset of # positions to keep V-sized logits cost bounded. # ----------------------------------------------------------- if reduction == 'mean' and self._entropy_penalty > 0.0 and self.training: # Sample up to 64 random positions. V-sized logits on 64 # positions = 64 * V * 4 bytes (~50 MB at V=200k) — fits # on the 3060 and adds ~2 ms. h_flat = x.reshape(-1, x.shape[-1]) n_pos = h_flat.shape[0] n_sample = min(64, n_pos) idx_sample = torch.randint(0, n_pos, (n_sample,), device=x.device) h_sample = h_flat[idx_sample] logits_s = F.linear(h_sample, self.lm_head.weight).float() if _softcap_clamp: logits_s = torch.clamp(logits_s, -softcap, softcap) else: logits_s = softcap * torch.tanh(logits_s / softcap) log_probs = F.log_softmax(logits_s, dim=-1) probs = log_probs.exp() entropy = -(probs * log_probs).sum(-1).mean() # scalar, nats out = out - self._entropy_penalty * entropy if _profile: _t_end = _ev() torch.cuda.synchronize() def _ms(a, b): return a.elapsed_time(b) print( f"[PROFILE B={B} T={T}] " f"htm_launch={_ms(_t0, _t_htm_async):.2f} " f"wte={_ms(_t_htm_async, _t_wte):.2f} " f"htm_await={_ms(_t_wte, _t_htm_await):.2f} " f"htm_proj={_ms(_t_htm_await, _t_htm_proj):.2f} " f"mamba_mhc_engram={_ms(_t_htm_proj, _t_blocks):.2f} " f"merge={_ms(_t_blocks, _t_merge):.2f} " f"lm_head_loss={_ms(_t_merge, _t_end):.2f} " f"total={_ms(_t0, _t_end):.2f} ms", flush=True, ) return out logits = self.lm_head(x).float() if _softcap_clamp: logits = torch.clamp(logits, -softcap, softcap) else: logits = softcap * torch.tanh(logits / softcap) return logits