Bug Report: 4 Critical Issues Found in NAS / GPU Path (branch do-not-merge-gpu-tests-from-23)
Reporter: ML Intern
Branch audited: do-not-merge-gpu-tests-from-23
Fix branch: fix-gpu-tests-from-23-v2
Severity: High β NAS trials have been wasted on dead parameters; GPU training fails with OOM / NaN / compile crashes.
Summary
During a neural-architecture-search audit of nas_helixlm.py v2.3 on the do-not-merge-gpu-tests-from-23 branch, four interrelated bugs were found in graph.py, mamba2.py, nodes.py, trainer.py, and nas_helixlm.py. All four have been fixed in branch fix-gpu-tests-from-23-v2 (see commit d6619f4).
| # | Bug | File(s) | Impact |
|---|---|---|---|
| 1 | nodes_per_column is ignored by _build_node_spec() |
graph.py |
NAS permutes a dead parameter; topology is not reproducible by seed alone |
| 2 | Mamba2 / SSM scan builds a 256-step autograd chain β OOM | mamba2.py, nodes.py |
GPU OOM on SSM configs; ~1.1 GB saved tensors per mamba2 node |
| 3 | fp16 AMP forced for seq_len > 128 β NaN |
nas_helixlm.py, trainer.py |
Immediate NaN on small models where there is no memory pressure |
| 4 | torch.compile crashes on custom Python loops |
mamba2.py, nodes.py, nas_helixlm.py |
Inductor graph-breaks / invalid kernels on GPU; compile silently skipped for SSM/Titans configs |
1. nodes_per_column is dead β _build_node_spec() ignores it
Evidence
config.py docstrings and presets (lines 22-28, 311-380) advertise nodes_per_column tuples like (2,2), (2,3,2), (3,4,4,3). Validation logic (lines 249-256) even pads/truncates the tuple to match n_columns.
However, graph.py lines 88-148 hardcode every column to:
column = [
("linear_attn" | "full_attn", {β¦}),
("swiglu", {β¦}),
("mamba2" | "ssm", {β¦}) # if use_ssm
("titans", {β¦}) # if use_titans and ci==0
("gate", {β¦}) # always appended
]
nodes_per_column is never read after validation.
Fix
Wire nodes_per_column into _build_node_spec() by repeating the [attention, swiglu] base pattern until the target count is reached. Optional SSM/Titans nodes consume one slot each. A gate is appended when there are multiple compute nodes or when ci > 0.
for ci in range(cfg.n_columns):
target = cfg.nodes_per_column[ci] # e.g. 3
# Build base: attn + swiglu
# Insert optional SSM/Titans if room
# Repeat [attn, swiglu] to fill remaining slots
# Append gate for aggregation
The RNG seed already controls lateral/vertical wiring; with this fix the node count is also deterministic and reproducible.
2. OOM from mamba2 scan's 256-step autograd chain
Evidence
_ssd_chunked_scan (and _ssm_chunked_scan in nodes.py) keeps every intermediate h tensor for backward:
h = A_c[:, t] * h + B_c[:, t] * x_c[:, t].unsqueeze(-1)
# h is (B, d_inner, d_state) β kept for ALL timesteps
For B=32, d_inner=768, d_state=64, each mamba2 node alone stores ~1.1 GB of saved tensors. Multiple columns and loops multiply this.
Fix
Wrap each chunk's inner loop in torch.utils.checkpoint:
def _chunk_scan(h_in, A_c_in, ...):
h = h_in
for t in range(chunk_size):
h = ...
ys_c.append(y_t)
return h, torch.stack(ys_c, dim=1)
h, ys_chunk = torch.utils.checkpoint.checkpoint(
_chunk_scan, h, A_c, B_c, x_c, C_c,
use_reentrant=False,
)
Only the chunk boundary h states are materialised for backward. Trade ~10β20 % extra compute for an order-of-magnitude memory reduction.
3. fp16 AMP causes NaN on seq_len β₯ 256
Evidence
nas_helixlm.py line 395 forced dtype_str = "float32" everywhere because fp16 caused immediate NaN on d β₯ 256 with LR=3e-3. The root cause is not fp16 universally, but this architecture's SSM scan, Titans memory updates, and ELU+1.0 feature maps β all of which underflow in fp16's narrow dynamic range.
Fix
Use bfloat16 instead of fp16 on GPU. bf16 shares the same 8-bit exponent range as fp32, so it does not underflow on the scan or memory updates. It is natively supported on Ampere+ (A100, H100, L4) and is typically as fast as fp16.
if torch.cuda.is_available():
dtype_str = "bfloat16"
use_amp = True
else:
dtype_str = "float32"
use_amp = False
The Trainer was also updated to skip GradScaler when the AMP dtype is bf16 (no loss-scaling needed) and to pass torch.bfloat16 to torch.amp.autocast.
4. torch.compile breaks on custom Python loops
Evidence
The inductor backend cannot compile the for t in range(chunk_size) loops with in-place h mutations inside _ssd_chunked_scan and TitansMemoryNode.forward. The old workaround in nas_helixlm.py simply skipped compilation entirely for any SSM/Titans config:
if use_ssm or use_titans:
return model, False, "skipped: SSM/Titans autograd not compile-safe"
Fix
Decorate the loop functions with @torch.compiler.disable. This tells the inductor backend to treat them as opaque ops β the rest of the model (embeddings, linear projections, attention, SwiGLU) still gets compiled.
Decorated functions:
mamba2._ssd_chunked_scannodes._ssm_chunked_scannodes.TitansMemoryNode.forward
try_compile_model in nas_helixlm.py now removes the SSM/Titans skip and attempts compilation for all configs.
Reproducing the Baseline
The sacred CPU baseline (1000 samples, seq_len=96, 14 epochs) should yield:
- Train perplexity ~ 23
- Val perplexity ~ 85β86
- Throughput ~ 1,892 tok/s
With these fixes applied, GPU configs should hit the same numbers (or better, thanks to bf16 + compile) without NaN or OOM.
Patched Files
| File | Lines changed | Nature of change |
|---|---|---|
helix_lm/graph.py |
+56 / β20 | _build_node_spec() now reads nodes_per_column |
helix_lm/mamba2.py |
+25 / β6 | @torch.compiler.disable + checkpoint on chunk scan |
helix_lm/nodes.py |
+23 / β7 | Same for _ssm_chunked_scan + TitansMemoryNode.forward |
helix_lm/trainer.py |
+29 / β11 | bf16 autocast, conditional GradScaler |
nas_helixlm.py |
+25 / β12 | bf16 dtype selection, remove compile skip for SSM/Titans |
Branch
fix-gpu-tests-from-23-v2 (forked from do-not-merge-gpu-tests-from-23)
Note: I do not have write access to push this branch to GitHub. The commit d6619f4 is ready in the local workspace; please pull / cherry-pick / review before merging to main.