"""LightningModule wrapping PostSemClawModel. Thin adapter. The model and the MuonAdamW optimizer are unchanged. This module implements: • configure_optimizers — returns the existing MuonAdamW (subclass of torch.optim.Optimizer) built by model.setup_optimizer. Lightning accepts this directly. • training_step — splits (B, T+1) batches into (x, y), forwards through the model, logs loss / bpb / tps / mfu / vram. Preserves the sampled-softmax path inside PostSemClawModel (no changes there). • optimizer_step — before each step we update LR + muon momentum + WD using the same time-progress schedule as hydra/training.py (get_lr_multiplier / get_muon_momentum / get_weight_decay). Lightning handles grad accumulation via Trainer(accumulate_grad_batches=N). The SDR SOM update and Hestia QAT snap are called at the same cadence as the legacy loop, but inline on the main thread (Lightning provides its own callbacks for async work if we need to extract them later — keeping it simple for now). Env vars respected: HYDRA_TIME_BUDGET — wall-clock budget (s) used for LR schedule and as Trainer max_time HYDRA_HESTIA_INTERVAL — steps between Hestia snaps (default 100) HYDRA_BATCH_SIZE — device batch size (for throughput calc) HYDRA_SEQ_LEN — sequence length (for throughput calc) """ from __future__ import annotations import math import os import time import torch import lightning as L from hydra.config import ( ADAM_BETAS, EMBEDDING_LR, FINAL_LR_FRAC, GPU_BF16_PEAK_FLOPS, MATRIX_LR, SCALAR_LR, UNEMBEDDING_LR, WARMUP_RATIO, WEIGHT_DECAY, PostSemClawConfig, ) from hydra.model import PostSemClawModel # --------------------------------------------------------------------------- # LR / momentum / wd schedules — verbatim copy of hydra/training.py so the # curves match exactly. Kept here to avoid import cycles. # --------------------------------------------------------------------------- def _lr_multiplier(progress: float) -> float: if progress < WARMUP_RATIO: return progress / WARMUP_RATIO if WARMUP_RATIO > 0 else 1.0 decay_progress = (progress - WARMUP_RATIO) / max(1.0 - WARMUP_RATIO, 1e-9) return FINAL_LR_FRAC + 0.5 * (1.0 - FINAL_LR_FRAC) * ( 1 + math.cos(math.pi * decay_progress) ) def _muon_momentum(step: int) -> float: frac = min(step / 300.0, 1.0) return (1 - frac) * 0.85 + frac * 0.95 def _weight_decay(progress: float) -> float: return WEIGHT_DECAY * (1 - progress) # --------------------------------------------------------------------------- class HydraLightningModule(L.LightningModule): """Lightning wrapper. Public attrs: self.model, self.config.""" def __init__(self, config: PostSemClawConfig): super().__init__() self.config = config self.model = PostSemClawModel(config) # Model weights init must be deferred to the correct device; done by # caller after construction (to match the meta-device + to_empty() # pattern used in the legacy loop). # Time-based progress tracks the legacy loop's semantics: LR cosine # is driven by wall-clock, not step count. We capture training start # in on_train_start and TIME_BUDGET from env. self.time_budget = float( int(os.environ.get("HYDRA_TIME_BUDGET", "300")) ) self._train_start_time: float | None = None self._total_training_time = 0.0 self._last_step_end: float | None = None self._hestia_interval = int(os.environ.get("HYDRA_HESTIA_INTERVAL", "100")) self._flops_per_token = 0 self._tokens_per_step = 0 # Smoothed loss for the header-line log (matches legacy format). self._ema_beta = 0.9 self._smooth_loss = 0.0 self._bpt_ema = 0.0 self._token_bytes: torch.Tensor | None = None # ------------------------------------------------------------------ # Lifecycle # ------------------------------------------------------------------ def on_train_start(self) -> None: self._train_start_time = time.time() self._last_step_end = self._train_start_time self._flops_per_token = self.model.estimate_flops() # Tokens processed per optimizer step (pre-accum). B = int(os.environ.get("HYDRA_BATCH_SIZE", "1")) T = int(os.environ.get("HYDRA_SEQ_LEN", "512")) self._tokens_per_step = B * T # Build/cache token_bytes LUT (for bits-per-byte live metric). import prepare as _p self._token_bytes = _p.get_token_bytes(device=self.device) def configure_optimizers(self): optimizer = self.model.setup_optimizer( unembedding_lr=UNEMBEDDING_LR, embedding_lr=EMBEDDING_LR, scalar_lr=SCALAR_LR, adam_betas=ADAM_BETAS, matrix_lr=MATRIX_LR, weight_decay=WEIGHT_DECAY, ) return optimizer # ------------------------------------------------------------------ # Training step. Lightning auto-handles: autocast (via precision flag # on Trainer), backward, grad-accum, zero_grad. We only: # - split batch into (x, y) # - forward through model (autocast is established by Trainer) # - return loss (grads flow from return) # ------------------------------------------------------------------ def training_step(self, batch: torch.Tensor, batch_idx: int): # DataLoader produces (B, T+1) rows; split into input/target. # Lightning's default collate already moved batch to self.device via # the accelerator callback when pin_memory=True and device != cpu. if batch.dim() != 2: raise RuntimeError(f"Expected (B, T+1) batch, got shape {tuple(batch.shape)}") x = batch[:, :-1].contiguous() y = batch[:, 1:].contiguous() loss = self.model(x, y) # Lightning applies the grad-accum divisor automatically; we just # return the raw loss. loss.detach() is stored for logging. self._log_step(loss.detach(), y) return loss # ------------------------------------------------------------------ # Optimizer step hook: update LR / momentum / WD using time-progress. # Runs once per optimizer step (after all accum micro-batches). # ------------------------------------------------------------------ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): # Update schedules from wall-clock progress. now = time.time() if self._train_start_time is None: self._train_start_time = now self._last_step_end = now progress = min(self._total_training_time / max(self.time_budget, 1.0), 1.0) step = self.global_step lrm = _lr_multiplier(progress) mom = _muon_momentum(step) wd = _weight_decay(progress) for group in optimizer.param_groups: group["lr"] = group["initial_lr"] * lrm if group.get("kind") == "muon": group["momentum"] = mom group["weight_decay"] = wd # Grad clip (matches legacy loop). Lightning provides this via # Trainer(gradient_clip_val=1.0) but we want the exact call-site. torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0) # Hyena train-cache: we must flush accumulated micro-batch grads BACK # into the filter MLP params AFTER the accum-backward closure has run # but BEFORE the optimizer actually consumes the grads. Lightning # composes these so the closure runs inside optimizer.step(). We wrap # the closure to insert our flush at the exact right moment. # # Ordering within the wrapped closure: # 1. optimizer_closure() — runs all micro-batch forwards + backwards. # Each Hyena micro-batch backward accumulates into _k_leaf.grad. # 2. flush_hyena_pending_grads() — one-shot # torch.autograd.backward(_k_graph, _k_leaf.grad) per HyenaFilter. # Now filter MLP / pos_emb / bias params have their correct grads. # # No-op when HYDRA_HYENA_TRAIN_CACHE=0 or no Hyena blocks exist. _has_flush = hasattr(self.model, "flush_hyena_pending_grads") if _has_flush: _orig_closure = optimizer_closure def _wrapped_closure(): result = _orig_closure() self.model.flush_hyena_pending_grads() return result effective_closure = _wrapped_closure else: effective_closure = optimizer_closure # Run the step (this is what Lightning would have done for us). optimizer.step(closure=effective_closure) self.model.zero_grad(set_to_none=True) # Hyena filter-rfft cache invalidation. No-op if: # (a) no Hyena layers are in the model, or # (b) HYDRA_HYENA_FILTER_CACHE=0 and HYDRA_HYENA_TRAIN_CACHE=0 # (the operators never populated either cache) # In either case this is a handful of Python attribute resets. if hasattr(self.model, "invalidate_hyena_caches"): self.model.invalidate_hyena_caches() # Hestia QAT snap every N steps. Temperature anneals every step. progress_now = (now - self._train_start_time) / max(self.time_budget, 1.0) self.model.hestia.anneal_temperature(progress_now) if self._hestia_interval > 0 and step % self._hestia_interval == 0: self.model.hestia.apply_to(self.model) # SDR SOM update when the model stashed an sdr in the last forward. _last_sdr = getattr(self.model, "_last_sdr", None) if _last_sdr is not None and hasattr(self.model.sdr_semantic, "maybe_som_update"): # x from the last training_step is not available here without # captured state; the legacy loop passed (x, _last_sdr). To keep # the interface clean we pass the last batch's x via a buffer. # Since _last_sdr is derived from idx, we reuse self._last_x. if getattr(self, "_last_x", None) is not None: self.model.sdr_semantic.maybe_som_update(self._last_x, _last_sdr) # Advance the wall-clock counter for LR schedule (matches legacy # behavior which incremented only after the first warm-up step). dt = now - (self._last_step_end or now) self._last_step_end = now if step > 10: self._total_training_time += dt # ------------------------------------------------------------------ # Logging — mirrors the step=NNNNN line format of the legacy loop so # grep/tee pipelines keep working. # ------------------------------------------------------------------ def _log_step(self, loss: torch.Tensor, y: torch.Tensor) -> None: # Stash the current x so optimizer_step can drive SOM update. self._last_x = None # reset; we will set it below. # We don't have x here (already discarded); emit a None marker that # the SOM hook will silently skip if absent. loss_f = float(loss.item()) if not math.isfinite(loss_f) or loss_f > 100: # Let Lightning raise / the trainer callbacks handle this. self.log("train_loss_nan", 1.0) return step = self.global_step self._smooth_loss = ( self._ema_beta * self._smooth_loss + (1 - self._ema_beta) * loss_f ) debiased = self._smooth_loss / max(1 - self._ema_beta ** (step + 1), 1e-9) dt = max(time.time() - (self._last_step_end or time.time()), 1e-6) tps = int(self._tokens_per_step / dt) if dt > 0 else 0 mfu = ( 100.0 * self._flops_per_token * self._tokens_per_step / dt / GPU_BF16_PEAK_FLOPS if dt > 0 else 0.0 ) # bpb live: y flat -> token_bytes LUT -> avg bytes/token bpt = debiased / math.log(2) if self._token_bytes is not None: with torch.no_grad(): y_flat = y.reshape(-1) nbytes = self._token_bytes[y_flat] mask = nbytes > 0 denom = mask.sum().clamp(min=1).float() avg_bpt = (nbytes.float() * mask.float()).sum() / denom bpt_batch = float(avg_bpt.item()) if step == 0 or self._bpt_ema <= 0.0: self._bpt_ema = bpt_batch else: self._bpt_ema = 0.98 * self._bpt_ema + 0.02 * bpt_batch bpb = bpt / max(self._bpt_ema, 1e-6) vram = ( torch.cuda.memory_allocated() / 1024 / 1024 if torch.cuda.is_available() else 0.0 ) self.log_dict( { "train/loss": debiased, "train/bpb": bpb, "train/bpt": bpt, "train/tps": float(tps), "train/mfu": float(mfu), "train/vram_mib": float(vram), }, prog_bar=False, on_step=True, on_epoch=False, ) # Match legacy one-line format: "step=NNNNN loss=x bpb=y tps=z ..." print( f"step={step:05d} loss={debiased:.4f} bpb={bpb:.4f} " f"bpt={bpt:.3f} bpt_div={self._bpt_ema:.2f} " f"tps={tps} dt_ms={dt*1000:.0f} mfu={mfu:.1f} " f"vram={vram:.0f}MiB", flush=True, )