| """PostSemClawModel — full-architecture model assembly. |
| |
| Extracted from the monolithic train.py (W1 modularization). Semantics |
| unchanged. Imports `GPUEngram` from `hydra.engram` and `MuonAdamW` from |
| `hydra.optimizer`. |
| |
| Triton kernel integration status (Phase 2): |
| HYDRA_FUSED_BCNORM — DEFERRED. The bcnorm_fused Triton kernel fuses |
| LayerNorm + RoPE on B/C projections. However, mamba-ssm's Mamba3 block |
| uses RMSNormGated (not LayerNorm) for B/C, and RoPE is applied inside |
| the mamba3_siso_combined CUDA kernel via the Angles parameter. Replacing |
| would require either (a) monkey-patching RMSNormGated + intercepting the |
| fused CUDA scan — invasive, 50+ lines, high breakage risk — or (b) a |
| full custom Mamba3Block reimplementation. Both are out of scope for |
| Phase 2. The kernel is validated standalone; integration deferred to |
| Phase 3 when HYDRA moves to a custom SSM block. |
| |
| HYDRA_FUSED_SSD — DEFERRED. The ssd_exp_trap Triton kernel implements |
| exponential-trapezoidal discretization as a sequential scan. mamba-ssm's |
| Mamba3 block delegates the entire scan + gating + output projection to |
| mamba3_siso_combined (a compiled CUDA kernel with tilelang). Replacing |
| it would require decomposing the combined kernel into constituent ops |
| and substituting only the scan — not feasible without a custom block. |
| Same Phase 3 gate as above. |
| |
| Both env vars are accepted but currently no-ops (gates read, logged, but |
| the code path is unchanged). This avoids silent regression if someone |
| sets them expecting a speedup. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| try: |
| from mamba_ssm import Mamba3 |
| except ModuleNotFoundError: |
| Mamba3 = None |
|
|
| from subsystems.hestia_mini import HestiaQAT |
| from subsystems.htm import HTMLayer |
| from subsystems.mhc_mini import ManifoldHyperConnection |
| from subsystems.sdr_semantic import SemanticFoldingSDR |
|
|
| from hydra.engram import GPUEngram |
| from hydra.htm_cache import htm_cache_key, htm_cache_matches |
| from hydra.hyena_block import HyenaBlock |
| from hydra.reality_bridge import RealityPoincareBridge |
| |
| |
| |
| from hydra.optimizer import MuonAdamW |
| from hydra.sampled_softmax import UnigramSampler, sampled_softmax_loss |
|
|
| try: |
| from subsystems.cantor_router import CantorRouter |
| except ModuleNotFoundError: |
| from archive.cantor_router import CantorRouter |
|
|
|
|
| def norm(x: torch.Tensor) -> torch.Tensor: |
| """RMSNorm over the last dim — stateless, autocast-friendly.""" |
| return F.rms_norm(x, (x.size(-1),)) |
|
|
|
|
| def paired_slow_fast_orthogonality(w: torch.Tensor) -> torch.Tensor: |
| """Penalty for aligned adjacent slow/fast vector pairs.""" |
| n = (w.shape[0] // 2) * 2 |
| if n == 0: |
| return w.new_zeros(()) |
| slow = F.normalize(w[:n:2].float(), dim=-1, eps=1e-8) |
| fast = F.normalize(w[1:n:2].float(), dim=-1, eps=1e-8) |
| return (slow * fast).sum(dim=-1).square().mean().to(dtype=w.dtype) |
|
|
|
|
| def semantic_gaussian_mollify( |
| x: torch.Tensor, |
| std: float = 0.0, |
| training: bool = True, |
| eval_enabled: bool = False, |
| ) -> torch.Tensor: |
| """Optionally add train-time semantic Gaussian noise; disabled is identity.""" |
| if std <= 0.0 or (not training and not eval_enabled): |
| return x |
| return x + torch.randn_like(x) * float(std) |
|
|
|
|
| class _LocalMamba3Fallback(nn.Identity): |
| """Shape-preserving local fallback used only when mamba_ssm is absent.""" |
| pass |
|
|
|
|
| class PostSemClawModel(nn.Module): |
| """Full Post-SEM-Claw model assembly. |
| |
| Architecture: |
| Token Embedding -> [Mamba3 + residual] x n_layer |
| -> SDR + Engram (at configured layer) -> norm -> LM head |
| |
| Interface (must match prepare.py evaluate_bpb): |
| model(x, y, reduction='none').view(-1) -> per-token losses |
| model(x, y, reduction='mean') -> scalar loss |
| """ |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
|
|
| |
| |
| |
| |
| |
| |
| self._mdlm_active = os.environ.get("HYDRA_USE_MDLM", "0") == "1" |
| self._htm_cache_key = None |
|
|
| |
| self.wte = nn.Embedding(config.vocab_size, config.d_model) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| _hyena_layer_set = set(getattr(config, "hyena_layers", ()) or ()) |
| _gdn_layer_set = set(getattr(config, "gdn_layers", ()) or ()) |
| |
| _both = _hyena_layer_set & _gdn_layer_set |
| if _both: |
| print(f"[WARN] layers in both hyena_layers and gdn_layers; using Hyena: {sorted(_both)}", flush=True) |
| _gdn_layer_set -= _hyena_layer_set |
| if os.environ.get("HYDRA_STRICT_OPTIMAL_COMPONENTS", "0") == "1": |
| if _hyena_layer_set or _gdn_layer_set: |
| raise RuntimeError( |
| "HYDRA_STRICT_OPTIMAL_COMPONENTS=1 requires all layers to use Mamba3; " |
| f"got hyena_layers={sorted(_hyena_layer_set)} gdn_layers={sorted(_gdn_layer_set)}." |
| ) |
| if Mamba3 is None: |
| raise RuntimeError("HYDRA_STRICT_OPTIMAL_COMPONENTS=1 requires mamba_ssm/Mamba3 to be importable.") |
|
|
| if _gdn_layer_set: |
| from hydra.gdn_block import GDNBlock |
|
|
| def _build_block(i: int) -> nn.Module: |
| if i in _hyena_layer_set: |
| return HyenaBlock( |
| d_model=config.d_model, |
| seq_len=config.sequence_len, |
| order=int(os.environ.get("HYDRA_HYENA_ORDER", "2")), |
| filter_order=int(os.environ.get("HYDRA_HYENA_FILTER_DIM", "64")), |
| ) |
| if i in _gdn_layer_set: |
| return GDNBlock( |
| d_model=config.d_model, |
| n_heads=config.n_heads, |
| ) |
| if Mamba3 is None: |
| return _LocalMamba3Fallback() |
| block = Mamba3( |
| d_model=config.d_model, |
| d_state=config.d_state, |
| expand=config.expand, |
| headdim=config.headdim, |
| is_mimo=False, |
| chunk_size=int(os.environ.get("HYDRA_MAMBA3_CHUNK", "64")), |
| is_outproj_norm=False, |
| dtype=torch.bfloat16, |
| ) |
| return block |
|
|
| self.blocks = nn.ModuleList([_build_block(i) for i in range(config.n_layer)]) |
|
|
| |
| self.sdr_semantic = SemanticFoldingSDR( |
| vocab_size=config.vocab_size, |
| n_bits=config.sdr_n_bits, |
| target_active=config.sdr_target_active, |
| delta_rank=config.sdr_delta_rank, |
| som_warmup_steps=config.sdr_som_warmup, |
| som_update_interval=config.sdr_som_interval, |
| ) |
|
|
| |
| self.htm = HTMLayer( |
| input_bits=config.sdr_n_bits, |
| n_columns=config.htm_n_columns, |
| cells_per_column=config.htm_cells_per_column, |
| batch_size=1, |
| seed=42, |
| learn=True, |
| reset_each_forward=True, |
| ) |
|
|
| |
| self.htm_proj = nn.Linear(config.htm_n_columns + 1, config.d_model, bias=False) |
|
|
| |
| self.engram = GPUEngram( |
| d_model=config.d_model, |
| n_columns=config.engram_n_columns, |
| max_ngram=3, |
| ) |
| self.reality_bridge = None |
| self.cantor = None |
| if os.environ.get("HYDRA_REALITY_BRIDGE", "0") == "1": |
| d_reality = int(os.environ.get("HYDRA_REALITY_D", "133")) |
| self.reality_bridge = RealityPoincareBridge( |
| d_model=config.d_model, |
| d_reality=d_reality, |
| l0_k=int(os.environ.get("HYDRA_REALITY_L0_K", "64")), |
| ) |
| if os.environ.get("HYDRA_CANTOR_DISABLE", "0") != "1": |
| self.cantor = CantorRouter( |
| depth=int(os.environ.get("HYDRA_CANTOR_DEPTH", "7")), |
| d_query=d_reality, |
| seed=int(os.environ.get("HYDRA_CANTOR_SEED", "42")), |
| device=self.wte.weight.device, |
| ) |
| self.engram_layer_idx = config.engram_layer_idx |
|
|
| |
| self.mhc = nn.ModuleList([ |
| ManifoldHyperConnection(d_model=config.d_model, n_streams=2, sinkhorn_iters=3) |
| for _ in range(config.n_layer) |
| ]) |
|
|
| |
| self.hestia = HestiaQAT(enabled=True, bits=1.58) |
|
|
| |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) |
|
|
| |
| |
| |
| |
| |
| self._mtp_k = max(1, int(os.environ.get("HYDRA_MTP_K", "1"))) |
|
|
| |
| |
| |
| |
| self._unigram_sampler = None |
|
|
| |
| self._grad_ckpt = os.environ.get("HYDRA_GRAD_CKPT", "0") == "1" |
|
|
| |
| self._doc_sep_mask = os.environ.get("HYDRA_DOC_SEP_MASK", "0") == "1" |
| |
| |
| self._bos_token_id = -1 |
|
|
| |
| |
| self._htm_stop_grad = os.environ.get("HYDRA_HTM_STOP_GRAD", "0") == "1" |
|
|
| |
| self._entropy_penalty = float(os.environ.get("HYDRA_ENTROPY_PENALTY", "0.0")) |
|
|
| |
| |
| |
| |
| |
| |
| |
| self.drop = nn.Dropout(float(os.environ.get("HYDRA_DROPOUT", "0.05"))) |
|
|
| |
| |
| |
| self.softcap = float(os.environ.get("HYDRA_LOGIT_SOFTCAP", "30.0")) |
|
|
| |
| self._metrics = {} |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self._diag_enabled = os.environ.get("HYDRA_LAYER_DIAGNOSTICS", "0") == "1" |
| _tripwire_ratio = float(os.environ.get("HYDRA_INERTNESS_TRIPWIRE_RATIO", "0.01")) |
| _tripwire_active = _tripwire_ratio > 0.0 |
| self._fwd_hooks_enabled = self._diag_enabled or _tripwire_active |
| self._diag_step = 0 |
| self._diag_svd_every = int(os.environ.get("HYDRA_LAYER_DIAG_SVD_EVERY", "100")) |
| if self._diag_enabled: |
| |
| for _i, _block in enumerate(self.blocks): |
| def _mk_grad_hook(_layer_idx): |
| def _hook(module, grad_input, grad_output): |
| if grad_output and grad_output[0] is not None: |
| g = grad_output[0].detach() |
| self._metrics[f'layer_{_layer_idx}_grad_norm'] = float( |
| g.pow(2).mean().sqrt().item() |
| ) |
| return _hook |
| _block.register_full_backward_hook(_mk_grad_hook(_i)) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if self._fwd_hooks_enabled: |
| for _i, _block in enumerate(self.blocks): |
| def _mk_fwd_hook(_layer_idx): |
| def _hook(module, inputs, output): |
| with torch.no_grad(): |
| inp = inputs[0].detach().float() if inputs else None |
| out = output.detach().float() if isinstance(output, torch.Tensor) else None |
| if out is not None: |
| out_rms = out.pow(2).mean().sqrt().item() |
| self._metrics[f'layer_{_layer_idx}_block_out_rms'] = float(out_rms) |
| if inp is not None: |
| in_rms = inp.pow(2).mean().sqrt().item() |
| self._metrics[f'layer_{_layer_idx}_block_in_rms'] = float(in_rms) |
| self._metrics[f'layer_{_layer_idx}_contrib_ratio'] = float( |
| out_rms / (in_rms + 1e-8) |
| ) |
| return _hook |
| _block.register_forward_hook(_mk_fwd_hook(_i)) |
|
|
| |
| self._fused_bcnorm = os.environ.get("HYDRA_FUSED_BCNORM", "0") == "1" |
| self._fused_ssd = os.environ.get("HYDRA_FUSED_SSD", "0") == "1" |
| if self._fused_bcnorm or self._fused_ssd: |
| import sys |
| _active = [] |
| if self._fused_bcnorm: |
| _active.append("HYDRA_FUSED_BCNORM") |
| if self._fused_ssd: |
| _active.append("HYDRA_FUSED_SSD") |
| print( |
| f"[HYDRA] Triton kernel gates set: {', '.join(_active)}. " |
| f"NOTE: Both are DEFERRED (mamba-ssm Mamba3 uses internal " |
| f"CUDA kernels). Gates accepted but currently no-ops.", |
| file=sys.stderr, |
| ) |
|
|
| |
| if os.environ.get("HYDRA_MODEL_COMPILE", "0") == "1": |
| self._forward_impl = torch.compile( |
| self._forward_impl, |
| fullgraph=False, |
| dynamic=True, |
| mode="default", |
| ) |
|
|
| @torch.no_grad() |
| def init_weights(self) -> None: |
| s = 3 ** 0.5 * self.config.d_model ** -0.5 |
|
|
| |
| |
| |
| device = self.wte.weight.device |
| if hasattr(self.sdr_semantic, '_retina_indices'): |
| self.sdr_semantic._retina_indices = self.sdr_semantic._retina_indices.to(device) |
|
|
| |
| |
| |
| import math as _math |
| _d_model = self.wte.weight.shape[1] |
| wte_std = float(os.environ.get("HYDRA_WTE_STD", str(1.0 / _math.sqrt(_d_model)))) |
| nn.init.normal_(self.wte.weight, mean=0.0, std=wte_std) |
| |
| |
| |
| |
| lm_head_std = float(os.environ.get("HYDRA_LM_HEAD_STD", "0.02")) |
| nn.init.normal_(self.lm_head.weight, mean=0.0, std=lm_head_std) |
| |
| |
| |
| |
| |
| |
| |
|
|
| for li, block in enumerate(self.blocks): |
| if hasattr(block, 'in_proj') and hasattr(block.in_proj, 'weight'): |
| nn.init.uniform_(block.in_proj.weight, -s, s) |
| if hasattr(block, 'out_proj') and hasattr(block.out_proj, 'weight'): |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| n_layer = self.config.n_layer |
| _legacy_divisor = os.environ.get("HYDRA_OUT_PROJ_DIVISOR", "0") == "1" |
| _base_std = 0.02 / (2 * n_layer) ** 0.5 if _legacy_divisor else 0.02 |
| out_std = float(os.environ.get("HYDRA_OUT_PROJ_STD", str(_base_std))) |
| nn.init.normal_(block.out_proj.weight, mean=0.0, std=out_std) |
|
|
| |
| |
| |
| |
| |
| |
| |
| if hasattr(block, 'dt_bias') and isinstance(block.dt_bias, nn.Parameter): |
| _dt_min = float(os.environ.get("HYDRA_DT_MIN", "0.001")) |
| _dt_max = float(os.environ.get("HYDRA_DT_MAX", "0.1")) |
| _log_dt_min = _math.log(max(_dt_min, 1e-6)) |
| _log_dt_max = _math.log(max(_dt_max, _dt_min * 1.0001)) |
| |
| |
| |
| with torch.no_grad(): |
| block.dt_bias.uniform_(_log_dt_min, _log_dt_max) |
|
|
| nn.init.normal_(self.htm_proj.weight, mean=0.0, std=s) |
|
|
| if hasattr(self.engram, "memory"): |
| nn.init.normal_(self.engram.memory, mean=0.0, std=0.01) |
| if hasattr(self.engram, "gate"): |
| nn.init.zeros_(self.engram.gate.weight) |
| nn.init.zeros_(self.engram.gate.bias) |
| if self.reality_bridge is not None: |
| nn.init.normal_(self.reality_bridge.to_reality.weight, mean=0.0, std=0.02) |
| nn.init.normal_(self.reality_bridge.to_tangent2.weight, mean=0.0, std=0.02) |
| if self.cantor is not None and hasattr(self.cantor, "branch"): |
| bound = (3.0 / float(self.cantor.d_query)) ** 0.5 |
| nn.init.uniform_(self.cantor.branch, -bound, bound) |
|
|
| |
| |
| self.wte.to(dtype=torch.bfloat16) |
| self.blocks.to(dtype=torch.bfloat16) |
| self.htm_proj.to(dtype=torch.bfloat16) |
| self.engram.to(dtype=torch.bfloat16) |
| if self.reality_bridge is not None: |
| self.reality_bridge.to(dtype=torch.bfloat16) |
| if self.cantor is not None: |
| self.cantor.to(dtype=torch.bfloat16) |
|
|
| def set_bos_token_id(self, bos_id: int) -> None: |
| """Inform the model of the tokenizer's BOS id so doc-separator |
| masking (learnability #4) knows which positions to skip. Called from |
| training setup once the tokenizer is loaded.""" |
| self._bos_token_id = int(bos_id) |
|
|
| def set_unigram_sampler(self, sampler) -> None: |
| """Attach a UnigramSampler used by the sampled-softmax LM-head loss. |
| |
| Audit 2026-05-09 issue #22 - Cluster E. With a sampler attached the |
| LM-head loss draws negatives from the empirical unigram distribution |
| and applies the per-id ``log p_unigram`` correction (Jean et al. |
| 2015, importance-sampled NCE). Without a sampler the legacy |
| uniform-negative path runs (uniform proposal + ``log(V/K)`` constant |
| correction) - kept as a fallback for environments that cannot build |
| the unigram cache (no tokenizer, CI, etc.). |
| |
| Eval always uses full softmax regardless of this setting. |
| """ |
| if sampler is not None and not isinstance(sampler, UnigramSampler): |
| raise TypeError( |
| f"set_unigram_sampler expects UnigramSampler or None, got {type(sampler)}" |
| ) |
| self._unigram_sampler = sampler |
|
|
| def invalidate_hyena_caches(self) -> None: |
| """Invalidate filter-rfft caches on all Hyena blocks. |
| |
| MUST be called after each `optimizer.step()` when |
| `HYDRA_HYENA_FILTER_CACHE=1` is set, otherwise cached rfft values |
| will be reused with stale filter parameters. |
| |
| No-op for blocks that are not HyenaBlock (Mamba3, etc.). |
| """ |
| for block in self.blocks: |
| if hasattr(block, "operator") and hasattr(block.operator, "invalidate_filter_cache"): |
| block.operator.invalidate_filter_cache() |
|
|
| def flush_hyena_pending_grads(self) -> None: |
| """Push pending train-cache filter gradients into filter params. |
| |
| Used ONLY when HYDRA_HYENA_TRAIN_CACHE=1. Must be called exactly once |
| per optimizer step, BEFORE `optimizer.step()` and BEFORE |
| `invalidate_hyena_caches()`. The lightning_module wires this in |
| `optimizer_step` around the existing optimizer.step() call. |
| |
| No-op if: |
| * No HyenaBlocks are in the model, OR |
| * No micro-batch ever ran with grad enabled (e.g. all-eval step). |
| """ |
| for block in self.blocks: |
| if hasattr(block, "operator") and hasattr(block.operator, "flush_pending_filter_grads"): |
| block.operator.flush_pending_filter_grads() |
|
|
| def estimate_flops(self) -> int: |
| nparams = sum(p.numel() for p in self.parameters()) |
| embed_params = self.wte.weight.numel() |
| return 6 * (nparams - embed_params) |
|
|
| def num_scaling_params(self) -> dict: |
| wte = sum(p.numel() for p in self.wte.parameters()) |
| lm_head = sum(p.numel() for p in self.lm_head.parameters()) |
| blocks = sum(p.numel() for p in self.blocks.parameters()) |
| sdr = sum(p.numel() for p in self.sdr_semantic.parameters()) |
| htm_proj = sum(p.numel() for p in self.htm_proj.parameters()) |
| engram = sum(p.numel() for p in self.engram.parameters()) |
| total = sum(p.numel() for p in self.parameters()) |
| return { |
| 'wte': wte, 'lm_head': lm_head, 'blocks': blocks, |
| 'sdr_semantic': sdr, 'htm_proj': htm_proj, |
| 'engram': engram, 'total': total, |
| } |
|
|
| def get_secondary_metrics(self) -> dict: |
| """Flush any lingering CUDA tensors to host (single sync).""" |
| flushed = {} |
| for k, v in self._metrics.items(): |
| if hasattr(v, 'item'): |
| try: |
| flushed[k] = float(v.item()) |
| except Exception: |
| flushed[k] = v |
| else: |
| flushed[k] = v |
| return flushed |
|
|
| def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.6, matrix_lr=0.04, |
| weight_decay=0.2, adam_betas=(0.8, 0.95), scalar_lr=0.5): |
| """Setup MuonAdamW optimizer with per-component LR groups.""" |
| model_dim = self.config.d_model |
|
|
| embedding_params = list(self.wte.parameters()) |
| lm_head_params = list(self.lm_head.parameters()) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| MUON_MIN_DIM = 8 |
|
|
| def _muon_eligible(name: str, p: torch.Tensor) -> bool: |
| if p.dim() != 2: |
| return False |
| if name.endswith(".freq"): |
| return False |
| if min(p.shape) < MUON_MIN_DIM: |
| return False |
| return True |
|
|
| |
| matrix_params = [] |
| for name, p in self.blocks.named_parameters(): |
| if _muon_eligible(name, p): |
| matrix_params.append(p) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| for name, p in self.htm_proj.named_parameters(): |
| if _muon_eligible(name, p): |
| matrix_params.append(p) |
| for name, p in self.engram.named_parameters(): |
| if _muon_eligible(name, p): |
| matrix_params.append(p) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| _contrastive_enabled = os.environ.get("HYDRA_CONTRASTIVE_RETINA", "0") == "1" |
| _contrastive_lr = float(os.environ.get("HYDRA_CONTRASTIVE_LR", "1e-3")) |
| _sdr_lr = float(os.environ.get("HYDRA_SDR_LR", "1e-3")) |
|
|
| |
| |
| sdr_delta_ids = {id(self.sdr_semantic.delta_u), id(self.sdr_semantic.delta_v)} |
| sdr_contrastive_params = [] |
| if self.sdr_semantic.retina_contrastive is not None: |
| sdr_contrastive_params.append(self.sdr_semantic.retina_contrastive) |
| |
| sdr_delta_params = [self.sdr_semantic.delta_u, self.sdr_semantic.delta_v] |
| |
| |
| |
| sdr_binary_logits = [] |
| _binary_logits_param = getattr(self.sdr_semantic, 'retina_logits', None) |
| if isinstance(_binary_logits_param, torch.nn.Parameter): |
| sdr_binary_logits.append(_binary_logits_param) |
| sdr_delta_params.append(_binary_logits_param) |
|
|
| |
| sdr_param_ids = {id(p) for p in self.sdr_semantic.parameters()} |
|
|
| |
| |
| |
| |
| |
| |
| _dt_bias_lr = float(os.environ.get("HYDRA_DT_BIAS_LR", "1e-2")) |
| dt_bias_params = [] |
| for name, p in self.blocks.named_parameters(): |
| if name.endswith("dt_bias") or name == "dt_bias": |
| dt_bias_params.append(p) |
| dt_bias_ids = {id(p) for p in dt_bias_params} |
|
|
| assigned = set(id(p) for p in embedding_params + lm_head_params + matrix_params) |
| scalar_params = [ |
| p for p in self.parameters() |
| if id(p) not in assigned |
| and id(p) not in sdr_param_ids |
| and id(p) not in dt_bias_ids |
| ] |
|
|
| total_assigned = ( |
| len(embedding_params) + len(lm_head_params) + len(matrix_params) |
| + len(scalar_params) + len(sdr_delta_params) + len(sdr_contrastive_params) |
| + len(dt_bias_params) |
| ) |
| total_params = len(list(self.parameters())) |
| assert total_assigned == total_params, ( |
| f"Parameter count mismatch: assigned {total_assigned} vs total {total_params}" |
| ) |
|
|
| dmodel_lr_scale = (model_dim / 768) ** -0.5 |
| print(f"Scaling AdamW LRs by 1/sqrt({model_dim}/768) = {dmodel_lr_scale:.6f}") |
| if _contrastive_enabled: |
| print(f"[retina-d] Contrastive retina ENABLED. retina_contrastive LR={_contrastive_lr:.2e}") |
| else: |
| print("[retina-d] Contrastive retina DISABLED (set HYDRA_CONTRASTIVE_RETINA=1 to enable)") |
| print( |
| f"[sdr] delta_u/delta_v + binary retina_logits routed to SDR group " |
| f"(lr={_sdr_lr:.2e}) — audit issue #21." |
| ) |
| if dt_bias_params: |
| print( |
| f"[dt-bias] {len(dt_bias_params)} dt_bias parameter(s) routed to " |
| f"dedicated AdamW group (lr={_dt_bias_lr:.2e}, betas=(0.9, 0.95)) — " |
| f"audit issue #19." |
| ) |
| else: |
| print("[dt-bias] no dt_bias parameters detected (Hyena/GDN-only stack?)") |
|
|
| param_groups = [ |
| dict(kind='adamw', params=lm_head_params, |
| lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, |
| eps=1e-10, weight_decay=0.0), |
| dict(kind='adamw', params=embedding_params, |
| lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, |
| eps=1e-10, weight_decay=0.0), |
| ] |
| |
| |
| if sdr_contrastive_params: |
| |
| |
| param_groups.append( |
| dict(kind='retina_contrastive', params=sdr_contrastive_params, |
| lr=_contrastive_lr, betas=adam_betas, |
| eps=1e-10, weight_decay=0.01) |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| param_groups.append( |
| dict(kind='adamw', params=sdr_delta_params, |
| lr=_sdr_lr, betas=adam_betas, |
| eps=1e-10, weight_decay=0.0) |
| ) |
|
|
| |
| |
| |
| |
| |
| if dt_bias_params: |
| param_groups.append(dict( |
| kind='dt_bias', params=dt_bias_params, |
| lr=_dt_bias_lr, betas=(0.9, 0.95), |
| eps=1e-8, weight_decay=0.0, |
| )) |
|
|
| if scalar_params: |
| param_groups.append( |
| dict(kind='adamw', params=scalar_params, |
| lr=scalar_lr * dmodel_lr_scale, betas=adam_betas, |
| eps=1e-10, weight_decay=0.0) |
| ) |
|
|
| for shape in sorted({p.shape for p in matrix_params}): |
| group_params = [p for p in matrix_params if p.shape == shape] |
| |
| |
| |
| _ns_steps = int(os.environ.get("HYDRA_MUON_NS_STEPS", "3")) |
| param_groups.append(dict( |
| kind='muon', params=group_params, lr=matrix_lr, |
| momentum=0.95, ns_steps=_ns_steps, beta2=0.95, weight_decay=weight_decay, |
| )) |
|
|
| optimizer = MuonAdamW(param_groups) |
| for group in optimizer.param_groups: |
| group["initial_lr"] = group["lr"] |
| return optimizer |
|
|
| def forward(self, idx, targets=None, reduction='mean'): |
| """idx: (B, T) int64. Returns loss if targets given, else logits. |
| |
| Nested bf16 autocast is a no-op when ambient autocast is already on; |
| when it's off (e.g. integration tests) we establish the dtype contract. |
| """ |
| if torch.is_autocast_enabled(): |
| return self._forward_impl(idx, targets=targets, reduction=reduction) |
| with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): |
| return self._forward_impl(idx, targets=targets, reduction=reduction) |
|
|
| def _forward_impl(self, idx, targets=None, reduction='mean'): |
| B, T = idx.shape |
|
|
| |
| |
| |
| _profile = os.environ.get("HYDRA_PROFILE_FORWARD", "0") == "1" |
| if _profile: |
| def _ev(): |
| e = torch.cuda.Event(enable_timing=True) |
| e.record() |
| return e |
| _t0 = _ev() |
| else: |
| _t0 = None |
|
|
| |
| sdr_binary = self.sdr_semantic.binary_only(idx) |
| self._last_sdr = sdr_binary |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| _htm_sub = int(os.environ.get("HYDRA_HTM_SUBSAMPLE", "8")) |
| if not hasattr(self, '_htm_call_idx'): |
| self._htm_call_idx = int(os.environ.get("HYDRA_HTM_INITIAL_OFFSET", "0")) |
|
|
| _run_htm = (self._htm_call_idx % _htm_sub == 0) |
| self._htm_call_idx += 1 |
|
|
| if _run_htm: |
| htm_handle = self.htm.forward_async(sdr_binary, output_dtype=self.wte.weight.dtype) |
| else: |
| htm_handle = None |
|
|
| if _profile: _t_htm_async = _ev() |
|
|
| dense_emb = self.wte(idx) |
| dense_emb = semantic_gaussian_mollify( |
| dense_emb, |
| std=float(os.environ.get("HYDRA_SEMANTIC_SMOOTH_STD", "0.0")), |
| training=self.training, |
| eval_enabled=os.environ.get("HYDRA_SEMANTIC_SMOOTH_EVAL", "0") == "1", |
| ) |
|
|
| if _profile: _t_wte = _ev() |
|
|
| |
| |
| |
| |
| |
| |
| |
| if _run_htm: |
| htm_out = self.htm.forward_await(htm_handle) |
| self._htm_cache = htm_out.detach() |
| self._htm_cache_key = htm_cache_key(sdr_binary.nonzero()) |
| self._htm_cache_shape = (B, T) |
| elif ( |
| |
| |
| |
| |
| |
| |
| self.training |
| and not self._mdlm_active |
| and os.environ.get("HYDRA_HTM_CACHE_MODE", "exact").lower() == "shape" |
| and hasattr(self, '_htm_cache') and self._htm_cache is not None |
| and getattr(self, '_htm_cache_shape', None) == (B, T) |
| ): |
| htm_out = self._htm_cache |
| elif ( |
| hasattr(self, '_htm_cache') and self._htm_cache is not None |
| and self._htm_cache.shape[0] == B and self._htm_cache.shape[1] == T |
| and not self._mdlm_active |
| and htm_cache_matches(self._htm_cache_key, sdr_binary.nonzero()) |
| ): |
| htm_out = self._htm_cache |
| elif ( |
| os.environ.get("HYDRA_HTM_ZERO_CACHE_ON_MISS", "0") == "1" |
| and self.training |
| and not self._mdlm_active |
| ): |
| htm_out = torch.zeros((B, T, self.config.htm_n_columns + 1), device=dense_emb.device, dtype=dense_emb.dtype) |
| self._htm_cache = htm_out.detach() |
| self._htm_cache_key = None |
| self._htm_cache_shape = (B, T) |
| else: |
| |
| |
| htm_handle = self.htm.forward_async(sdr_binary, output_dtype=self.wte.weight.dtype) |
| htm_out = self.htm.forward_await(htm_handle) |
| self._htm_cache = htm_out.detach() |
| self._htm_cache_key = htm_cache_key(sdr_binary.nonzero()) |
| self._htm_cache_shape = (B, T) |
|
|
| if _profile: _t_htm_await = _ev() |
| with torch.no_grad(): |
| sdr_active_bits = float(self.sdr_semantic.target_active) |
| htm_anomaly = htm_out[..., -1].mean() |
|
|
| |
| |
| |
| |
| if self._htm_stop_grad: |
| htm_out = htm_out.detach() |
|
|
| |
| htm_proj_out = self.htm_proj(htm_out.to(dense_emb.dtype)) |
| x = dense_emb + htm_proj_out |
| x = norm(x) |
|
|
| if _profile: _t_htm_proj = _ev() |
|
|
| |
| streams = self.mhc[0].init_streams(x) |
| _engram_ev = None |
|
|
| |
| |
| |
| _diag = self._diag_enabled |
| if _diag: |
| |
| |
| |
| |
| |
| |
| with torch.no_grad(): |
| h_pre = self.mhc[0].merge_streams(streams).detach().float() |
| _run_svd = (self._diag_step % self._diag_svd_every) == 0 |
|
|
| for i, (block, mhc_layer) in enumerate(zip(self.blocks, self.mhc)): |
| def _block_fn(h, _block=block): |
| return self.drop(_block(norm(h))) |
|
|
| |
| |
| |
| |
| if self._grad_ckpt and self.training: |
| import torch.utils.checkpoint as _ckpt |
| _raw_fn = _block_fn |
| def _block_fn(h, _raw=_raw_fn): |
| return _ckpt.checkpoint(_raw, h, use_reentrant=False) |
|
|
| streams = mhc_layer(streams, _block_fn) |
|
|
| if i == self.engram_layer_idx: |
| if _profile: _t_pre_engram = _ev() |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| x_mid = mhc_layer.merge_streams(streams) |
| if self.reality_bridge is not None and self.cantor is not None: |
| rb = self.reality_bridge(x_mid) |
| cantor_leaf_ids, _ = self.cantor(rb.reality, return_scores=False) |
| x_after_engram, hit_rate = self.engram( |
| x_mid, |
| idx, |
| sdr_active_indices=rb.l0_indices, |
| cantor_leaf_ids=cantor_leaf_ids, |
| cantor_n_leaves=self.cantor.n_leaves, |
| ) |
| else: |
| x_after_engram, hit_rate = self.engram(x_mid, idx) |
| if os.environ.get("HYDRA_ENGRAM_RESET_STREAMS", "0") == "1": |
| streams = mhc_layer.init_streams(x_after_engram) |
| else: |
| engram_delta = x_after_engram - x_mid |
| streams = streams.clone() |
| streams[0] = streams[0] + engram_delta |
| self._metrics['engram_hit_rate'] = hit_rate |
| if _diag: |
| with torch.no_grad(): |
| s0 = streams[0].detach().float() |
| s1 = streams[1].detach().float() if streams.shape[0] > 1 else None |
| self._metrics['engram_stream0_rms'] = float(s0.pow(2).mean().sqrt().item()) |
| if s1 is not None: |
| self._metrics['engram_stream1_rms'] = float(s1.pow(2).mean().sqrt().item()) |
| self._metrics['engram_stream_divergence_rms'] = float( |
| (s0 - s1).pow(2).mean().sqrt().item() |
| ) |
| if _profile: _engram_ev = _ev() |
|
|
| if _diag: |
| with torch.no_grad(): |
| h_post = mhc_layer.merge_streams(streams).detach().float() |
| in_n = h_pre.pow(2).mean().sqrt() |
| out_n = h_post.pow(2).mean().sqrt() |
| d_n = (h_post - h_pre).pow(2).mean().sqrt() |
| self._metrics[f'layer_{i}_in_norm'] = float(in_n.item()) |
| self._metrics[f'layer_{i}_out_norm'] = float(out_n.item()) |
| self._metrics[f'layer_{i}_delta_ratio'] = float((d_n / (in_n + 1e-6)).item()) |
| self._metrics[f'layer_{i}_feat_std'] = float(h_post.std(dim=-1).mean().item()) |
| if _run_svd: |
| |
| |
| |
| flat = h_post.reshape(-1, h_post.shape[-1])[:512].float() |
| try: |
| s = torch.linalg.svdvals(flat) |
| eff_rank = float(((s.sum() ** 2) / (s.pow(2).sum() + 1e-6)).item()) |
| self._metrics[f'layer_{i}_eff_rank'] = eff_rank |
| except Exception: |
| pass |
| h_pre = h_post |
|
|
| if _diag: |
| self._diag_step += 1 |
|
|
| if _profile: _t_blocks = _ev() |
|
|
| self._metrics['sdr_active_bits'] = sdr_active_bits |
| self._metrics['htm_anomaly'] = htm_anomaly |
|
|
| x = self.mhc[-1].merge_streams(streams) |
| x = norm(x) |
|
|
| if _profile: _t_merge = _ev() |
|
|
| softcap = self.softcap |
| _softcap_clamp = os.environ.get("HYDRA_SOFTCAP_CLAMP", "0") == "1" |
| if targets is not None: |
| smoothing = self.config.label_smoothing |
| V = self.config.vocab_size |
|
|
| |
| |
| |
| |
| |
| |
| |
| if self._doc_sep_mask and self._bos_token_id >= 0: |
| targets = torch.where( |
| targets == self._bos_token_id, |
| torch.full_like(targets, -1), |
| targets, |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| K_neg = int(os.environ.get("HYDRA_SAMPLED_SOFTMAX", "4096")) |
| use_sampled = K_neg > 0 and K_neg < V and self.training |
| unigram_sampler = getattr(self, "_unigram_sampler", None) |
|
|
| |
| |
| |
| neg_logits = None |
| log_correction = None |
|
|
| if use_sampled and unigram_sampler is not None: |
| |
| h_flat = x.reshape(-1, x.shape[-1]) |
| t_flat = targets.reshape(-1) |
| valid_mask_flat = (t_flat >= 0) |
|
|
| if reduction == 'none': |
| per_tok = sampled_softmax_loss( |
| h_flat, t_flat, self.lm_head.weight, unigram_sampler, K_neg, |
| label_smoothing=0.0, softcap=softcap, |
| softcap_clamp=_softcap_clamp, |
| valid_mask=valid_mask_flat, reduction='none', |
| ) |
| return per_tok |
|
|
| out = sampled_softmax_loss( |
| h_flat, t_flat, self.lm_head.weight, unigram_sampler, K_neg, |
| label_smoothing=smoothing, softcap=softcap, |
| softcap_clamp=_softcap_clamp, |
| valid_mask=valid_mask_flat, reduction='mean', |
| ) |
| elif use_sampled: |
| |
| |
| h_flat = x.reshape(-1, x.shape[-1]) |
| t_flat = targets.reshape(-1) |
| n = h_flat.shape[0] |
|
|
| |
| |
| |
| valid_mask_flat = (t_flat >= 0) |
| t_flat_safe = torch.where(valid_mask_flat, t_flat, torch.zeros_like(t_flat)) |
|
|
| |
| neg_ids = torch.randint(0, V, (K_neg,), device=x.device) |
| |
| all_ids = torch.cat([t_flat_safe, neg_ids]) |
| sampled_w = self.lm_head.weight[all_ids] |
|
|
| |
| |
| |
| target_w = sampled_w[:n] |
| neg_w = sampled_w[n:] |
| target_logit = (h_flat * target_w).sum(-1) |
| neg_logits = h_flat @ neg_w.t() |
|
|
| if not _softcap_clamp: |
| target_logit = softcap * torch.tanh(target_logit / softcap) |
| neg_logits = softcap * torch.tanh(neg_logits / softcap) |
|
|
| |
| |
| |
| |
| log_correction = torch.tensor(V / K_neg, device=x.device).log() |
| all_logits = torch.cat([ |
| target_logit.unsqueeze(-1), |
| neg_logits + log_correction, |
| ], dim=-1).float() |
|
|
| |
| ce_targets = torch.zeros(n, dtype=torch.long, device=x.device) |
| if reduction == 'none': |
| per_tok = F.cross_entropy(all_logits, ce_targets, reduction='none') |
| if self._doc_sep_mask and self._bos_token_id >= 0: |
| per_tok = torch.where(valid_mask_flat, per_tok, torch.zeros_like(per_tok)) |
| return per_tok |
| per_tok_ce = F.cross_entropy( |
| all_logits, ce_targets, reduction='none', |
| label_smoothing=smoothing, |
| ) |
| |
| |
| |
| valid_f = valid_mask_flat.float() |
| valid_n = valid_f.sum().clamp(min=1) |
| out = (per_tok_ce * valid_f).sum() / valid_n |
| else: |
| |
| chunk_size = int(os.environ.get("HYDRA_CE_CHUNK", "1024")) |
| if chunk_size <= 0: |
| MAX_LOGITS_BYTES = 256 * 1024 * 1024 |
| tokens_per_chunk = max(V, MAX_LOGITS_BYTES // (V * 4)) |
| chunk_size = max(1, tokens_per_chunk // max(1, B)) |
| chunk_size = min(chunk_size, T) |
|
|
| if reduction == 'none': |
| loss_parts = [] |
| for start in range(0, T, chunk_size): |
| end = min(start + chunk_size, T) |
| chunk_logits = self.lm_head(x[:, start:end, :]).float() |
| if _softcap_clamp: |
| chunk_logits = torch.clamp(chunk_logits, -softcap, softcap) |
| else: |
| chunk_logits = softcap * torch.tanh(chunk_logits / softcap) |
| chunk_targets = targets[:, start:end].reshape(-1) |
| chunk_loss = F.cross_entropy( |
| chunk_logits.view(-1, chunk_logits.size(-1)), |
| chunk_targets, ignore_index=-1, reduction='none', |
| ) |
| loss_parts.append(chunk_loss) |
| return torch.cat(loss_parts) |
|
|
| total_loss = 0.0 |
| total_tokens = 0 |
| for start in range(0, T, chunk_size): |
| end = min(start + chunk_size, T) |
| chunk_logits = self.lm_head(x[:, start:end, :]).float() |
| if _softcap_clamp: |
| chunk_logits = torch.clamp(chunk_logits, -softcap, softcap) |
| else: |
| chunk_logits = softcap * torch.tanh(chunk_logits / softcap) |
| chunk_targets = targets[:, start:end].reshape(-1) |
| chunk_loss = F.cross_entropy( |
| chunk_logits.view(-1, chunk_logits.size(-1)), |
| chunk_targets, ignore_index=-1, reduction='sum', |
| label_smoothing=smoothing, |
| ) |
| total_loss = total_loss + chunk_loss |
| total_tokens += (chunk_targets != -1).sum() |
| out = total_loss / total_tokens |
|
|
| |
| |
| |
| |
| |
| |
| |
| if reduction == 'mean' and self._mtp_k > 1 and self.training and use_sampled: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| mtp_loss_sum = out.new_tensor(0.0) |
| mtp_terms = 0 |
| neg_logits_bt = None |
| if unigram_sampler is None and neg_logits is not None: |
| neg_logits_bt = neg_logits.view(B, T, K_neg) |
|
|
| for k in range(2, self._mtp_k + 1): |
| shift = k - 1 |
| if T <= shift: |
| continue |
| n_k = B * (T - shift) |
| h_k_flat = x[:, :T - shift, :].reshape(n_k, -1) |
| t_k = targets[:, shift:].reshape(-1) |
| mask_k = (t_k >= 0) |
|
|
| if unigram_sampler is not None: |
| |
| per_tok_ce_k = sampled_softmax_loss( |
| h_k_flat, t_k, self.lm_head.weight, |
| unigram_sampler, K_neg, |
| label_smoothing=smoothing, softcap=softcap, |
| softcap_clamp=_softcap_clamp, |
| valid_mask=mask_k, reduction='none', |
| ) |
| |
| n_valid_k = mask_k.float().sum().clamp(min=1) |
| mtp_loss_sum = mtp_loss_sum + per_tok_ce_k.sum() / n_valid_k |
| else: |
| |
| |
| |
| t_k_safe = torch.where(mask_k, t_k, torch.zeros_like(t_k)) |
| tgt_w_k = self.lm_head.weight[t_k_safe] |
| tgt_logit_k = (h_k_flat * tgt_w_k).sum(-1) |
| if not _softcap_clamp: |
| tgt_logit_k = softcap * torch.tanh(tgt_logit_k / softcap) |
| neg_logits_k = neg_logits_bt[:, :T - shift, :].reshape(n_k, K_neg) |
| all_logits_k = torch.cat([ |
| tgt_logit_k.unsqueeze(-1), |
| neg_logits_k + log_correction, |
| ], dim=-1).float() |
| ce_targets_k = torch.zeros(n_k, dtype=torch.long, device=x.device) |
| per_tok_ce_k = F.cross_entropy( |
| all_logits_k, ce_targets_k, reduction='none', |
| label_smoothing=smoothing, |
| ) |
| per_tok_ce_k = torch.where(mask_k, per_tok_ce_k, torch.zeros_like(per_tok_ce_k)) |
| n_valid_k = mask_k.sum().clamp(min=1) |
| mtp_loss_sum = mtp_loss_sum + per_tok_ce_k.sum() / n_valid_k |
| mtp_terms += 1 |
| if mtp_terms > 0: |
| out = (out + mtp_loss_sum) / float(mtp_terms + 1) |
|
|
| |
| |
| |
| |
| |
| |
| |
| if reduction == 'mean' and self._entropy_penalty > 0.0 and self.training: |
| |
| |
| |
| h_flat = x.reshape(-1, x.shape[-1]) |
| n_pos = h_flat.shape[0] |
| n_sample = min(64, n_pos) |
| idx_sample = torch.randint(0, n_pos, (n_sample,), device=x.device) |
| h_sample = h_flat[idx_sample] |
| logits_s = F.linear(h_sample, self.lm_head.weight).float() |
| if _softcap_clamp: |
| logits_s = torch.clamp(logits_s, -softcap, softcap) |
| else: |
| logits_s = softcap * torch.tanh(logits_s / softcap) |
| log_probs = F.log_softmax(logits_s, dim=-1) |
| probs = log_probs.exp() |
| entropy = -(probs * log_probs).sum(-1).mean() |
| out = out - self._entropy_penalty * entropy |
|
|
| if _profile: |
| _t_end = _ev() |
| torch.cuda.synchronize() |
| def _ms(a, b): return a.elapsed_time(b) |
| print( |
| f"[PROFILE B={B} T={T}] " |
| f"htm_launch={_ms(_t0, _t_htm_async):.2f} " |
| f"wte={_ms(_t_htm_async, _t_wte):.2f} " |
| f"htm_await={_ms(_t_wte, _t_htm_await):.2f} " |
| f"htm_proj={_ms(_t_htm_await, _t_htm_proj):.2f} " |
| f"mamba_mhc_engram={_ms(_t_htm_proj, _t_blocks):.2f} " |
| f"merge={_ms(_t_blocks, _t_merge):.2f} " |
| f"lm_head_loss={_ms(_t_merge, _t_end):.2f} " |
| f"total={_ms(_t0, _t_end):.2f} ms", |
| flush=True, |
| ) |
| return out |
|
|
| logits = self.lm_head(x).float() |
| if _softcap_clamp: |
| logits = torch.clamp(logits, -softcap, softcap) |
| else: |
| logits = softcap * torch.tanh(logits / softcap) |
| return logits |
|
|