Spaces:
Runtime error
Runtime error
| """PostSemClawModel — full-architecture model assembly. | |
| Extracted from the monolithic train.py (W1 modularization). Semantics | |
| unchanged. Imports `GPUEngram` from `hydra.engram` and `MuonAdamW` from | |
| `hydra.optimizer`. | |
| Triton kernel integration status (Phase 2): | |
| HYDRA_FUSED_BCNORM — DEFERRED. The bcnorm_fused Triton kernel fuses | |
| LayerNorm + RoPE on B/C projections. However, mamba-ssm's Mamba3 block | |
| uses RMSNormGated (not LayerNorm) for B/C, and RoPE is applied inside | |
| the mamba3_siso_combined CUDA kernel via the Angles parameter. Replacing | |
| would require either (a) monkey-patching RMSNormGated + intercepting the | |
| fused CUDA scan — invasive, 50+ lines, high breakage risk — or (b) a | |
| full custom Mamba3Block reimplementation. Both are out of scope for | |
| Phase 2. The kernel is validated standalone; integration deferred to | |
| Phase 3 when HYDRA moves to a custom SSM block. | |
| HYDRA_FUSED_SSD — DEFERRED. The ssd_exp_trap Triton kernel implements | |
| exponential-trapezoidal discretization as a sequential scan. mamba-ssm's | |
| Mamba3 block delegates the entire scan + gating + output projection to | |
| mamba3_siso_combined (a compiled CUDA kernel with tilelang). Replacing | |
| it would require decomposing the combined kernel into constituent ops | |
| and substituting only the scan — not feasible without a custom block. | |
| Same Phase 3 gate as above. | |
| Both env vars are accepted but currently no-ops (gates read, logged, but | |
| the code path is unchanged). This avoids silent regression if someone | |
| sets them expecting a speedup. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| try: | |
| from mamba_ssm import Mamba3 | |
| except ModuleNotFoundError: # local CPU tests may run outside the HF image wheel stack | |
| Mamba3 = None | |
| from subsystems.hestia_mini import HestiaQAT | |
| from subsystems.htm import HTMLayer | |
| from subsystems.mhc_mini import ManifoldHyperConnection | |
| from subsystems.sdr_semantic import SemanticFoldingSDR | |
| from subsystems.fused_sdr_project import FusedSDRProject | |
| from subsystems.cantor_router import CantorRouter | |
| from hydra.engram import GPUEngram | |
| from hydra.reality_bridge import RealityPoincareBridge | |
| from hydra.hyena_block import HyenaBlock | |
| # GDNBlock is imported lazily inside __init__ so the `fla` dependency is | |
| # only required when HYDRA_GDN_LAYERS is actually non-empty. Baseline | |
| # pure-Mamba3 runs continue to work without flash-linear-attention installed. | |
| from hydra.optimizer import MuonAdamW | |
| FLOAT32_BYTES = torch.finfo(torch.float32).bits // 8 | |
| def norm(x: torch.Tensor) -> torch.Tensor: | |
| """RMSNorm over the last dim — stateless, autocast-friendly.""" | |
| return F.rms_norm(x, (x.size(-1),)) | |
| def semantic_gaussian_mollify( | |
| x: torch.Tensor, | |
| std: float, | |
| training: bool, | |
| eval_enabled: bool = False, | |
| ) -> torch.Tensor: | |
| """Tiny Gaussian semantic smoothing gate for SDR/Engram queries. | |
| Default identity; train-only unless explicitly enabled for eval. This acts | |
| as local mollification around the discrete SDR/Cantor seam without changing | |
| checkpoint shapes. | |
| """ | |
| if std <= 0.0 or (not training and not eval_enabled): | |
| return x | |
| return x + torch.randn_like(x) * std | |
| def paired_slow_fast_orthogonality(w: torch.Tensor) -> torch.Tensor: | |
| """Cheap W_slow ⊕ W_fast row-pair orthogonality proxy.""" | |
| if w.dim() != 2 or w.shape[0] < 2: | |
| return w.new_zeros(()) | |
| slow = w[0::2].float() | |
| fast = w[1::2].float() | |
| n = min(slow.shape[0], fast.shape[0]) | |
| if n == 0: | |
| return w.new_zeros(()) | |
| slow = F.normalize(slow[:n], dim=-1, eps=1e-6) | |
| fast = F.normalize(fast[:n], dim=-1, eps=1e-6) | |
| return (slow * fast).sum(dim=-1).square().mean() | |
| class PostSemClawModel(nn.Module): | |
| """Full Post-SEM-Claw model assembly. | |
| Architecture: | |
| Token Embedding -> [Mamba3 + residual] x n_layer | |
| -> SDR + Engram (at configured layer) -> norm -> LM head | |
| Interface (must match prepare.py evaluate_bpb): | |
| model(x, y, reduction='none').view(-1) -> per-token losses | |
| model(x, y, reduction='mean') -> scalar loss | |
| """ | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| # Token embedding | |
| self.wte = nn.Embedding(config.vocab_size, config.d_model) | |
| # Mamba-3 blocks — official mamba-ssm fused CUDA kernel. No fallbacks. | |
| # RoPE is applied internally by the Mamba3 CUDA kernel via the Angles | |
| # parameter; external cos/sin buffers are not needed. | |
| # | |
| # Hyena supplement: layers whose index appears in `config.hyena_layers` | |
| # are instantiated as HyenaBlock instead of Mamba3. The config field | |
| # is populated from HYDRA_HYENA_LAYERS at construction time and then | |
| # persisted to checkpoints, so resume is safe even when the env var | |
| # is unset. Empty tuple → all-Mamba3, byte-identical to pre-port. | |
| _hyena_layer_set = set(getattr(config, "hyena_layers", ()) or ()) | |
| _gdn_layer_set = set(getattr(config, "gdn_layers", ()) or ()) | |
| # Hyena wins on overlap; conflict is logged at construction time. | |
| _both = _hyena_layer_set & _gdn_layer_set | |
| if _both: | |
| print(f"[WARN] layers in both hyena_layers and gdn_layers; using Hyena: {sorted(_both)}", flush=True) | |
| _gdn_layer_set -= _hyena_layer_set | |
| if _gdn_layer_set: | |
| from hydra.gdn_block import GDNBlock # requires `fla` package | |
| def _build_block(i: int) -> nn.Module: | |
| if i in _hyena_layer_set: | |
| return HyenaBlock( | |
| d_model=config.d_model, | |
| seq_len=config.sequence_len, | |
| order=int(os.environ.get("HYDRA_HYENA_ORDER", "2")), | |
| filter_order=int(os.environ.get("HYDRA_HYENA_FILTER_DIM", "64")), | |
| ) | |
| if i in _gdn_layer_set: | |
| return GDNBlock( | |
| d_model=config.d_model, | |
| n_heads=config.n_heads, | |
| ) | |
| if Mamba3 is None: | |
| raise RuntimeError( | |
| "mamba_ssm is required for Mamba3 layers; set hyena_layers/gdn_layers " | |
| "to cover every layer or run inside the HF runtime image." | |
| ) | |
| block = Mamba3( | |
| d_model=config.d_model, | |
| d_state=config.d_state, | |
| expand=config.expand, | |
| headdim=config.headdim, | |
| is_mimo=False, # SISO path uses stable mamba3_siso_combined kernel | |
| chunk_size=int(os.environ.get("HYDRA_MAMBA3_CHUNK", "64")), # 64 is the validated default; 128 tripped a Triton autotune hang (>8min, no progress) | |
| is_outproj_norm=False, | |
| dtype=torch.bfloat16, | |
| ) | |
| # Fix dt_bias gradient starvation: Mamba3 init samples dt uniformly | |
| # in log-space from [0.001, 0.1], giving dt_bias in [-6.9, -2.25]. | |
| # softplus'(dt_bias) = sigmoid(dt_bias). At -6.9: 0.1% grad survives; | |
| # at -2.25: 9.5% survives. Shift up so sigmoid is 2-68% instead. | |
| # The SSM can still learn to make dt_bias more negative if finer | |
| # temporal resolution is needed — now the gradient will survive to | |
| # do so. | |
| with torch.no_grad(): | |
| block.dt_bias.add_(3.0) | |
| return block | |
| self.blocks = nn.ModuleList([_build_block(i) for i in range(config.n_layer)]) | |
| # Full-architecture SDR: offline semantic retina + STE (no-bypass). | |
| self.sdr_semantic = SemanticFoldingSDR( | |
| vocab_size=config.vocab_size, | |
| n_bits=config.sdr_n_bits, | |
| target_active=config.sdr_target_active, | |
| delta_rank=config.sdr_delta_rank, | |
| som_warmup_steps=config.sdr_som_warmup, | |
| som_update_interval=config.sdr_som_interval, | |
| ) | |
| # HTM spatial pooler + temporal memory (Rust, Hebbian). | |
| self.htm = HTMLayer( | |
| input_bits=config.sdr_n_bits, | |
| n_columns=config.htm_n_columns, | |
| cells_per_column=config.htm_cells_per_column, | |
| batch_size=1, # grows lazily to actual B on first forward | |
| seed=42, | |
| learn=True, | |
| reset_each_forward=True, | |
| ) | |
| # Gradient bridge: (n_columns + anomaly) -> d_model. | |
| self.htm_proj = nn.Linear(config.htm_n_columns + 1, config.d_model, bias=False) | |
| # GPU Engram with Hebbian writes — runs EVERY step. | |
| self.engram = GPUEngram( | |
| d_model=config.d_model, | |
| n_columns=config.engram_n_columns, | |
| max_ngram=3, | |
| ) | |
| self.engram_layer_idx = config.engram_layer_idx | |
| # Cantor router: gradient-free topological routing engine. | |
| # Partitions query space into 2^depth leaves (default 128). | |
| # Each leaf constrains which Engram columns are eligible for | |
| # retrieval — replacing the flat top-k with a geometric partition. | |
| # Phase 1: static branching vectors, zero learnable parameters. | |
| _cantor_depth = int(os.environ.get("HYDRA_CANTOR_DEPTH", "7")) | |
| self.cantor = CantorRouter( | |
| depth=_cantor_depth, | |
| d_query=config.d_model, | |
| device="cuda" if torch.cuda.is_available() else "cpu", | |
| ) | |
| self._cantor_enabled = os.environ.get("HYDRA_CANTOR_DISABLE", "0") != "1" | |
| # SDR-to-d_model projection: routes differentiable SDR signal (STE) | |
| # into the engram's residual stream. One bf16 matmul per forward step | |
| # at (B*T, n_bits) @ (n_bits, d_model) — ~2.6M MACs vs 5B total, | |
| # <0.05% overhead. Zero extra memory (no intermediates materialized | |
| # beyond the matmul output). This single projection backpropagates LM | |
| # loss gradients through sdr_semantic.delta_u/delta_v, finally giving | |
| # the curated semantic retina real learning signal. | |
| self.sdr_proj = nn.Linear(config.sdr_n_bits, config.d_model, bias=False) | |
| # SEM-Claw Reality/Poincare bridge. Enabled by default: emits a | |
| # compact int16 L0 active-index buffer for Engram/Cantor routing and a | |
| # differentiable Poincare coordinate for metrics/regularizers. Set | |
| # HYDRA_REALITY_BRIDGE=0 to fall back to retina active indices only. | |
| self._reality_bridge_enabled = os.environ.get("HYDRA_REALITY_BRIDGE", "1") != "0" | |
| if self._reality_bridge_enabled: | |
| self.reality_bridge = RealityPoincareBridge( | |
| d_model=config.d_model, | |
| d_reality=int(os.environ.get("HYDRA_REALITY_DIM", "133")), | |
| l0_k=int(os.environ.get("HYDRA_REALITY_L0_K", "64")), | |
| ) | |
| else: | |
| self.reality_bridge = None | |
| # Manifold-Constrained Hyper-Connections (one per Mamba-3 block). | |
| self.mhc = nn.ModuleList([ | |
| ManifoldHyperConnection(d_model=config.d_model, n_streams=2, sinkhorn_iters=3) | |
| for _ in range(config.n_layer) | |
| ]) | |
| # Hestia QAT — ternary weight quantization applied post-optimizer-step. | |
| self.hestia = HestiaQAT(enabled=True, bits=1.58) | |
| # LM head | |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) | |
| # Learnability knob 1: Multi-Token Prediction (Llama-3 style). | |
| # MTP_K=1 -> standard next-token. MTP_K>1 -> extra heads predict | |
| # tokens at positions t+1, t+2, ..., t+K. Heads are weight-tied to | |
| # lm_head (we share Parameters), so the only extra compute is | |
| # additional CE losses; no new params. Activated via HYDRA_MTP_K. | |
| self._mtp_k = max(1, int(os.environ.get("HYDRA_MTP_K", "1"))) | |
| # Learnability knob 3: gradient checkpointing on Mamba3 blocks. | |
| self._grad_ckpt = os.environ.get("HYDRA_GRAD_CKPT", "0") == "1" | |
| # Learnability knob 4: doc-separator BOS masking in packed sequences. | |
| self._doc_sep_mask = os.environ.get("HYDRA_DOC_SEP_MASK", "0") == "1" | |
| # BOS token id is looked up lazily on first forward (requires tokenizer | |
| # load); -1 means uninitialized. | |
| self._bos_token_id = -1 | |
| # Learnability knob 5: explicit stop-grad on HTM tensor (htm_rust | |
| # outputs already have requires_grad=False; this is defense-in-depth). | |
| self._htm_stop_grad = os.environ.get("HYDRA_HTM_STOP_GRAD", "0") == "1" | |
| # Learnability knob 6: entropy penalty coefficient on LM logits. | |
| self._entropy_penalty = float(os.environ.get("HYDRA_ENTROPY_PENALTY", "0.0")) | |
| # SEM-Claw upgrades: Gaussian query mollification and W_slow ⊕ W_fast | |
| # orthogonality probe/regularizer are enabled by default. Set the env | |
| # values to 0 to disable during ablations. | |
| self._semantic_smooth_std = float(os.environ.get("HYDRA_SEMANTIC_SMOOTH_STD", "0.01")) | |
| self._semantic_smooth_eval = os.environ.get("HYDRA_SEMANTIC_SMOOTH_EVAL", "0") == "1" | |
| self._sf_ortho_lambda = float(os.environ.get("HYDRA_SLOW_FAST_ORTHO_LAMBDA", "1e-4")) | |
| self._sf_ortho_metrics = os.environ.get("HYDRA_SLOW_FAST_ORTHO_METRICS", "1") != "0" | |
| self._sf_ortho_every = max(1, int(os.environ.get("HYDRA_SLOW_FAST_ORTHO_EVERY", "100"))) | |
| self._sf_ortho_step = 0 | |
| self._sf_ortho_targets = tuple( | |
| s.strip() for s in os.environ.get( | |
| "HYDRA_SLOW_FAST_ORTHO_TARGETS", | |
| "cantor,engram,block_out_proj", | |
| ).split(",") if s.strip() | |
| ) | |
| # Residual dropout | |
| self.drop = nn.Dropout(float(os.environ.get("HYDRA_DROPOUT", "0.2"))) | |
| # Logits soft-capping | |
| self.softcap = 15.0 | |
| # Secondary metrics storage | |
| self._metrics = {} | |
| # Cantor leaf utilization is an in-training fidelity metric. The final | |
| # run summary can execute tiny validation/factual-probe forwards after | |
| # training, so a single last-forward leaf count can falsely look | |
| # collapsed (e.g. 2/128 leaves from a handful of probe prompts). Track | |
| # the maximum seen during training separately from the instantaneous | |
| # last-forward value. | |
| self._cantor_active_leaves_train_max = 0 | |
| # Engram hit-rate has the same last-forward overwrite hazard: final | |
| # validation/factual forwards can be tiny or distribution-shifted, so | |
| # preserve training-window max/mean alongside the instantaneous value. | |
| self._engram_hit_rate_train_max = 0.0 | |
| self._engram_hit_rate_train_sum = 0.0 | |
| self._engram_hit_rate_train_count = 0 | |
| # Per-layer diagnostic panel. Env-gated; zero overhead when off. | |
| # Emits residual-contribution (delta_ratio), feature std, effective rank, | |
| # gradient norm per layer; used to identify minimum viable n_layer + find | |
| # entropy leakage / dead layers. See docs/depth-sweep.md. | |
| self._diag_enabled = os.environ.get("HYDRA_LAYER_DIAGNOSTICS", "0") == "1" | |
| self._diag_step = 0 | |
| self._diag_svd_every = int(os.environ.get("HYDRA_LAYER_DIAG_SVD_EVERY", "100")) | |
| if self._diag_enabled: | |
| # Gradient-norm backward hooks on each Mamba3 block output. | |
| for _i, _block in enumerate(self.blocks): | |
| def _mk_grad_hook(_layer_idx): | |
| def _hook(module, grad_input, grad_output): | |
| if grad_output and grad_output[0] is not None: | |
| g = grad_output[0].detach() | |
| self._metrics[f'layer_{_layer_idx}_grad_norm'] = float( | |
| g.pow(2).mean().sqrt().item() | |
| ) | |
| return _hook | |
| _block.register_full_backward_hook(_mk_grad_hook(_i)) | |
| # Forward hooks on each Mamba3 block capture the block's OUTPUT | |
| # directly. This is the clean measurement: unlike merge_streams() | |
| # sampling which sees (streams + M*block_output) in bf16 — where | |
| # small block contributions round to zero against unit-norm | |
| # residuals — this captures `block_output` itself as produced. | |
| # Reports both its absolute RMS norm and its ratio to the block | |
| # INPUT's RMS norm (contribution magnitude relative to the | |
| # residual it's added to). | |
| for _i, _block in enumerate(self.blocks): | |
| def _mk_fwd_hook(_layer_idx): | |
| def _hook(module, inputs, output): | |
| with torch.no_grad(): | |
| inp = inputs[0].detach().float() if inputs else None | |
| out = output.detach().float() if isinstance(output, torch.Tensor) else None | |
| if out is not None: | |
| out_rms = out.pow(2).mean().sqrt().item() | |
| self._metrics[f'layer_{_layer_idx}_block_out_rms'] = float(out_rms) | |
| if inp is not None: | |
| in_rms = inp.pow(2).mean().sqrt().item() | |
| self._metrics[f'layer_{_layer_idx}_block_in_rms'] = float(in_rms) | |
| self._metrics[f'layer_{_layer_idx}_contrib_ratio'] = float( | |
| out_rms / (in_rms + 1e-8) | |
| ) | |
| return _hook | |
| _block.register_forward_hook(_mk_fwd_hook(_i)) | |
| # Triton kernel integration gates (Phase 2 — deferred, see module docstring). | |
| self._fused_bcnorm = os.environ.get("HYDRA_FUSED_BCNORM", "0") == "1" | |
| self._fused_ssd = os.environ.get("HYDRA_FUSED_SSD", "0") == "1" | |
| if self._fused_bcnorm or self._fused_ssd: | |
| import sys | |
| _active = [] | |
| if self._fused_bcnorm: | |
| _active.append("HYDRA_FUSED_BCNORM") | |
| if self._fused_ssd: | |
| _active.append("HYDRA_FUSED_SSD") | |
| print( | |
| f"[HYDRA] Triton kernel gates set: {', '.join(_active)}. " | |
| f"NOTE: Both are DEFERRED (mamba-ssm Mamba3 uses internal " | |
| f"CUDA kernels). Gates accepted but currently no-ops.", | |
| file=sys.stderr, | |
| ) | |
| # R6 optional torch.compile on the impl forward. Gated (default OFF). | |
| if os.environ.get("HYDRA_MODEL_COMPILE", "0") == "1": | |
| self._forward_impl = torch.compile( | |
| self._forward_impl, | |
| fullgraph=False, | |
| dynamic=True, | |
| mode="default", | |
| ) | |
| def init_weights(self) -> None: | |
| s = 3 ** 0.5 * self.config.d_model ** -0.5 | |
| # Move SDR retina indices (plain attribute, not buffer) to same device as params. | |
| # Required because to_empty() only moves params/buffers, and _retina_indices | |
| # is loaded from numpy (always CPU) by SemanticFoldingSDR.__init__. | |
| device = self.wte.weight.device | |
| if hasattr(self.sdr_semantic, '_retina_indices'): | |
| self.sdr_semantic._retina_indices = self.sdr_semantic._retina_indices.to(device) | |
| # Embedding init: GPT-2 / LLaMA convention. std=1.0 was chosen for | |
| # vocab=8192; at larger vocabs, smaller std prevents logit blowup. | |
| # Use std = 1/sqrt(d_model) which scales sensibly with model width. | |
| import math as _math | |
| _d_model = self.wte.weight.shape[1] | |
| wte_std = float(os.environ.get("HYDRA_WTE_STD", str(1.0 / _math.sqrt(_d_model)))) | |
| nn.init.normal_(self.wte.weight, mean=0.0, std=wte_std) | |
| # LM head init: was std=0.001 — PATHOLOGICAL at vocab>=32k because | |
| # logits collapse to zero, loss locks at log(V)~=11, gradient through | |
| # head ∝ 1/V is too small to escape. GPT-2 uses std=0.02; LLaMA uses | |
| # std=1/sqrt(d_model). Pick 0.02 as robust default, env-overridable. | |
| lm_head_std = float(os.environ.get("HYDRA_LM_HEAD_STD", "0.02")) | |
| nn.init.normal_(self.lm_head.weight, mean=0.0, std=lm_head_std) | |
| # F8 (NOT APPLIED): Weight tying would save V*D params but current LR | |
| # groups have embedding_lr=1.0 and unembedding_lr=0.005 × d_model_scale | |
| # — tying forces the shared tensor under a single LR group and either | |
| # the embeddings learn 200x too slow (under unembed LR) or the LM head | |
| # becomes unstable (under embed LR). Short 15-step smoke with tying + | |
| # embed-group update showed initial loss jump 9 -> 20. Deferred until | |
| # LR groups are re-tuned; see docs/OPTIMIZATION_PLAN.md Post-plan. | |
| for li, block in enumerate(self.blocks): | |
| if hasattr(block, 'in_proj') and hasattr(block.in_proj, 'weight'): | |
| nn.init.uniform_(block.in_proj.weight, -s, s) | |
| if hasattr(block, 'out_proj') and hasattr(block.out_proj, 'weight'): | |
| # GPT-2 residual init: std = 0.02 / sqrt(2 * n_layer). | |
| # NOT zeros — zero init makes the block a permanent pass-through | |
| # (block_out_rms=0, zero gradient flow to SSM internals). | |
| # With non-zero init the block contributes to the residual stream | |
| # from step 1, giving the SSM scan actual gradient signal. | |
| n_layer = self.config.n_layer | |
| out_std = float(os.environ.get( | |
| "HYDRA_OUT_PROJ_STD", | |
| str(0.02 / (2 * n_layer) ** 0.5), | |
| )) | |
| nn.init.normal_(block.out_proj.weight, mean=0.0, std=out_std) | |
| nn.init.normal_(self.htm_proj.weight, mean=0.0, std=s) | |
| # SDR proj: tiny init preserves the existing residual dynamics at step 0. | |
| # The signal ramps up gradually as delta_u/delta_v learn via STE gradients. | |
| nn.init.normal_(self.sdr_proj.weight, mean=0.0, std=1e-4) | |
| # Modules constructed under torch.device("meta") then moved with | |
| # to_empty() have uninitialized storage. Reinitialize SEM-Claw modules | |
| # here so default-on routing is a real architecture, not allocator noise. | |
| if hasattr(self, "cantor") and hasattr(self.cantor, "branch"): | |
| g = torch.Generator(device="cpu") | |
| g.manual_seed(42) | |
| bound = _math.sqrt(3.0 / self.cantor.d_query) | |
| branch = torch.empty( | |
| self.cantor.branch.shape, | |
| device="cpu", | |
| dtype=torch.float32, | |
| ).uniform_(-bound, bound, generator=g) | |
| self.cantor.branch.copy_(branch.to(device=device, dtype=self.cantor.branch.dtype)) | |
| nn.init.normal_(self.engram.memory, mean=0.0, std=0.01) | |
| nn.init.zeros_(self.engram.gate.weight) | |
| nn.init.constant_(self.engram.gate.bias, 0.0) | |
| if self.reality_bridge is not None: | |
| nn.init.normal_(self.reality_bridge.to_reality.weight, mean=0.0, std=0.02) | |
| nn.init.normal_(self.reality_bridge.to_tangent2.weight, mean=0.0, std=0.02) | |
| # Cast to bf16 to match Mamba3 dtype; Muon groups by shape so mixed | |
| # dtypes in the same shape group would break lerp_ dtype checks. | |
| self.wte.to(dtype=torch.bfloat16) | |
| self.blocks.to(dtype=torch.bfloat16) | |
| self.htm_proj.to(dtype=torch.bfloat16) | |
| self.sdr_proj.to(dtype=torch.bfloat16) | |
| self.engram.to(dtype=torch.bfloat16) | |
| if self.reality_bridge is not None: | |
| self.reality_bridge.to(dtype=torch.bfloat16) | |
| def set_bos_token_id(self, bos_id: int) -> None: | |
| """Inform the model of the tokenizer's BOS id so doc-separator | |
| masking (learnability #4) knows which positions to skip. Called from | |
| training setup once the tokenizer is loaded.""" | |
| self._bos_token_id = int(bos_id) | |
| def invalidate_hyena_caches(self) -> None: | |
| """Invalidate filter-rfft caches on all Hyena blocks. | |
| MUST be called after each `optimizer.step()` when | |
| `HYDRA_HYENA_FILTER_CACHE=1` is set, otherwise cached rfft values | |
| will be reused with stale filter parameters. | |
| No-op for blocks that are not HyenaBlock (Mamba3, etc.). | |
| """ | |
| for block in self.blocks: | |
| if hasattr(block, "operator") and hasattr(block.operator, "invalidate_filter_cache"): | |
| block.operator.invalidate_filter_cache() | |
| def flush_hyena_pending_grads(self) -> None: | |
| """Push pending train-cache filter gradients into filter params. | |
| Used ONLY when HYDRA_HYENA_TRAIN_CACHE=1. Must be called exactly once | |
| per optimizer step, BEFORE `optimizer.step()` and BEFORE | |
| `invalidate_hyena_caches()`. The lightning_module wires this in | |
| `optimizer_step` around the existing optimizer.step() call. | |
| No-op if: | |
| * No HyenaBlocks are in the model, OR | |
| * No micro-batch ever ran with grad enabled (e.g. all-eval step). | |
| """ | |
| for block in self.blocks: | |
| if hasattr(block, "operator") and hasattr(block.operator, "flush_pending_filter_grads"): | |
| block.operator.flush_pending_filter_grads() | |
| def estimate_flops(self) -> int: | |
| nparams = sum(p.numel() for p in self.parameters()) | |
| embed_params = self.wte.weight.numel() | |
| return 6 * (nparams - embed_params) | |
| def num_scaling_params(self) -> dict: | |
| wte = sum(p.numel() for p in self.wte.parameters()) | |
| lm_head = sum(p.numel() for p in self.lm_head.parameters()) | |
| blocks = sum(p.numel() for p in self.blocks.parameters()) | |
| sdr = sum(p.numel() for p in self.sdr_semantic.parameters()) | |
| htm_proj = sum(p.numel() for p in self.htm_proj.parameters()) | |
| engram = sum(p.numel() for p in self.engram.parameters()) | |
| total = sum(p.numel() for p in self.parameters()) | |
| return { | |
| 'wte': wte, 'lm_head': lm_head, 'blocks': blocks, | |
| 'sdr_semantic': sdr, 'htm_proj': htm_proj, | |
| 'engram': engram, 'total': total, | |
| } | |
| def get_secondary_metrics(self) -> dict: | |
| """Flush any lingering CUDA tensors to host (single sync).""" | |
| flushed = {} | |
| for k, v in self._metrics.items(): | |
| if hasattr(v, 'item'): | |
| try: | |
| flushed[k] = float(v.item()) | |
| except Exception: | |
| flushed[k] = v | |
| else: | |
| flushed[k] = v | |
| return flushed | |
| def _slow_fast_ortho_named_tensors(self): | |
| targets = set(self._sf_ortho_targets) | |
| if "cantor" in targets and hasattr(self, "cantor") and hasattr(self.cantor, "branch"): | |
| yield "cantor_branch", self.cantor.branch | |
| if "engram" in targets and hasattr(self, "engram") and hasattr(self.engram, "memory"): | |
| yield "engram_memory", self.engram.memory | |
| if "block_out_proj" in targets: | |
| for i, block in enumerate(self.blocks): | |
| out_proj = getattr(block, "out_proj", None) | |
| weight = getattr(out_proj, "weight", None) | |
| if weight is not None and weight.dim() == 2: | |
| yield f"block_{i}_out_proj", weight | |
| if "block_in_proj" in targets: | |
| for i, block in enumerate(self.blocks): | |
| in_proj = getattr(block, "in_proj", None) | |
| weight = getattr(in_proj, "weight", None) | |
| if weight is not None and weight.dim() == 2: | |
| yield f"block_{i}_in_proj", weight | |
| def _slow_fast_ortho_loss(self) -> torch.Tensor: | |
| vals = [ | |
| paired_slow_fast_orthogonality(w) | |
| for _, w in self._slow_fast_ortho_named_tensors() | |
| if torch.is_tensor(w) and w.dim() == 2 | |
| ] | |
| if not vals: | |
| return self.wte.weight.new_zeros(()) | |
| return torch.stack(vals).mean() | |
| def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.6, matrix_lr=0.04, | |
| weight_decay=0.2, adam_betas=(0.8, 0.95), scalar_lr=0.5): | |
| """Setup MuonAdamW optimizer with per-component LR groups.""" | |
| model_dim = self.config.d_model | |
| embedding_params = list(self.wte.parameters()) | |
| lm_head_params = list(self.lm_head.parameters()) | |
| # Muon routing guard: 2D parameters are NOT automatically matrices. | |
| # Exclude: | |
| # (a) params whose name ends in `.freq` — Sin frequency vectors used | |
| # by Hyena's implicit filter MLP. Shape (1, dim) is nominally 2D | |
| # but semantically a per-dim scalar. Muon's polar-express | |
| # orthogonalization would force it toward an orthogonal matrix, | |
| # destroying the learned modulation frequencies. | |
| # (b) 2-D params with min(shape) < MUON_MIN_DIM. Tiny projections | |
| # (e.g. HyenaFilter.implicit_filter.0.weight of shape (64, 3)) | |
| # get collapsed toward near-identity by orthogonalization on the | |
| # narrow axis, damaging expressivity. These belong in AdamW. | |
| # These exclusions route the params into the AdamW scalar/vector group. | |
| MUON_MIN_DIM = 8 | |
| def _muon_eligible(name: str, p: torch.Tensor) -> bool: | |
| if p.dim() != 2: | |
| return False | |
| if name.endswith(".freq"): | |
| return False | |
| if min(p.shape) < MUON_MIN_DIM: | |
| return False | |
| return True | |
| # Matrix params -> Muon (2D weight matrices passing the routing guard). | |
| matrix_params = [] | |
| for name, p in self.blocks.named_parameters(): | |
| if _muon_eligible(name, p): | |
| matrix_params.append(p) | |
| # NOTE (W1 audit REG-2): SemanticFoldingSDR.delta_u / delta_v are now | |
| # GRADIENT-ALIVE via the STE path reconnected in _forward_impl: the | |
| # differentiable SDR output feeds into sdr_proj which feeds into the | |
| # engram residual. Gradient flows: LM_loss -> engram -> ... -> sdr_proj | |
| # -> sdr_semantic.forward() -> STE -> delta_u/delta_v. These are 1D | |
| # params (not 2D matrices) so they automatically land in scalar AdamW | |
| # via the exclusion below — no special treatment needed. | |
| # for p in self.sdr_semantic.parameters(): | |
| # if p.dim() == 2: | |
| # matrix_params.append(p) | |
| for name, p in self.htm_proj.named_parameters(): | |
| if _muon_eligible(name, p): | |
| matrix_params.append(p) | |
| for name, p in self.engram.named_parameters(): | |
| if _muon_eligible(name, p): | |
| matrix_params.append(p) | |
| assigned = set(id(p) for p in embedding_params + lm_head_params + matrix_params) | |
| # Extract dt_bias from each Mamba3 block into its own high-LR group. | |
| # dt_bias controls SSM temporal discretization: softplus(dt_bias) gives | |
| # the step size delta_t. A single dt_bias per head is 1D — falls into | |
| # scalar_params by default. Give it a dedicated group with HYDRA_DT_BIAS_LR | |
| # so heads can differentiate their timescales instead of all tracking at | |
| # ln(2) across all layers. | |
| dt_bias_params = [p for name, p in self.blocks.named_parameters() | |
| if name.endswith('dt_bias')] | |
| dt_bias_ids = set(id(p) for p in dt_bias_params) if dt_bias_params else set() | |
| scalar_params = [ | |
| p for p in self.parameters() | |
| if id(p) not in assigned and id(p) not in dt_bias_ids | |
| ] | |
| total_assigned = len(embedding_params) + len(lm_head_params) + len(matrix_params) + len(scalar_params) + len(dt_bias_params) | |
| total_params = len(list(self.parameters())) | |
| assert total_assigned == total_params, ( | |
| f"Parameter count mismatch: assigned {total_assigned} vs total {total_params}" | |
| ) | |
| dmodel_lr_scale = (model_dim / 768) ** -0.5 | |
| print(f"Scaling AdamW LRs by 1/sqrt({model_dim}/768) = {dmodel_lr_scale:.6f}") | |
| param_groups = [ | |
| dict(kind='adamw', params=lm_head_params, | |
| lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, | |
| eps=1e-10, weight_decay=0.0), | |
| dict(kind='adamw', params=embedding_params, | |
| lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, | |
| eps=1e-10, weight_decay=0.0), | |
| ] | |
| if scalar_params: | |
| param_groups.append( | |
| dict(kind='adamw', params=scalar_params, | |
| lr=scalar_lr * dmodel_lr_scale, betas=adam_betas, | |
| eps=1e-10, weight_decay=0.0) | |
| ) | |
| # dt_bias: dedicated group with embed-level LR so each head learns its | |
| # own temporal discretization. Env-overridable for sweeps. | |
| if dt_bias_params: | |
| _dt_bias_lr = float(os.environ.get("HYDRA_DT_BIAS_LR", str(embedding_lr))) | |
| param_groups.append(dict( | |
| kind='adamw', params=dt_bias_params, | |
| lr=_dt_bias_lr * dmodel_lr_scale, betas=adam_betas, | |
| eps=1e-10, weight_decay=0.0, | |
| )) | |
| for shape in sorted({p.shape for p in matrix_params}): | |
| group_params = [p for p in matrix_params if p.shape == shape] | |
| # ns_steps: Muon polar-express inner iterations. Default 5 (paper), | |
| # but 3 converges on small matrices (d_model ~ 384) with ~40% lower | |
| # optimizer step cost. Env-tunable for experimentation. | |
| _ns_steps = int(os.environ.get("HYDRA_MUON_NS_STEPS", "3")) | |
| param_groups.append(dict( | |
| kind='muon', params=group_params, lr=matrix_lr, | |
| momentum=0.95, ns_steps=_ns_steps, beta2=0.95, weight_decay=weight_decay, | |
| )) | |
| optimizer = MuonAdamW(param_groups) | |
| for group in optimizer.param_groups: | |
| group["initial_lr"] = group["lr"] | |
| return optimizer | |
| def forward(self, idx, targets=None, reduction='mean'): | |
| """idx: (B, T) int64. Returns loss if targets given, else logits. | |
| Nested bf16 autocast is a no-op when ambient autocast is already on; | |
| when it's off (e.g. integration tests) we establish the dtype contract. | |
| """ | |
| if torch.is_autocast_enabled(): | |
| return self._forward_impl(idx, targets=targets, reduction=reduction) | |
| with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| return self._forward_impl(idx, targets=targets, reduction=reduction) | |
| def _forward_impl(self, idx, targets=None, reduction='mean'): | |
| B, T = idx.shape | |
| # Diagnostic: per-subsystem CUDA event timing. Env-gated; zero overhead | |
| # when disabled. Logs one timing line per forward call. Used to isolate | |
| # which subsystem is the tps bottleneck on paid hardware. | |
| _profile = os.environ.get("HYDRA_PROFILE_FORWARD", "0") == "1" | |
| if _profile: | |
| def _ev(): | |
| e = torch.cuda.Event(enable_timing=True) | |
| e.record() | |
| return e | |
| _t0 = _ev() | |
| else: | |
| _t0 = None | |
| # Compact SDR support set used by Reality/Cantor/Engram and by the | |
| # sparse SDR projection below. Do NOT materialize the dense | |
| # (B,T,n_bits) SDR every step: at B16/T1024/n_bits=16384 that dense | |
| # projection dominated runtime. Dense uint8 SDR is built only on HTM | |
| # subsample steps where the HTM subsystem actually consumes it. | |
| sdr_active_indices = self.sdr_semantic.active_indices(idx) | |
| sdr_binary = None | |
| # HTM subsampling: run HTM on 1 of every N micro-batches within a | |
| # gradient accumulation step, reuse the cached result for the other | |
| # N-1 micro-batches. Cooperative launch monopolizes all SMs (grid.sync | |
| # requires full-grid residency), so HTM and mamba can't overlap via | |
| # streams. Subsampling removes HTM from most micro-batches' critical | |
| # path instead. | |
| # | |
| # Math: N=8, 64 accum steps → 8 HTM calls (10.6ms each) + 56 fast | |
| # calls (4ms each). Total = 84.8 + 224 = 309ms → 106k tps. | |
| # | |
| # HYDRA_HTM_SUBSAMPLE=N (default 8). Set =1 for every-microbatch HTM. | |
| _htm_sub = int(os.environ.get("HYDRA_HTM_SUBSAMPLE", "8")) | |
| if not hasattr(self, '_htm_call_idx'): | |
| self._htm_call_idx = 0 | |
| _run_htm = (self._htm_call_idx % _htm_sub == 0) | |
| self._htm_call_idx += 1 | |
| if _run_htm: | |
| sdr_binary = self.sdr_semantic.binary_only(idx) | |
| self._last_sdr = sdr_binary | |
| htm_handle = self.htm.forward_async(sdr_binary) | |
| else: | |
| htm_handle = None | |
| if _profile: _t_htm_async = _ev() | |
| dense_emb = self.wte(idx) # (B, T, d_model) bf16 | |
| if _profile: _t_wte = _ev() | |
| if _run_htm: | |
| htm_out = self.htm.forward_await(htm_handle) | |
| self._htm_cache = htm_out.detach() # cache for non-HTM micro-batches | |
| elif hasattr(self, '_htm_cache') and self._htm_cache is not None \ | |
| and self._htm_cache.shape[0] == B and self._htm_cache.shape[1] == T: | |
| htm_out = self._htm_cache | |
| else: | |
| # Very first call with subsample > 1: run HTM anyway. | |
| sdr_binary = self.sdr_semantic.binary_only(idx) | |
| self._last_sdr = sdr_binary | |
| htm_handle = self.htm.forward_async(sdr_binary) | |
| htm_out = self.htm.forward_await(htm_handle) | |
| self._htm_cache = htm_out.detach() | |
| if _profile: _t_htm_await = _ev() | |
| with torch.no_grad(): | |
| sdr_active_bits = float(self.sdr_semantic.target_active) | |
| htm_anomaly = htm_out[..., -1].mean() | |
| # Learnability #5: explicit stop-grad on HTM output. htm_rust already | |
| # produces a detached tensor, but making it explicit here hardens the | |
| # contract against future refactors that might route HTM through a | |
| # grad-enabled op. | |
| if self._htm_stop_grad: | |
| htm_out = htm_out.detach() | |
| # Gradient bridge: HTM columns+anomaly -> d_model. | |
| htm_proj_out = self.htm_proj(htm_out.to(dense_emb.dtype)) | |
| x = dense_emb + htm_proj_out | |
| x = norm(x) | |
| if _profile: _t_htm_proj = _ev() | |
| # mHC-routed Mamba-3 stack with Engram injection at configured layer. | |
| streams = self.mhc[0].init_streams(x) | |
| _engram_ev = None | |
| # Per-layer diagnostic panel. The pre-layer merged state h_pre lets us | |
| # measure residual contribution of each layer: delta_N = h_post - h_pre. | |
| # All reads are detached no-grad to avoid autograd graph pollution. | |
| _diag = self._diag_enabled | |
| if _diag: | |
| # Cast to float32 for the diagnostic arithmetic: the layer's | |
| # residual contribution is small (~0.5 × rms-normed block output), | |
| # which underflows in bf16 subtraction (3-digit mantissa) and | |
| # reports delta_ratio=0 at the boundaries. float32 snapshot is | |
| # ~3.8 MB extra memory per diag sample (B=1, T=2048, d=96) — | |
| # negligible vs peak VRAM. | |
| with torch.no_grad(): | |
| h_pre = self.mhc[0].merge_streams(streams).detach().float() | |
| _run_svd = (self._diag_step % self._diag_svd_every) == 0 | |
| for i, (block, mhc_layer) in enumerate(zip(self.blocks, self.mhc)): | |
| def _block_fn(h, _block=block): | |
| return self.drop(_block(norm(h))) | |
| # Learnability #3: gradient checkpointing. Wrap the block-fn so | |
| # the mhc layer's internal uses of it re-run the block in backward | |
| # (trading compute for activation memory). use_reentrant=False is | |
| # the modern API and works cleanly under autocast. | |
| if self._grad_ckpt and self.training: | |
| import torch.utils.checkpoint as _ckpt | |
| _raw_fn = _block_fn | |
| def _block_fn(h, _raw=_raw_fn): # noqa: E731 | |
| return _ckpt.checkpoint(_raw, h, use_reentrant=False) | |
| streams = mhc_layer(streams, _block_fn) | |
| if i == self.engram_layer_idx: | |
| if _profile: _t_pre_engram = _ev() | |
| x_mid = mhc_layer.merge_streams(streams) | |
| # Inject differentiable SDR signal into the engram query via | |
| # a lightweight projection (one bf16 matmul per forward step). | |
| # This is the first time LM loss gradients reach the SDR retina | |
| # via delta_u/delta_v — the curated semantic folding patterns | |
| # can now adapt during training instead of staying frozen. | |
| sdr_feat = FusedSDRProject.apply( | |
| sdr_active_indices, | |
| idx, | |
| self.sdr_proj.weight, | |
| self.sdr_semantic.delta_u, | |
| self.sdr_semantic.delta_v, | |
| ) | |
| # Norm the projection to prevent magnitude blowup: the raw STE | |
| # output has 327/16384 1.0 activations per token, and a single | |
| # matmul through sdr_proj (16384→256) with no normalization | |
| # grows weight norm from 1e-4 to ~182 within 2K steps, | |
| # overwhelming the engram residual. | |
| sdr_feat = norm(sdr_feat) | |
| x_mid = x_mid + sdr_feat | |
| x_mid = semantic_gaussian_mollify( | |
| x_mid, | |
| std=self._semantic_smooth_std, | |
| training=self.training, | |
| eval_enabled=self._semantic_smooth_eval, | |
| ) | |
| # Cantor routing: partition the query space into 2^depth leaves. | |
| # Leaf IDs can constrain Engram column eligibility per query. | |
| leaf_ids = None | |
| if self._cantor_enabled: | |
| leaf_ids, scores = self.cantor( | |
| x_mid, | |
| return_scores=bool(self.cantor.score_grad), | |
| ) | |
| if scores is not None and scores.requires_grad: | |
| self._metrics['cantor_score_mean'] = scores.detach().mean() | |
| # Expose leaf distribution for monitoring. Keep both the | |
| # instantaneous last-forward count and a training-window max; | |
| # final factual probes are tiny and can otherwise overwrite | |
| # the metric with an artificial 1-2 leaf count. | |
| unique = leaf_ids.unique().numel() | |
| self._metrics['cantor_active_leaves'] = unique | |
| self._metrics['cantor_leaf_util'] = unique / self.cantor.n_leaves | |
| if self.training: | |
| self._cantor_active_leaves_train_max = max( | |
| self._cantor_active_leaves_train_max, | |
| int(unique), | |
| ) | |
| self._metrics['cantor_active_leaves_train_max'] = self._cantor_active_leaves_train_max | |
| self._metrics['cantor_leaf_util_train_max'] = ( | |
| self._cantor_active_leaves_train_max / self.cantor.n_leaves | |
| ) | |
| if self.reality_bridge is not None: | |
| reality = self.reality_bridge(x_mid) | |
| engram_active_indices = reality.l0_indices | |
| self._metrics['reality_poincare_radius'] = reality.poincare.float().norm(dim=-1).mean().detach() | |
| else: | |
| engram_active_indices = self.sdr_semantic.active_indices(idx) | |
| x_mid, hit_rate = self.engram( | |
| x_mid, | |
| idx, | |
| sdr_active_indices=engram_active_indices, | |
| cantor_leaf_ids=leaf_ids, | |
| cantor_n_leaves=self.cantor.n_leaves if self._cantor_enabled else None, | |
| ) | |
| streams = mhc_layer.init_streams(x_mid) | |
| self._metrics['engram_hit_rate'] = hit_rate | |
| if self.training: | |
| hit = float(hit_rate.detach().item() if hasattr(hit_rate, 'detach') else hit_rate) | |
| self._engram_hit_rate_train_max = max(self._engram_hit_rate_train_max, hit) | |
| self._engram_hit_rate_train_sum += hit | |
| self._engram_hit_rate_train_count += 1 | |
| self._metrics['engram_hit_rate_train_max'] = self._engram_hit_rate_train_max | |
| self._metrics['engram_hit_rate_train_mean'] = ( | |
| self._engram_hit_rate_train_sum / max(1, self._engram_hit_rate_train_count) | |
| ) | |
| self._metrics['engram_hit_rate_train_count'] = self._engram_hit_rate_train_count | |
| if _profile: _engram_ev = _ev() | |
| if _diag: | |
| with torch.no_grad(): | |
| h_post = mhc_layer.merge_streams(streams).detach().float() | |
| in_n = h_pre.pow(2).mean().sqrt() | |
| out_n = h_post.pow(2).mean().sqrt() | |
| d_n = (h_post - h_pre).pow(2).mean().sqrt() | |
| self._metrics[f'layer_{i}_in_norm'] = float(in_n.item()) | |
| self._metrics[f'layer_{i}_out_norm'] = float(out_n.item()) | |
| self._metrics[f'layer_{i}_delta_ratio'] = float((d_n / (in_n + 1e-6)).item()) | |
| self._metrics[f'layer_{i}_feat_std'] = float(h_post.std(dim=-1).mean().item()) | |
| if _run_svd: | |
| # Effective rank via participation ratio of singular values. | |
| # eff_rank = (Σσ)^2 / Σσ² — smooth rank proxy, bounded by d_model. | |
| # Sampled to keep overhead low (SVD is O(min(B*T, D)^2·D)). | |
| flat = h_post.reshape(-1, h_post.shape[-1])[:512].float() | |
| try: | |
| s = torch.linalg.svdvals(flat) | |
| eff_rank = float(((s.sum() ** 2) / (s.pow(2).sum() + 1e-6)).item()) | |
| self._metrics[f'layer_{i}_eff_rank'] = eff_rank | |
| except Exception: | |
| pass | |
| h_pre = h_post | |
| if _diag: | |
| self._diag_step += 1 | |
| if _profile: _t_blocks = _ev() | |
| self._metrics['sdr_active_bits'] = sdr_active_bits | |
| self._metrics['htm_anomaly'] = htm_anomaly | |
| x = self.mhc[-1].merge_streams(streams) | |
| x = norm(x) | |
| if _profile: _t_merge = _ev() | |
| softcap = self.softcap | |
| _softcap_clamp = os.environ.get("HYDRA_SOFTCAP_CLAMP", "0") == "1" | |
| if targets is not None: | |
| smoothing = self.config.label_smoothing | |
| V = self.config.vocab_size | |
| # Learnability #4: doc-separator masking. In packed rows, | |
| # tokenizer.encode(..., prepend=bos_token) places a BOS at every | |
| # document boundary. Without masking, the model is penalized for | |
| # failing to predict "doc B's BOS" from the last tokens of doc A | |
| # — pure noise. We set targets==bos to -1 (ignore_index). Done | |
| # BEFORE MTP/entropy/sampled-softmax branches so all downstream | |
| # losses inherit the mask. | |
| if self._doc_sep_mask and self._bos_token_id >= 0: | |
| targets = torch.where( | |
| targets == self._bos_token_id, | |
| torch.full_like(targets, -1), | |
| targets, | |
| ) | |
| # Sampled softmax: instead of computing logits for ALL V tokens, | |
| # compute only for the target + K random negatives. Reduces the | |
| # lm_head matmul from (B*T, d) × (d, V) to (B*T, d) × (d, K+1). | |
| # At V=65536 and K=4096: 16× less compute, ~4× tps improvement. | |
| # The log-sum-exp correction adjusts for the sampling bias. | |
| # Set HYDRA_SAMPLED_SOFTMAX=0 to disable (full softmax). | |
| K_neg = int(os.environ.get("HYDRA_SAMPLED_SOFTMAX", "4096")) | |
| use_sampled = K_neg > 0 and K_neg < V and self.training | |
| if use_sampled: | |
| # Flatten hidden states + targets | |
| h_flat = x.reshape(-1, x.shape[-1]) # (B*T, d) | |
| t_flat = targets.reshape(-1) # (B*T,) | |
| n = h_flat.shape[0] | |
| # Learnability #4 hardening: sampled-softmax gather crashes on | |
| # negative ids (-1 from doc-sep mask). Replace -1 with 0 for | |
| # gather; the actual loss is masked below. | |
| valid_mask_flat = (t_flat >= 0) | |
| t_flat_safe = torch.where(valid_mask_flat, t_flat, torch.zeros_like(t_flat)) | |
| # Sample K negatives uniformly from [0, V) | |
| neg_ids = torch.randint(0, V, (K_neg,), device=x.device) | |
| # Gather lm_head weights for target + negatives | |
| all_ids = torch.cat([t_flat_safe, neg_ids]) # (B*T + K,) | |
| sampled_w = self.lm_head.weight[all_ids] # (B*T + K, d) | |
| # Compute sampled logits: for each position, dot with its | |
| # target weight and all K negative weights. | |
| # Target logit: dot product of h[i] with w[target[i]]. | |
| target_w = sampled_w[:n] # (B*T, d) | |
| neg_w = sampled_w[n:] # (K, d) | |
| log_correction = torch.tensor(V / K_neg, device=x.device).log() | |
| # B16+ active-stack experiments on the 6GB local GPU can OOM | |
| # if we materialize the full (B*T, K+1) sampled-CE matrix. | |
| # Chunk the sampled loss just like the full-softmax path unless | |
| # MTP needs the full neg_logits view for reuse. | |
| sampled_chunk = int(os.environ.get("HYDRA_SAMPLED_CE_CHUNK", "0")) | |
| if sampled_chunk <= 0: | |
| sampled_chunk = n | |
| if reduction == 'mean' and self._mtp_k <= 1 and sampled_chunk < n: | |
| total_loss = x.new_tensor(0.0) | |
| total_tokens = x.new_tensor(0.0) | |
| ce_targets_chunk = None | |
| for start in range(0, n, sampled_chunk): | |
| end = min(start + sampled_chunk, n) | |
| h_c = h_flat[start:end] | |
| target_w_c = target_w[start:end] | |
| target_logit_c = (h_c * target_w_c).sum(-1) | |
| neg_logits_c = h_c @ neg_w.t() | |
| if not _softcap_clamp: | |
| target_logit_c = softcap * torch.tanh(target_logit_c / softcap) | |
| neg_logits_c = softcap * torch.tanh(neg_logits_c / softcap) | |
| all_logits_c = torch.cat([ | |
| target_logit_c.unsqueeze(-1), | |
| neg_logits_c + log_correction, | |
| ], dim=-1).float() | |
| if ce_targets_chunk is None or ce_targets_chunk.numel() != end - start: | |
| ce_targets_chunk = torch.zeros(end - start, dtype=torch.long, device=x.device) | |
| per_tok_ce_c = F.cross_entropy( | |
| all_logits_c, ce_targets_chunk, reduction='none', | |
| label_smoothing=smoothing, | |
| ) | |
| valid_c = valid_mask_flat[start:end].float() | |
| total_loss = total_loss + (per_tok_ce_c * valid_c).sum() | |
| total_tokens = total_tokens + valid_c.sum() | |
| out = total_loss / total_tokens.clamp(min=1) | |
| neg_logits = None | |
| else: | |
| target_logit = (h_flat * target_w).sum(-1) # (B*T,) | |
| neg_logits = h_flat @ neg_w.t() # (B*T, K) | |
| if not _softcap_clamp: | |
| target_logit = softcap * torch.tanh(target_logit / softcap) | |
| neg_logits = softcap * torch.tanh(neg_logits / softcap) | |
| # Sampled softmax loss: -log(exp(target) / (exp(target) + sum(exp(neg)))) | |
| # With log-sum-exp correction for sampling K of V negatives. | |
| # Correction: add log(V/K) to negative logits to account for | |
| # the fact that we're only seeing K of V possible negatives. | |
| all_logits = torch.cat([ | |
| target_logit.unsqueeze(-1), # (B*T, 1) | |
| neg_logits + log_correction, # (B*T, K) | |
| ], dim=-1).float() # (B*T, K+1) | |
| # CE with target always at index 0 | |
| ce_targets = torch.zeros(n, dtype=torch.long, device=x.device) | |
| if reduction == 'none': | |
| per_tok = F.cross_entropy(all_logits, ce_targets, reduction='none') | |
| if self._doc_sep_mask and self._bos_token_id >= 0: | |
| per_tok = torch.where(valid_mask_flat, per_tok, torch.zeros_like(per_tok)) | |
| return per_tok | |
| per_tok_ce = F.cross_entropy( | |
| all_logits, ce_targets, reduction='none', | |
| label_smoothing=smoothing, | |
| ) | |
| # Mask doc-separator positions. valid_mask_flat is always | |
| # computed; when doc_sep_mask is off every token is valid so | |
| # this reduces to a plain mean. | |
| valid_f = valid_mask_flat.float() | |
| valid_n = valid_f.sum().clamp(min=1) | |
| out = (per_tok_ce * valid_f).sum() / valid_n | |
| else: | |
| # Full softmax path (eval or HYDRA_SAMPLED_SOFTMAX=0) | |
| chunk_size = int(os.environ.get("HYDRA_CE_CHUNK", "1024")) | |
| if chunk_size <= 0: | |
| MAX_LOGITS_BYTES = 256 * 1024 * 1024 | |
| bytes_per_logit = FLOAT32_BYTES | |
| # Bound by token logits memory: each token contributes V | |
| # logits, so the safe token count can be smaller than V. | |
| tokens_per_chunk = max(1, MAX_LOGITS_BYTES // (V * bytes_per_logit)) | |
| chunk_size = max(1, tokens_per_chunk // max(1, B)) | |
| chunk_size = min(chunk_size, T) | |
| if reduction == 'none': | |
| loss_parts = [] | |
| for start in range(0, T, chunk_size): | |
| end = min(start + chunk_size, T) | |
| chunk_logits = self.lm_head(x[:, start:end, :]).float() | |
| if _softcap_clamp: | |
| chunk_logits = torch.clamp(chunk_logits, -softcap, softcap) | |
| else: | |
| chunk_logits = softcap * torch.tanh(chunk_logits / softcap) | |
| chunk_targets = targets[:, start:end].reshape(-1) | |
| chunk_loss = F.cross_entropy( | |
| chunk_logits.view(-1, chunk_logits.size(-1)), | |
| chunk_targets, ignore_index=-1, reduction='none', | |
| ) | |
| loss_parts.append(chunk_loss) | |
| return torch.cat(loss_parts) | |
| total_loss = 0.0 | |
| total_tokens = 0 | |
| for start in range(0, T, chunk_size): | |
| end = min(start + chunk_size, T) | |
| chunk_logits = self.lm_head(x[:, start:end, :]).float() | |
| if _softcap_clamp: | |
| chunk_logits = torch.clamp(chunk_logits, -softcap, softcap) | |
| else: | |
| chunk_logits = softcap * torch.tanh(chunk_logits / softcap) | |
| chunk_targets = targets[:, start:end].reshape(-1) | |
| chunk_loss = F.cross_entropy( | |
| chunk_logits.view(-1, chunk_logits.size(-1)), | |
| chunk_targets, ignore_index=-1, reduction='sum', | |
| label_smoothing=smoothing, | |
| ) | |
| total_loss = total_loss + chunk_loss | |
| total_tokens += (chunk_targets != -1).sum() | |
| out = total_loss / total_tokens | |
| # ----------------------------------------------------------- | |
| # Learnability #1: Multi-Token Prediction. | |
| # For k in {2..K}, add a CE loss at position (t) predicting | |
| # the token at position (t+k), using the SAME lm_head weights | |
| # (weight-tied). Cost: K-1 extra CEs on a subset of positions. | |
| # Only triggered in reduction='mean' path, training only. | |
| # ----------------------------------------------------------- | |
| if reduction == 'mean' and self._mtp_k > 1 and self.training and use_sampled: | |
| # TRUE zero-cost MTP: reuse primary's neg_logits (B*T, K_neg) | |
| # entirely. Only cost per extra head: O(B*T*d) target-weight | |
| # gather + dot product. neg_logits is sliced (view) to match. | |
| mtp_loss_sum = out.new_tensor(0.0) | |
| mtp_terms = 0 | |
| # Reshape primary neg_logits back to (B, T, K_neg) so we can slice positions | |
| neg_logits_bt = neg_logits.view(B, T, K_neg) | |
| for k in range(2, self._mtp_k + 1): | |
| shift = k - 1 | |
| if T <= shift: | |
| continue | |
| n_k = B * (T - shift) | |
| h_k_flat = x[:, :T - shift, :].reshape(n_k, -1) # (n_k, d) | |
| t_k = targets[:, shift:].reshape(-1) # (n_k,) | |
| mask_k = (t_k >= 0) | |
| t_k_safe = torch.where(mask_k, t_k, torch.zeros_like(t_k)) | |
| tgt_w_k = self.lm_head.weight[t_k_safe] # (n_k, d) | |
| tgt_logit_k = (h_k_flat * tgt_w_k).sum(-1) # (n_k,) | |
| if not _softcap_clamp: | |
| tgt_logit_k = softcap * torch.tanh(tgt_logit_k / softcap) | |
| # REUSE primary neg_logits — slice positions [:T-shift] | |
| neg_logits_k = neg_logits_bt[:, :T - shift, :].reshape(n_k, K_neg) | |
| all_logits_k = torch.cat([ | |
| tgt_logit_k.unsqueeze(-1), | |
| neg_logits_k + log_correction, | |
| ], dim=-1).float() | |
| ce_targets_k = torch.zeros(n_k, dtype=torch.long, device=x.device) | |
| per_tok_ce_k = F.cross_entropy( | |
| all_logits_k, ce_targets_k, reduction='none', | |
| label_smoothing=smoothing, | |
| ) | |
| per_tok_ce_k = torch.where(mask_k, per_tok_ce_k, torch.zeros_like(per_tok_ce_k)) | |
| n_valid_k = mask_k.sum().clamp(min=1) | |
| mtp_loss_sum = mtp_loss_sum + per_tok_ce_k.sum() / n_valid_k | |
| mtp_terms += 1 | |
| if mtp_terms > 0: | |
| out = (out + mtp_loss_sum) / float(mtp_terms + 1) | |
| # ----------------------------------------------------------- | |
| # Learnability #6: output entropy penalty. | |
| # L += -lambda * H(softmax(logits)). Negative entropy penalizes | |
| # peaked distributions; encourages diverse predictions and | |
| # breaks repetition loops. Computed on a small subset of | |
| # positions to keep V-sized logits cost bounded. | |
| # ----------------------------------------------------------- | |
| if reduction == 'mean' and self._entropy_penalty > 0.0 and self.training: | |
| # Sample up to 64 random positions. V-sized logits on 64 | |
| # positions = 64 * V * 4 bytes (~50 MB at V=200k) — fits | |
| # on the 3060 and adds ~2 ms. | |
| h_flat = x.reshape(-1, x.shape[-1]) | |
| n_pos = h_flat.shape[0] | |
| n_sample = min(64, n_pos) | |
| idx_sample = torch.randint(0, n_pos, (n_sample,), device=x.device) | |
| h_sample = h_flat[idx_sample] | |
| logits_s = F.linear(h_sample, self.lm_head.weight).float() | |
| if _softcap_clamp: | |
| logits_s = torch.clamp(logits_s, -softcap, softcap) | |
| else: | |
| logits_s = softcap * torch.tanh(logits_s / softcap) | |
| log_probs = F.log_softmax(logits_s, dim=-1) | |
| probs = log_probs.exp() | |
| entropy = -(probs * log_probs).sum(-1).mean() # scalar, nats | |
| out = out - self._entropy_penalty * entropy | |
| if reduction == 'mean' and self.training and ( | |
| self._sf_ortho_lambda > 0.0 or self._sf_ortho_metrics | |
| ): | |
| run_metric = self._sf_ortho_metrics and ( | |
| self._sf_ortho_step % self._sf_ortho_every == 0 | |
| ) | |
| if self._sf_ortho_lambda > 0.0: | |
| sf_ortho = self._slow_fast_ortho_loss() | |
| out = out + self._sf_ortho_lambda * sf_ortho | |
| if run_metric: | |
| self._metrics['slow_fast_ortho_loss'] = sf_ortho.detach() | |
| elif run_metric: | |
| with torch.no_grad(): | |
| self._metrics['slow_fast_ortho_loss'] = self._slow_fast_ortho_loss().detach() | |
| self._sf_ortho_step += 1 | |
| if _profile: | |
| _t_end = _ev() | |
| torch.cuda.synchronize() | |
| def _ms(a, b): return a.elapsed_time(b) | |
| print( | |
| f"[PROFILE B={B} T={T}] " | |
| f"htm_launch={_ms(_t0, _t_htm_async):.2f} " | |
| f"wte={_ms(_t_htm_async, _t_wte):.2f} " | |
| f"htm_await={_ms(_t_wte, _t_htm_await):.2f} " | |
| f"htm_proj={_ms(_t_htm_await, _t_htm_proj):.2f} " | |
| f"mamba_mhc_engram={_ms(_t_htm_proj, _t_blocks):.2f} " | |
| f"merge={_ms(_t_blocks, _t_merge):.2f} " | |
| f"lm_head_loss={_ms(_t_merge, _t_end):.2f} " | |
| f"total={_ms(_t0, _t_end):.2f} ms", | |
| flush=True, | |
| ) | |
| return out | |
| logits = self.lm_head(x).float() | |
| if _softcap_clamp: | |
| logits = torch.clamp(logits, -softcap, softcap) | |
| else: | |
| logits = softcap * torch.tanh(logits / softcap) | |
| return logits | |