Spaces:
Runtime error
Runtime error
| """LightningModule wrapping PostSemClawModel. | |
| Thin adapter. The model and the MuonAdamW optimizer are unchanged. This | |
| module implements: | |
| • configure_optimizers — returns the existing MuonAdamW (subclass of | |
| torch.optim.Optimizer) built by model.setup_optimizer. Lightning accepts | |
| this directly. | |
| • training_step — splits (B, T+1) batches into (x, y), forwards through | |
| the model, logs loss / bpb / tps / mfu / vram. Preserves the | |
| sampled-softmax path inside PostSemClawModel (no changes there). | |
| • optimizer_step — before each step we update LR + muon momentum + WD | |
| using the same time-progress schedule as hydra/training.py | |
| (get_lr_multiplier / get_muon_momentum / get_weight_decay). Lightning | |
| handles grad accumulation via Trainer(accumulate_grad_batches=N). | |
| The SDR SOM update and Hestia QAT snap are called at the same cadence as | |
| the legacy loop, but inline on the main thread (Lightning provides its own | |
| callbacks for async work if we need to extract them later — keeping it | |
| simple for now). | |
| Env vars respected: | |
| HYDRA_TIME_BUDGET — wall-clock budget (s) used for LR schedule | |
| and as Trainer max_time | |
| HYDRA_HESTIA_INTERVAL — steps between Hestia snaps (default 100) | |
| HYDRA_BATCH_SIZE — device batch size (for throughput calc) | |
| HYDRA_SEQ_LEN — sequence length (for throughput calc) | |
| """ | |
| from __future__ import annotations | |
| import math | |
| import os | |
| import time | |
| import torch | |
| import lightning as L | |
| from hydra.config import ( | |
| ADAM_BETAS, | |
| EMBEDDING_LR, | |
| FINAL_LR_FRAC, | |
| GPU_BF16_PEAK_FLOPS, | |
| MATRIX_LR, | |
| SCALAR_LR, | |
| UNEMBEDDING_LR, | |
| WARMUP_RATIO, | |
| WEIGHT_DECAY, | |
| PostSemClawConfig, | |
| ) | |
| from hydra.model import PostSemClawModel | |
| # --------------------------------------------------------------------------- | |
| # LR / momentum / wd schedules — verbatim copy of hydra/training.py so the | |
| # curves match exactly. Kept here to avoid import cycles. | |
| # --------------------------------------------------------------------------- | |
| def _lr_multiplier(progress: float) -> float: | |
| if progress < WARMUP_RATIO: | |
| return progress / WARMUP_RATIO if WARMUP_RATIO > 0 else 1.0 | |
| decay_progress = (progress - WARMUP_RATIO) / max(1.0 - WARMUP_RATIO, 1e-9) | |
| return FINAL_LR_FRAC + 0.5 * (1.0 - FINAL_LR_FRAC) * ( | |
| 1 + math.cos(math.pi * decay_progress) | |
| ) | |
| def _muon_momentum(step: int) -> float: | |
| frac = min(step / 300.0, 1.0) | |
| return (1 - frac) * 0.85 + frac * 0.95 | |
| def _weight_decay(progress: float) -> float: | |
| return WEIGHT_DECAY * (1 - progress) | |
| # --------------------------------------------------------------------------- | |
| class HydraLightningModule(L.LightningModule): | |
| """Lightning wrapper. Public attrs: self.model, self.config.""" | |
| def __init__(self, config: PostSemClawConfig): | |
| super().__init__() | |
| self.config = config | |
| self.model = PostSemClawModel(config) | |
| # Model weights init must be deferred to the correct device; done by | |
| # caller after construction (to match the meta-device + to_empty() | |
| # pattern used in the legacy loop). | |
| # Time-based progress tracks the legacy loop's semantics: LR cosine | |
| # is driven by wall-clock, not step count. We capture training start | |
| # in on_train_start and TIME_BUDGET from env. | |
| self.time_budget = float( | |
| int(os.environ.get("HYDRA_TIME_BUDGET", "300")) | |
| ) | |
| self._train_start_time: float | None = None | |
| self._total_training_time = 0.0 | |
| self._last_step_end: float | None = None | |
| self._hestia_interval = int(os.environ.get("HYDRA_HESTIA_INTERVAL", "100")) | |
| self._flops_per_token = 0 | |
| self._tokens_per_step = 0 | |
| # Smoothed loss for the header-line log (matches legacy format). | |
| self._ema_beta = 0.9 | |
| self._smooth_loss = 0.0 | |
| self._bpt_ema = 0.0 | |
| self._token_bytes: torch.Tensor | None = None | |
| # ------------------------------------------------------------------ | |
| # Lifecycle | |
| # ------------------------------------------------------------------ | |
| def on_train_start(self) -> None: | |
| self._train_start_time = time.time() | |
| self._last_step_end = self._train_start_time | |
| self._flops_per_token = self.model.estimate_flops() | |
| # Tokens processed per optimizer step (pre-accum). | |
| B = int(os.environ.get("HYDRA_BATCH_SIZE", "1")) | |
| T = int(os.environ.get("HYDRA_SEQ_LEN", "512")) | |
| self._tokens_per_step = B * T | |
| # Build/cache token_bytes LUT (for bits-per-byte live metric). | |
| import prepare as _p | |
| self._token_bytes = _p.get_token_bytes(device=self.device) | |
| def configure_optimizers(self): | |
| optimizer = self.model.setup_optimizer( | |
| unembedding_lr=UNEMBEDDING_LR, | |
| embedding_lr=EMBEDDING_LR, | |
| scalar_lr=SCALAR_LR, | |
| adam_betas=ADAM_BETAS, | |
| matrix_lr=MATRIX_LR, | |
| weight_decay=WEIGHT_DECAY, | |
| ) | |
| return optimizer | |
| # ------------------------------------------------------------------ | |
| # Training step. Lightning auto-handles: autocast (via precision flag | |
| # on Trainer), backward, grad-accum, zero_grad. We only: | |
| # - split batch into (x, y) | |
| # - forward through model (autocast is established by Trainer) | |
| # - return loss (grads flow from return) | |
| # ------------------------------------------------------------------ | |
| def training_step(self, batch: torch.Tensor, batch_idx: int): | |
| # DataLoader produces (B, T+1) rows; split into input/target. | |
| # Lightning's default collate already moved batch to self.device via | |
| # the accelerator callback when pin_memory=True and device != cpu. | |
| if batch.dim() != 2: | |
| raise RuntimeError(f"Expected (B, T+1) batch, got shape {tuple(batch.shape)}") | |
| x = batch[:, :-1].contiguous() | |
| y = batch[:, 1:].contiguous() | |
| loss = self.model(x, y) | |
| # Lightning applies the grad-accum divisor automatically; we just | |
| # return the raw loss. loss.detach() is stored for logging. | |
| self._log_step(loss.detach(), y) | |
| return loss | |
| # ------------------------------------------------------------------ | |
| # Optimizer step hook: update LR / momentum / WD using time-progress. | |
| # Runs once per optimizer step (after all accum micro-batches). | |
| # ------------------------------------------------------------------ | |
| def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): | |
| # Update schedules from wall-clock progress. | |
| now = time.time() | |
| if self._train_start_time is None: | |
| self._train_start_time = now | |
| self._last_step_end = now | |
| progress = min(self._total_training_time / max(self.time_budget, 1.0), 1.0) | |
| step = self.global_step | |
| lrm = _lr_multiplier(progress) | |
| mom = _muon_momentum(step) | |
| wd = _weight_decay(progress) | |
| for group in optimizer.param_groups: | |
| group["lr"] = group["initial_lr"] * lrm | |
| if group.get("kind") == "muon": | |
| group["momentum"] = mom | |
| group["weight_decay"] = wd | |
| # Grad clip (matches legacy loop). Lightning provides this via | |
| # Trainer(gradient_clip_val=1.0) but we want the exact call-site. | |
| torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0) | |
| # Hyena train-cache: we must flush accumulated micro-batch grads BACK | |
| # into the filter MLP params AFTER the accum-backward closure has run | |
| # but BEFORE the optimizer actually consumes the grads. Lightning | |
| # composes these so the closure runs inside optimizer.step(). We wrap | |
| # the closure to insert our flush at the exact right moment. | |
| # | |
| # Ordering within the wrapped closure: | |
| # 1. optimizer_closure() — runs all micro-batch forwards + backwards. | |
| # Each Hyena micro-batch backward accumulates into _k_leaf.grad. | |
| # 2. flush_hyena_pending_grads() — one-shot | |
| # torch.autograd.backward(_k_graph, _k_leaf.grad) per HyenaFilter. | |
| # Now filter MLP / pos_emb / bias params have their correct grads. | |
| # | |
| # No-op when HYDRA_HYENA_TRAIN_CACHE=0 or no Hyena blocks exist. | |
| _has_flush = hasattr(self.model, "flush_hyena_pending_grads") | |
| if _has_flush: | |
| _orig_closure = optimizer_closure | |
| def _wrapped_closure(): | |
| result = _orig_closure() | |
| self.model.flush_hyena_pending_grads() | |
| return result | |
| effective_closure = _wrapped_closure | |
| else: | |
| effective_closure = optimizer_closure | |
| # Run the step (this is what Lightning would have done for us). | |
| optimizer.step(closure=effective_closure) | |
| self.model.zero_grad(set_to_none=True) | |
| # Hyena filter-rfft cache invalidation. No-op if: | |
| # (a) no Hyena layers are in the model, or | |
| # (b) HYDRA_HYENA_FILTER_CACHE=0 and HYDRA_HYENA_TRAIN_CACHE=0 | |
| # (the operators never populated either cache) | |
| # In either case this is a handful of Python attribute resets. | |
| if hasattr(self.model, "invalidate_hyena_caches"): | |
| self.model.invalidate_hyena_caches() | |
| # Hestia QAT snap every N steps. Temperature anneals every step. | |
| progress_now = (now - self._train_start_time) / max(self.time_budget, 1.0) | |
| self.model.hestia.anneal_temperature(progress_now) | |
| if self._hestia_interval > 0 and step % self._hestia_interval == 0: | |
| self.model.hestia.apply_to(self.model) | |
| # SDR SOM update when the model stashed an sdr in the last forward. | |
| _last_sdr = getattr(self.model, "_last_sdr", None) | |
| if _last_sdr is not None and hasattr(self.model.sdr_semantic, "maybe_som_update"): | |
| # x from the last training_step is not available here without | |
| # captured state; the legacy loop passed (x, _last_sdr). To keep | |
| # the interface clean we pass the last batch's x via a buffer. | |
| # Since _last_sdr is derived from idx, we reuse self._last_x. | |
| if getattr(self, "_last_x", None) is not None: | |
| self.model.sdr_semantic.maybe_som_update(self._last_x, _last_sdr) | |
| # Advance the wall-clock counter for LR schedule (matches legacy | |
| # behavior which incremented only after the first warm-up step). | |
| dt = now - (self._last_step_end or now) | |
| self._last_step_end = now | |
| if step > 10: | |
| self._total_training_time += dt | |
| # ------------------------------------------------------------------ | |
| # Logging — mirrors the step=NNNNN line format of the legacy loop so | |
| # grep/tee pipelines keep working. | |
| # ------------------------------------------------------------------ | |
| def _log_step(self, loss: torch.Tensor, y: torch.Tensor) -> None: | |
| # Stash the current x so optimizer_step can drive SOM update. | |
| self._last_x = None # reset; we will set it below. | |
| # We don't have x here (already discarded); emit a None marker that | |
| # the SOM hook will silently skip if absent. | |
| loss_f = float(loss.item()) | |
| if not math.isfinite(loss_f) or loss_f > 100: | |
| # Let Lightning raise / the trainer callbacks handle this. | |
| self.log("train_loss_nan", 1.0) | |
| return | |
| step = self.global_step | |
| self._smooth_loss = ( | |
| self._ema_beta * self._smooth_loss + (1 - self._ema_beta) * loss_f | |
| ) | |
| debiased = self._smooth_loss / max(1 - self._ema_beta ** (step + 1), 1e-9) | |
| dt = max(time.time() - (self._last_step_end or time.time()), 1e-6) | |
| tps = int(self._tokens_per_step / dt) if dt > 0 else 0 | |
| mfu = ( | |
| 100.0 | |
| * self._flops_per_token | |
| * self._tokens_per_step | |
| / dt | |
| / GPU_BF16_PEAK_FLOPS | |
| if dt > 0 | |
| else 0.0 | |
| ) | |
| # bpb live: y flat -> token_bytes LUT -> avg bytes/token | |
| bpt = debiased / math.log(2) | |
| if self._token_bytes is not None: | |
| with torch.no_grad(): | |
| y_flat = y.reshape(-1) | |
| nbytes = self._token_bytes[y_flat] | |
| mask = nbytes > 0 | |
| denom = mask.sum().clamp(min=1).float() | |
| avg_bpt = (nbytes.float() * mask.float()).sum() / denom | |
| bpt_batch = float(avg_bpt.item()) | |
| if step == 0 or self._bpt_ema <= 0.0: | |
| self._bpt_ema = bpt_batch | |
| else: | |
| self._bpt_ema = 0.98 * self._bpt_ema + 0.02 * bpt_batch | |
| bpb = bpt / max(self._bpt_ema, 1e-6) | |
| vram = ( | |
| torch.cuda.memory_allocated() / 1024 / 1024 | |
| if torch.cuda.is_available() | |
| else 0.0 | |
| ) | |
| self.log_dict( | |
| { | |
| "train/loss": debiased, | |
| "train/bpb": bpb, | |
| "train/bpt": bpt, | |
| "train/tps": float(tps), | |
| "train/mfu": float(mfu), | |
| "train/vram_mib": float(vram), | |
| }, | |
| prog_bar=False, | |
| on_step=True, | |
| on_epoch=False, | |
| ) | |
| # Match legacy one-line format: "step=NNNNN loss=x bpb=y tps=z ..." | |
| print( | |
| f"step={step:05d} loss={debiased:.4f} bpb={bpb:.4f} " | |
| f"bpt={bpt:.3f} bpt_div={self._bpt_ema:.2f} " | |
| f"tps={tps} dt_ms={dt*1000:.0f} mfu={mfu:.1f} " | |
| f"vram={vram:.0f}MiB", | |
| flush=True, | |
| ) | |