| """LightningModule wrapping PostSemClawModel. |
| |
| Thin adapter. The model and the MuonAdamW optimizer are unchanged. This |
| module implements: |
| |
| • configure_optimizers — returns the existing MuonAdamW (subclass of |
| torch.optim.Optimizer) built by model.setup_optimizer. Lightning accepts |
| this directly. |
| • training_step — splits (B, T+1) batches into (x, y), forwards through |
| the model, logs loss / bpb / tps / mfu / vram. Preserves the |
| sampled-softmax path inside PostSemClawModel (no changes there). |
| • optimizer_step — before each step we update LR + muon momentum + WD |
| using the same time-progress schedule as hydra/training.py |
| (get_lr_multiplier / get_muon_momentum / get_weight_decay). Lightning |
| handles grad accumulation via Trainer(accumulate_grad_batches=N). |
| |
| The SDR SOM update and Hestia QAT snap are called at the same cadence as |
| the legacy loop, but inline on the main thread (Lightning provides its own |
| callbacks for async work if we need to extract them later — keeping it |
| simple for now). |
| |
| Env vars respected: |
| HYDRA_TIME_BUDGET — wall-clock budget (s) used for LR schedule |
| and as Trainer max_time |
| HYDRA_HESTIA_INTERVAL — steps between Hestia snaps (default 100) |
| HYDRA_BATCH_SIZE — device batch size (for throughput calc) |
| HYDRA_SEQ_LEN — sequence length (for throughput calc) |
| """ |
| from __future__ import annotations |
|
|
| import math |
| import os |
| import time |
|
|
| import torch |
| import lightning as L |
|
|
| from hydra.config import ( |
| ADAM_BETAS, |
| EMBEDDING_LR, |
| FINAL_LR_FRAC, |
| GPU_BF16_PEAK_FLOPS, |
| MATRIX_LR, |
| SCALAR_LR, |
| UNEMBEDDING_LR, |
| WARMUP_RATIO, |
| WEIGHT_DECAY, |
| PostSemClawConfig, |
| ) |
| from hydra.model import PostSemClawModel |
|
|
|
|
| |
| |
| |
| |
|
|
|
|
| def _lr_multiplier(progress: float) -> float: |
| if progress < WARMUP_RATIO: |
| return progress / WARMUP_RATIO if WARMUP_RATIO > 0 else 1.0 |
| decay_progress = (progress - WARMUP_RATIO) / max(1.0 - WARMUP_RATIO, 1e-9) |
| return FINAL_LR_FRAC + 0.5 * (1.0 - FINAL_LR_FRAC) * ( |
| 1 + math.cos(math.pi * decay_progress) |
| ) |
|
|
|
|
| def _muon_momentum(step: int) -> float: |
| frac = min(step / 300.0, 1.0) |
| return (1 - frac) * 0.85 + frac * 0.95 |
|
|
|
|
| def _weight_decay(progress: float) -> float: |
| return WEIGHT_DECAY * (1 - progress) |
|
|
|
|
| |
|
|
|
|
| class HydraLightningModule(L.LightningModule): |
| """Lightning wrapper. Public attrs: self.model, self.config.""" |
|
|
| def __init__(self, config: PostSemClawConfig): |
| super().__init__() |
| self.config = config |
| self.model = PostSemClawModel(config) |
| |
| |
| |
|
|
| |
| |
| |
| self.time_budget = float( |
| int(os.environ.get("HYDRA_TIME_BUDGET", "300")) |
| ) |
| self._train_start_time: float | None = None |
| self._total_training_time = 0.0 |
| self._last_step_end: float | None = None |
| self._hestia_interval = int(os.environ.get("HYDRA_HESTIA_INTERVAL", "100")) |
| self._flops_per_token = 0 |
| self._tokens_per_step = 0 |
|
|
| |
| self._ema_beta = 0.9 |
| self._smooth_loss = 0.0 |
| self._bpt_ema = 0.0 |
| self._token_bytes: torch.Tensor | None = None |
|
|
| |
| |
| |
|
|
| def on_train_start(self) -> None: |
| self._train_start_time = time.time() |
| self._last_step_end = self._train_start_time |
| self._flops_per_token = self.model.estimate_flops() |
| |
| B = int(os.environ.get("HYDRA_BATCH_SIZE", "1")) |
| T = int(os.environ.get("HYDRA_SEQ_LEN", "512")) |
| self._tokens_per_step = B * T |
|
|
| |
| import prepare as _p |
| self._token_bytes = _p.get_token_bytes(device=self.device) |
|
|
| def configure_optimizers(self): |
| optimizer = self.model.setup_optimizer( |
| unembedding_lr=UNEMBEDDING_LR, |
| embedding_lr=EMBEDDING_LR, |
| scalar_lr=SCALAR_LR, |
| adam_betas=ADAM_BETAS, |
| matrix_lr=MATRIX_LR, |
| weight_decay=WEIGHT_DECAY, |
| ) |
| return optimizer |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| def training_step(self, batch: torch.Tensor, batch_idx: int): |
| |
| |
| |
| if batch.dim() != 2: |
| raise RuntimeError(f"Expected (B, T+1) batch, got shape {tuple(batch.shape)}") |
| x = batch[:, :-1].contiguous() |
| y = batch[:, 1:].contiguous() |
|
|
| loss = self.model(x, y) |
| |
| |
| self._log_step(loss.detach(), y) |
| return loss |
|
|
| |
| |
| |
| |
|
|
| def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): |
| |
| now = time.time() |
| if self._train_start_time is None: |
| self._train_start_time = now |
| self._last_step_end = now |
| progress = min(self._total_training_time / max(self.time_budget, 1.0), 1.0) |
|
|
| step = self.global_step |
| lrm = _lr_multiplier(progress) |
| mom = _muon_momentum(step) |
| wd = _weight_decay(progress) |
| for group in optimizer.param_groups: |
| group["lr"] = group["initial_lr"] * lrm |
| if group.get("kind") == "muon": |
| group["momentum"] = mom |
| group["weight_decay"] = wd |
|
|
| |
| |
| torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| _has_flush = hasattr(self.model, "flush_hyena_pending_grads") |
| if _has_flush: |
| _orig_closure = optimizer_closure |
|
|
| def _wrapped_closure(): |
| result = _orig_closure() |
| self.model.flush_hyena_pending_grads() |
| return result |
|
|
| effective_closure = _wrapped_closure |
| else: |
| effective_closure = optimizer_closure |
|
|
| |
| optimizer.step(closure=effective_closure) |
| self.model.zero_grad(set_to_none=True) |
|
|
| |
| |
| |
| |
| |
| if hasattr(self.model, "invalidate_hyena_caches"): |
| self.model.invalidate_hyena_caches() |
|
|
| |
| progress_now = (now - self._train_start_time) / max(self.time_budget, 1.0) |
| self.model.hestia.anneal_temperature(progress_now) |
| if self._hestia_interval > 0 and step % self._hestia_interval == 0: |
| self.model.hestia.apply_to(self.model) |
|
|
| |
| _last_sdr = getattr(self.model, "_last_sdr", None) |
| if _last_sdr is not None and hasattr(self.model.sdr_semantic, "maybe_som_update"): |
| |
| |
| |
| |
| if getattr(self, "_last_x", None) is not None: |
| self.model.sdr_semantic.maybe_som_update(self._last_x, _last_sdr) |
|
|
| |
| |
| dt = now - (self._last_step_end or now) |
| self._last_step_end = now |
| if step > 10: |
| self._total_training_time += dt |
|
|
| |
| |
| |
| |
|
|
| def _log_step(self, loss: torch.Tensor, y: torch.Tensor) -> None: |
| |
| self._last_x = None |
| |
| |
|
|
| loss_f = float(loss.item()) |
| if not math.isfinite(loss_f) or loss_f > 100: |
| |
| self.log("train_loss_nan", 1.0) |
| return |
|
|
| step = self.global_step |
| self._smooth_loss = ( |
| self._ema_beta * self._smooth_loss + (1 - self._ema_beta) * loss_f |
| ) |
| debiased = self._smooth_loss / max(1 - self._ema_beta ** (step + 1), 1e-9) |
| dt = max(time.time() - (self._last_step_end or time.time()), 1e-6) |
| tps = int(self._tokens_per_step / dt) if dt > 0 else 0 |
| mfu = ( |
| 100.0 |
| * self._flops_per_token |
| * self._tokens_per_step |
| / dt |
| / GPU_BF16_PEAK_FLOPS |
| if dt > 0 |
| else 0.0 |
| ) |
|
|
| |
| bpt = debiased / math.log(2) |
| if self._token_bytes is not None: |
| with torch.no_grad(): |
| y_flat = y.reshape(-1) |
| nbytes = self._token_bytes[y_flat] |
| mask = nbytes > 0 |
| denom = mask.sum().clamp(min=1).float() |
| avg_bpt = (nbytes.float() * mask.float()).sum() / denom |
| bpt_batch = float(avg_bpt.item()) |
| if step == 0 or self._bpt_ema <= 0.0: |
| self._bpt_ema = bpt_batch |
| else: |
| self._bpt_ema = 0.98 * self._bpt_ema + 0.02 * bpt_batch |
| bpb = bpt / max(self._bpt_ema, 1e-6) |
| vram = ( |
| torch.cuda.memory_allocated() / 1024 / 1024 |
| if torch.cuda.is_available() |
| else 0.0 |
| ) |
|
|
| self.log_dict( |
| { |
| "train/loss": debiased, |
| "train/bpb": bpb, |
| "train/bpt": bpt, |
| "train/tps": float(tps), |
| "train/mfu": float(mfu), |
| "train/vram_mib": float(vram), |
| }, |
| prog_bar=False, |
| on_step=True, |
| on_epoch=False, |
| ) |
|
|
| |
| print( |
| f"step={step:05d} loss={debiased:.4f} bpb={bpb:.4f} " |
| f"bpt={bpt:.3f} bpt_div={self._bpt_ema:.2f} " |
| f"tps={tps} dt_ms={dt*1000:.0f} mfu={mfu:.1f} " |
| f"vram={vram:.0f}MiB", |
| flush=True, |
| ) |
|
|