HelixLM / GITHUB_ISSUE.md
david-thrower's picture
Upload GITHUB_ISSUE.md
ec71cb2 verified

Bug Report: 4 Critical Issues Found in NAS / GPU Path (branch do-not-merge-gpu-tests-from-23)

Reporter: ML Intern
Branch audited: do-not-merge-gpu-tests-from-23
Fix branch: fix-gpu-tests-from-23-v2
Severity: High β€” NAS trials have been wasted on dead parameters; GPU training fails with OOM / NaN / compile crashes.


Summary

During a neural-architecture-search audit of nas_helixlm.py v2.3 on the do-not-merge-gpu-tests-from-23 branch, four interrelated bugs were found in graph.py, mamba2.py, nodes.py, trainer.py, and nas_helixlm.py. All four have been fixed in branch fix-gpu-tests-from-23-v2 (see commit d6619f4).

# Bug File(s) Impact
1 nodes_per_column is ignored by _build_node_spec() graph.py NAS permutes a dead parameter; topology is not reproducible by seed alone
2 Mamba2 / SSM scan builds a 256-step autograd chain β†’ OOM mamba2.py, nodes.py GPU OOM on SSM configs; ~1.1 GB saved tensors per mamba2 node
3 fp16 AMP forced for seq_len > 128 β†’ NaN nas_helixlm.py, trainer.py Immediate NaN on small models where there is no memory pressure
4 torch.compile crashes on custom Python loops mamba2.py, nodes.py, nas_helixlm.py Inductor graph-breaks / invalid kernels on GPU; compile silently skipped for SSM/Titans configs

1. nodes_per_column is dead β€” _build_node_spec() ignores it

Evidence

config.py docstrings and presets (lines 22-28, 311-380) advertise nodes_per_column tuples like (2,2), (2,3,2), (3,4,4,3). Validation logic (lines 249-256) even pads/truncates the tuple to match n_columns.

However, graph.py lines 88-148 hardcode every column to:

column = [
    ("linear_attn" | "full_attn", {…}),
    ("swiglu", {…}),
    ("mamba2" | "ssm", {…})       # if use_ssm
    ("titans", {…})                # if use_titans and ci==0
    ("gate", {…})                  # always appended
]

nodes_per_column is never read after validation.

Fix

Wire nodes_per_column into _build_node_spec() by repeating the [attention, swiglu] base pattern until the target count is reached. Optional SSM/Titans nodes consume one slot each. A gate is appended when there are multiple compute nodes or when ci > 0.

for ci in range(cfg.n_columns):
    target = cfg.nodes_per_column[ci]  # e.g. 3
    # Build base: attn + swiglu
    # Insert optional SSM/Titans if room
    # Repeat [attn, swiglu] to fill remaining slots
    # Append gate for aggregation

The RNG seed already controls lateral/vertical wiring; with this fix the node count is also deterministic and reproducible.


2. OOM from mamba2 scan's 256-step autograd chain

Evidence

_ssd_chunked_scan (and _ssm_chunked_scan in nodes.py) keeps every intermediate h tensor for backward:

h = A_c[:, t] * h + B_c[:, t] * x_c[:, t].unsqueeze(-1)
# h is (B, d_inner, d_state) β€” kept for ALL timesteps

For B=32, d_inner=768, d_state=64, each mamba2 node alone stores ~1.1 GB of saved tensors. Multiple columns and loops multiply this.

Fix

Wrap each chunk's inner loop in torch.utils.checkpoint:

def _chunk_scan(h_in, A_c_in, ...):
    h = h_in
    for t in range(chunk_size):
        h = ...
        ys_c.append(y_t)
    return h, torch.stack(ys_c, dim=1)

h, ys_chunk = torch.utils.checkpoint.checkpoint(
    _chunk_scan, h, A_c, B_c, x_c, C_c,
    use_reentrant=False,
)

Only the chunk boundary h states are materialised for backward. Trade ~10–20 % extra compute for an order-of-magnitude memory reduction.


3. fp16 AMP causes NaN on seq_len β‰₯ 256

Evidence

nas_helixlm.py line 395 forced dtype_str = "float32" everywhere because fp16 caused immediate NaN on d β‰₯ 256 with LR=3e-3. The root cause is not fp16 universally, but this architecture's SSM scan, Titans memory updates, and ELU+1.0 feature maps β€” all of which underflow in fp16's narrow dynamic range.

Fix

Use bfloat16 instead of fp16 on GPU. bf16 shares the same 8-bit exponent range as fp32, so it does not underflow on the scan or memory updates. It is natively supported on Ampere+ (A100, H100, L4) and is typically as fast as fp16.

if torch.cuda.is_available():
    dtype_str = "bfloat16"
    use_amp = True
else:
    dtype_str = "float32"
    use_amp = False

The Trainer was also updated to skip GradScaler when the AMP dtype is bf16 (no loss-scaling needed) and to pass torch.bfloat16 to torch.amp.autocast.


4. torch.compile breaks on custom Python loops

Evidence

The inductor backend cannot compile the for t in range(chunk_size) loops with in-place h mutations inside _ssd_chunked_scan and TitansMemoryNode.forward. The old workaround in nas_helixlm.py simply skipped compilation entirely for any SSM/Titans config:

if use_ssm or use_titans:
    return model, False, "skipped: SSM/Titans autograd not compile-safe"

Fix

Decorate the loop functions with @torch.compiler.disable. This tells the inductor backend to treat them as opaque ops β€” the rest of the model (embeddings, linear projections, attention, SwiGLU) still gets compiled.

Decorated functions:

  • mamba2._ssd_chunked_scan
  • nodes._ssm_chunked_scan
  • nodes.TitansMemoryNode.forward

try_compile_model in nas_helixlm.py now removes the SSM/Titans skip and attempts compilation for all configs.


Reproducing the Baseline

The sacred CPU baseline (1000 samples, seq_len=96, 14 epochs) should yield:

  • Train perplexity ~ 23
  • Val perplexity ~ 85–86
  • Throughput ~ 1,892 tok/s

With these fixes applied, GPU configs should hit the same numbers (or better, thanks to bf16 + compile) without NaN or OOM.


Patched Files

File Lines changed Nature of change
helix_lm/graph.py +56 / βˆ’20 _build_node_spec() now reads nodes_per_column
helix_lm/mamba2.py +25 / βˆ’6 @torch.compiler.disable + checkpoint on chunk scan
helix_lm/nodes.py +23 / βˆ’7 Same for _ssm_chunked_scan + TitansMemoryNode.forward
helix_lm/trainer.py +29 / βˆ’11 bf16 autocast, conditional GradScaler
nas_helixlm.py +25 / βˆ’12 bf16 dtype selection, remove compile skip for SSM/Titans

Branch

fix-gpu-tests-from-23-v2 (forked from do-not-merge-gpu-tests-from-23)

Note: I do not have write access to push this branch to GitHub. The commit d6619f4 is ready in the local workspace; please pull / cherry-pick / review before merging to main.