File size: 13,654 Bytes
e317e25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
"""LightningModule wrapping PostSemClawModel.

Thin adapter. The model and the MuonAdamW optimizer are unchanged. This
module implements:

  β€’ configure_optimizers β€” returns the existing MuonAdamW (subclass of
    torch.optim.Optimizer) built by model.setup_optimizer. Lightning accepts
    this directly.
  β€’ training_step β€” splits (B, T+1) batches into (x, y), forwards through
    the model, logs loss / bpb / tps / mfu / vram. Preserves the
    sampled-softmax path inside PostSemClawModel (no changes there).
  β€’ optimizer_step β€” before each step we update LR + muon momentum + WD
    using the same time-progress schedule as hydra/training.py
    (get_lr_multiplier / get_muon_momentum / get_weight_decay). Lightning
    handles grad accumulation via Trainer(accumulate_grad_batches=N).

The SDR SOM update and Hestia QAT snap are called at the same cadence as
the legacy loop, but inline on the main thread (Lightning provides its own
callbacks for async work if we need to extract them later β€” keeping it
simple for now).

Env vars respected:
  HYDRA_TIME_BUDGET          β€” wall-clock budget (s) used for LR schedule
                                and as Trainer max_time
  HYDRA_HESTIA_INTERVAL      β€” steps between Hestia snaps (default 100)
  HYDRA_BATCH_SIZE           β€” device batch size (for throughput calc)
  HYDRA_SEQ_LEN              β€” sequence length (for throughput calc)
"""
from __future__ import annotations

import math
import os
import time

import torch
import lightning as L

from hydra.config import (
    ADAM_BETAS,
    EMBEDDING_LR,
    FINAL_LR_FRAC,
    GPU_BF16_PEAK_FLOPS,
    MATRIX_LR,
    SCALAR_LR,
    UNEMBEDDING_LR,
    WARMUP_RATIO,
    WEIGHT_DECAY,
    PostSemClawConfig,
)
from hydra.model import PostSemClawModel


# ---------------------------------------------------------------------------
# LR / momentum / wd schedules β€” verbatim copy of hydra/training.py so the
# curves match exactly. Kept here to avoid import cycles.
# ---------------------------------------------------------------------------


def _lr_multiplier(progress: float) -> float:
    if progress < WARMUP_RATIO:
        return progress / WARMUP_RATIO if WARMUP_RATIO > 0 else 1.0
    decay_progress = (progress - WARMUP_RATIO) / max(1.0 - WARMUP_RATIO, 1e-9)
    return FINAL_LR_FRAC + 0.5 * (1.0 - FINAL_LR_FRAC) * (
        1 + math.cos(math.pi * decay_progress)
    )


def _muon_momentum(step: int) -> float:
    frac = min(step / 300.0, 1.0)
    return (1 - frac) * 0.85 + frac * 0.95


def _weight_decay(progress: float) -> float:
    return WEIGHT_DECAY * (1 - progress)


# ---------------------------------------------------------------------------


class HydraLightningModule(L.LightningModule):
    """Lightning wrapper. Public attrs: self.model, self.config."""

    def __init__(self, config: PostSemClawConfig):
        super().__init__()
        self.config = config
        self.model = PostSemClawModel(config)
        # Model weights init must be deferred to the correct device; done by
        # caller after construction (to match the meta-device + to_empty()
        # pattern used in the legacy loop).

        # Time-based progress tracks the legacy loop's semantics: LR cosine
        # is driven by wall-clock, not step count. We capture training start
        # in on_train_start and TIME_BUDGET from env.
        self.time_budget = float(
            int(os.environ.get("HYDRA_TIME_BUDGET", "300"))
        )
        self._train_start_time: float | None = None
        self._total_training_time = 0.0
        self._last_step_end: float | None = None
        self._hestia_interval = int(os.environ.get("HYDRA_HESTIA_INTERVAL", "100"))
        self._flops_per_token = 0
        self._tokens_per_step = 0

        # Smoothed loss for the header-line log (matches legacy format).
        self._ema_beta = 0.9
        self._smooth_loss = 0.0
        self._bpt_ema = 0.0
        self._token_bytes: torch.Tensor | None = None

    # ------------------------------------------------------------------
    # Lifecycle
    # ------------------------------------------------------------------

    def on_train_start(self) -> None:
        self._train_start_time = time.time()
        self._last_step_end = self._train_start_time
        self._flops_per_token = self.model.estimate_flops()
        # Tokens processed per optimizer step (pre-accum).
        B = int(os.environ.get("HYDRA_BATCH_SIZE", "1"))
        T = int(os.environ.get("HYDRA_SEQ_LEN", "512"))
        self._tokens_per_step = B * T

        # Build/cache token_bytes LUT (for bits-per-byte live metric).
        import prepare as _p
        self._token_bytes = _p.get_token_bytes(device=self.device)

    def configure_optimizers(self):
        optimizer = self.model.setup_optimizer(
            unembedding_lr=UNEMBEDDING_LR,
            embedding_lr=EMBEDDING_LR,
            scalar_lr=SCALAR_LR,
            adam_betas=ADAM_BETAS,
            matrix_lr=MATRIX_LR,
            weight_decay=WEIGHT_DECAY,
        )
        return optimizer

    # ------------------------------------------------------------------
    # Training step. Lightning auto-handles: autocast (via precision flag
    # on Trainer), backward, grad-accum, zero_grad. We only:
    #   - split batch into (x, y)
    #   - forward through model (autocast is established by Trainer)
    #   - return loss (grads flow from return)
    # ------------------------------------------------------------------

    def training_step(self, batch: torch.Tensor, batch_idx: int):
        # DataLoader produces (B, T+1) rows; split into input/target.
        # Lightning's default collate already moved batch to self.device via
        # the accelerator callback when pin_memory=True and device != cpu.
        if batch.dim() != 2:
            raise RuntimeError(f"Expected (B, T+1) batch, got shape {tuple(batch.shape)}")
        x = batch[:, :-1].contiguous()
        y = batch[:, 1:].contiguous()

        loss = self.model(x, y)
        # Lightning applies the grad-accum divisor automatically; we just
        # return the raw loss. loss.detach() is stored for logging.
        self._log_step(loss.detach(), y)
        return loss

    # ------------------------------------------------------------------
    # Optimizer step hook: update LR / momentum / WD using time-progress.
    # Runs once per optimizer step (after all accum micro-batches).
    # ------------------------------------------------------------------

    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
        # Update schedules from wall-clock progress.
        now = time.time()
        if self._train_start_time is None:
            self._train_start_time = now
            self._last_step_end = now
        progress = min(self._total_training_time / max(self.time_budget, 1.0), 1.0)

        step = self.global_step
        lrm = _lr_multiplier(progress)
        mom = _muon_momentum(step)
        wd = _weight_decay(progress)
        for group in optimizer.param_groups:
            group["lr"] = group["initial_lr"] * lrm
            if group.get("kind") == "muon":
                group["momentum"] = mom
                group["weight_decay"] = wd

        # Grad clip (matches legacy loop). Lightning provides this via
        # Trainer(gradient_clip_val=1.0) but we want the exact call-site.
        torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)

        # Hyena train-cache: we must flush accumulated micro-batch grads BACK
        # into the filter MLP params AFTER the accum-backward closure has run
        # but BEFORE the optimizer actually consumes the grads. Lightning
        # composes these so the closure runs inside optimizer.step(). We wrap
        # the closure to insert our flush at the exact right moment.
        #
        # Ordering within the wrapped closure:
        #   1. optimizer_closure() β€” runs all micro-batch forwards + backwards.
        #      Each Hyena micro-batch backward accumulates into _k_leaf.grad.
        #   2. flush_hyena_pending_grads() β€” one-shot
        #      torch.autograd.backward(_k_graph, _k_leaf.grad) per HyenaFilter.
        #      Now filter MLP / pos_emb / bias params have their correct grads.
        #
        # No-op when HYDRA_HYENA_TRAIN_CACHE=0 or no Hyena blocks exist.
        _has_flush = hasattr(self.model, "flush_hyena_pending_grads")
        if _has_flush:
            _orig_closure = optimizer_closure

            def _wrapped_closure():
                result = _orig_closure()
                self.model.flush_hyena_pending_grads()
                return result

            effective_closure = _wrapped_closure
        else:
            effective_closure = optimizer_closure

        # Run the step (this is what Lightning would have done for us).
        optimizer.step(closure=effective_closure)
        self.model.zero_grad(set_to_none=True)

        # Hyena filter-rfft cache invalidation. No-op if:
        #   (a) no Hyena layers are in the model, or
        #   (b) HYDRA_HYENA_FILTER_CACHE=0 and HYDRA_HYENA_TRAIN_CACHE=0
        #       (the operators never populated either cache)
        # In either case this is a handful of Python attribute resets.
        if hasattr(self.model, "invalidate_hyena_caches"):
            self.model.invalidate_hyena_caches()

        # Hestia QAT snap every N steps. Temperature anneals every step.
        progress_now = (now - self._train_start_time) / max(self.time_budget, 1.0)
        self.model.hestia.anneal_temperature(progress_now)
        if self._hestia_interval > 0 and step % self._hestia_interval == 0:
            self.model.hestia.apply_to(self.model)

        # SDR SOM update when the model stashed an sdr in the last forward.
        _last_sdr = getattr(self.model, "_last_sdr", None)
        if _last_sdr is not None and hasattr(self.model.sdr_semantic, "maybe_som_update"):
            # x from the last training_step is not available here without
            # captured state; the legacy loop passed (x, _last_sdr). To keep
            # the interface clean we pass the last batch's x via a buffer.
            # Since _last_sdr is derived from idx, we reuse self._last_x.
            if getattr(self, "_last_x", None) is not None:
                self.model.sdr_semantic.maybe_som_update(self._last_x, _last_sdr)

        # Advance the wall-clock counter for LR schedule (matches legacy
        # behavior which incremented only after the first warm-up step).
        dt = now - (self._last_step_end or now)
        self._last_step_end = now
        if step > 10:
            self._total_training_time += dt

    # ------------------------------------------------------------------
    # Logging β€” mirrors the step=NNNNN line format of the legacy loop so
    # grep/tee pipelines keep working.
    # ------------------------------------------------------------------

    def _log_step(self, loss: torch.Tensor, y: torch.Tensor) -> None:
        # Stash the current x so optimizer_step can drive SOM update.
        self._last_x = None  # reset; we will set it below.
        # We don't have x here (already discarded); emit a None marker that
        # the SOM hook will silently skip if absent.

        loss_f = float(loss.item())
        if not math.isfinite(loss_f) or loss_f > 100:
            # Let Lightning raise / the trainer callbacks handle this.
            self.log("train_loss_nan", 1.0)
            return

        step = self.global_step
        self._smooth_loss = (
            self._ema_beta * self._smooth_loss + (1 - self._ema_beta) * loss_f
        )
        debiased = self._smooth_loss / max(1 - self._ema_beta ** (step + 1), 1e-9)
        dt = max(time.time() - (self._last_step_end or time.time()), 1e-6)
        tps = int(self._tokens_per_step / dt) if dt > 0 else 0
        mfu = (
            100.0
            * self._flops_per_token
            * self._tokens_per_step
            / dt
            / GPU_BF16_PEAK_FLOPS
            if dt > 0
            else 0.0
        )

        # bpb live: y flat -> token_bytes LUT -> avg bytes/token
        bpt = debiased / math.log(2)
        if self._token_bytes is not None:
            with torch.no_grad():
                y_flat = y.reshape(-1)
                nbytes = self._token_bytes[y_flat]
                mask = nbytes > 0
                denom = mask.sum().clamp(min=1).float()
                avg_bpt = (nbytes.float() * mask.float()).sum() / denom
                bpt_batch = float(avg_bpt.item())
            if step == 0 or self._bpt_ema <= 0.0:
                self._bpt_ema = bpt_batch
            else:
                self._bpt_ema = 0.98 * self._bpt_ema + 0.02 * bpt_batch
        bpb = bpt / max(self._bpt_ema, 1e-6)
        vram = (
            torch.cuda.memory_allocated() / 1024 / 1024
            if torch.cuda.is_available()
            else 0.0
        )

        self.log_dict(
            {
                "train/loss": debiased,
                "train/bpb": bpb,
                "train/bpt": bpt,
                "train/tps": float(tps),
                "train/mfu": float(mfu),
                "train/vram_mib": float(vram),
            },
            prog_bar=False,
            on_step=True,
            on_epoch=False,
        )

        # Match legacy one-line format: "step=NNNNN loss=x bpb=y tps=z ..."
        print(
            f"step={step:05d} loss={debiased:.4f} bpb={bpb:.4f} "
            f"bpt={bpt:.3f} bpt_div={self._bpt_ema:.2f} "
            f"tps={tps} dt_ms={dt*1000:.0f} mfu={mfu:.1f} "
            f"vram={vram:.0f}MiB",
            flush=True,
        )