Instructions to use KexuanShi/Megatron-LM with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- NeMo
How to use KexuanShi/Megatron-LM with NeMo:
# tag did not correspond to a valid NeMo domain.
- Notebooks
- Google Colab
- Kaggle
| import re | |
| from argparse import ArgumentParser | |
| from pathlib import Path | |
| from typing import Dict, Iterable, Optional, Tuple, Union | |
| import torch | |
| from torch.distributed.checkpoint.filesystem import FileSystemReader | |
| from torch.distributed.checkpoint.state_dict_loader import load | |
| TensorLike = Union[torch.Tensor, Iterable[torch.Tensor]] | |
| def _as_iter(x: TensorLike): | |
| return x if (isinstance(x, Iterable) and not isinstance(x, torch.Tensor)) else [x] | |
| def _fro_norm(x: TensorLike) -> torch.Tensor: | |
| """Frobenius norm; supports sharded tensors (sum of shard ||·||_F^2).""" | |
| it = _as_iter(x) | |
| s = torch.tensor(0.0, device=next(iter(it)).device if it else "cpu") | |
| for t in it: | |
| s = s + t.float().pow(2).sum() | |
| return torch.sqrt(s) | |
| def machine_epsilon_for_dtype(dtype: torch.dtype) -> float: | |
| """Return machine epsilon for dtype. For FP8, use BF16 epsilon per paper.""" | |
| # Standard types | |
| if dtype in (torch.float32, torch.float16, torch.bfloat16): | |
| return float(torch.finfo(dtype).eps) | |
| # FP8 recipes: accum/store typically BF16/FP32; use BF16 epsilon | |
| if hasattr(torch, "float8_e4m3fn") and dtype in ( | |
| torch.float8_e4m3fn, | |
| getattr(torch, "float8_e5m2fn", None), | |
| ): | |
| return float(torch.finfo(torch.bfloat16).eps) | |
| # Fallback | |
| return float(torch.finfo(torch.float32).eps) | |
| def relative_grad_diff(g_hat: TensorLike, g_ref: TensorLike, eps_den: float = 1e-30) -> float: | |
| """ | |
| Relative difference ||g_hat - g_ref||_F / ||g_ref||_F. | |
| Accepts a single tensor or an iterable of shards for each argument. | |
| """ | |
| # If sharded, assume shards align 1:1; otherwise pass the merged tensors. | |
| gh_iter, gr_iter = _as_iter(g_hat), _as_iter(g_ref) | |
| if len(list(gh_iter)) != len(list(gr_iter)): | |
| # Re-materialize since we consumed generators above: | |
| gh_iter, gr_iter = _as_iter(g_hat), _as_iter(g_ref) | |
| num_sq = torch.tensor(0.0, device=next(iter(gh_iter)).device) | |
| for a, b in zip(_as_iter(g_hat), _as_iter(g_ref)): | |
| num_sq = num_sq + (a.float() - b.float()).pow(2).sum() | |
| num = torch.sqrt(num_sq) | |
| den = _fro_norm(g_ref) | |
| return float(num / (den + eps_den)) | |
| def expected_rel_bound( | |
| l: int, | |
| *, | |
| L: int = 32, | |
| C: float = 1.03, | |
| dtype: Optional[torch.dtype] = torch.bfloat16, | |
| k: float = 4.0, | |
| ) -> float: | |
| """ | |
| Bound ~ k * (C ** (L + 1 - l)) * eps_mch, with 1-based layer index l. | |
| - L is hard-coded default to 32 per your request. | |
| - C is 'close to 1'; 1.01–1.05 are reasonable defaults. | |
| - k absorbs the hidden constant in big-O; 2–8 are common choices. | |
| - dtype controls eps_mch; for FP8 use BF16 epsilon (see https://www.arxiv.org/pdf/2506.09280 theorem 5.3). | |
| """ | |
| eps_mch = machine_epsilon_for_dtype(dtype or torch.bfloat16) | |
| depth = L + 1 - l # 1-based depth from the top (as in the theorem) | |
| depth = max(depth, 0) | |
| return float(k * (C**depth) * eps_mch) | |
| def check_gradient( | |
| g_hat: TensorLike, | |
| g_ref: TensorLike, | |
| l: int, | |
| *, | |
| L: int = 32, | |
| C: float = 1.03, | |
| dtype: Optional[torch.dtype] = None, | |
| k: float = 4.0, | |
| ) -> Tuple[float, float, bool]: | |
| """ | |
| Compute (rel_error, bound, ok) for layer l. | |
| - If dtype is None, infer from g_ref (or g_hat if needed). | |
| # See https://www.arxiv.org/pdf/2506.09280 theorem 5.3 | |
| """ | |
| # Infer dtype if not provided | |
| if dtype is None: | |
| t0 = next(iter(_as_iter(g_ref))) | |
| dtype = t0.dtype | |
| rel = relative_grad_diff(g_hat, g_ref) | |
| bnd = expected_rel_bound(l, L=L, C=C, dtype=dtype, k=k) | |
| return rel, bnd, (rel <= bnd) | |
| def _filter_optimizer_tensors(plain_tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |
| """Return only optimizer-related tensors from a flat checkpoint tensor dict.""" | |
| return { | |
| k: v for k, v in plain_tensors.items() if k.startswith("optimizer.") and ".exp_avg." in k | |
| } | |
| def assert_grads_close(left: torch.Tensor, right: torch.Tensor): | |
| # Implement theorem 5.3 of https://www.arxiv.org/pdf/2506.09280 | |
| # This is the real test: | |
| rel, bnd, ok = check_gradient( | |
| left, right, l=0, dtype=torch.bfloat16 | |
| ) # hard code to layer 0 since that's the most permissive | |
| # If the real test above fails, run an assert close for the useful diagnostics and raise either way. | |
| if not ok: | |
| rel_shuff, _, ok_shuff = check_gradient( | |
| left, torch.roll(right, shifts=-1, dims=-1), l=0, dtype=torch.bfloat16 | |
| ) | |
| try: | |
| torch.testing.assert_close(left, right) | |
| msg = ( | |
| "AssertionError on relative norm magnitude " | |
| f"(rel={rel}, bnd={bnd}, ok={ok}, rel_shuff={rel_shuff}, ok_shuff={ok_shuff}) " | |
| "but torch.testing.assert_close(left, right) passes. \n" | |
| f"Left: {left.shape}/{left.dtype} {left}\n" | |
| f"Right: {right.shape}/{right.dtype} {right}" | |
| ) | |
| except AssertionError as e: | |
| msg = ( | |
| "AssertionError on relative norm magnitude " | |
| f"(rel={rel}, bnd={bnd}, ok={ok}, rel_shuff={rel_shuff}, ok_shuff={ok_shuff}): {e}\n" | |
| f"Left: {left.shape}/{left.dtype} {left}\n" | |
| f"Right: {right.shape}/{right.dtype} {right}" | |
| ) | |
| raise AssertionError(msg) | |
| def unshard_row_parallel_state(saved_state, out_features, in_features, tp): | |
| # saved_state: [..., tp, out_features * (in_features // tp)] | |
| prefix = saved_state.shape[:-2] | |
| per = in_features // tp | |
| x = saved_state.view(*prefix, tp, out_features, per) # [..., tp, O, I_shard] | |
| x = x.permute(*range(len(prefix)), -2, -3, -1) # [..., O, tp, I_shard] | |
| x = x.reshape(*prefix, out_features, in_features) # [..., O, I] | |
| return x | |
| def _assert_optimizer_tensors_equal( | |
| left: Dict[str, torch.Tensor], | |
| right: Dict[str, torch.Tensor], | |
| left_empty: Dict[str, torch.Tensor], | |
| right_empty: Dict[str, torch.Tensor], | |
| eps=1e-4, | |
| ): | |
| left_keys = set(left.keys()) | |
| right_keys = set(right.keys()) | |
| only_left = sorted(left_keys - right_keys) | |
| only_right = sorted(right_keys - left_keys) | |
| assert ( | |
| not only_left and not only_right | |
| ), f"Optimizer tensor keys mismatch.\nOnly in left: {only_left}\nOnly in right: {only_right}" | |
| some_non_zero = False | |
| assertions = [] | |
| for key in sorted(left_keys): | |
| lt, rt = left[key], right[key] | |
| rt_colpar, rt_rowpar = None, None | |
| if lt.shape != rt.shape: | |
| # "Tensor shape mismatch for {key}: {lt.shape} vs {rt.shape}, trying simple reshape | |
| original_key = key.replace("optimizer.state.exp_avg.", "") | |
| # Unsharded shape | |
| # {'decoder.layers.self_attention.linear_proj.weight': torch.Size([32, 3072, 4096]), 'optimizer.state.exp_avg.decoder.layers.self_attention.linear_proj.weight': torch.Size([32, 1, 1, 12582912]), 'optimizer.state.exp_avg_sq.decoder.layers.self_attention.linear_proj.weight': torch.Size([32, 1, 1, 12582912]), 'optimizer.state.fp32_param.decoder.layers.self_attention.linear_proj.weight': torch.Size([32, 1, 1, 12582912])} | |
| # Sharded shape | |
| # {'decoder.layers.self_attention.linear_proj.weight': torch.Size([32, 3072, 4096]), 'optimizer.state.exp_avg.decoder.layers.self_attention.linear_proj.weight': torch.Size([32, 1, 2, 6291456]), 'optimizer.state.exp_avg_sq.decoder.layers.self_attention.linear_proj.weight': torch.Size([32, 1, 2, 6291456]), 'optimizer.state.fp32_param.decoder.layers.self_attention.linear_proj.weight': torch.Size([32, 1, 2, 6291456])} | |
| left_shape = left_empty[original_key].shape | |
| right_shape = right_empty[original_key].shape | |
| skip_tp_check = False | |
| if left_shape != right_shape: | |
| if "embedding.word_embeddings.weight" in key or ".output_layer.weight" in key: | |
| # First handle different padding on the input/output dimensions. | |
| lt = lt.reshape(left_shape) | |
| rt = rt.reshape(right_shape) | |
| min_dim = min(left_shape[0], right_shape[0]) | |
| lt = lt[:min_dim, ...] | |
| rt = rt[:min_dim, ...] | |
| left_shape = lt.shape | |
| right_shape = rt.shape | |
| skip_tp_check = True | |
| else: | |
| raise AssertionError( | |
| f"Tensor shape mismatch for {key}: {left_shape} vs {right_shape}" | |
| ) | |
| # problem: we do not know the TP axis for this tensor. We can guess though. | |
| if len(left_shape) == 3 and not skip_tp_check: | |
| # TP axis is 1 | |
| lt = lt.reshape(left_shape[0], 1, left_shape[1], left_shape[2]) | |
| elif len(left_shape) == 2 and not skip_tp_check: | |
| # TP axis is 2 | |
| lt = lt.reshape(left_shape[0], 1, left_shape[1]) | |
| if ( | |
| key.endswith("mlp.linear_fc2.weight") | |
| or key.endswith("self_attention.linear_proj.weight") | |
| ) and not skip_tp_check: | |
| # Handle row parallel linear layers. | |
| # TODO come up with a better way to determine row parallel linear layers. | |
| rt = unshard_row_parallel_state( | |
| rt, out_features=left_shape[1], in_features=left_shape[2], tp=rt.shape[2] | |
| ) | |
| else: | |
| try: | |
| rt = rt.reshape(lt.shape) | |
| except Exception as e: | |
| msg = f"Tensor shape mismatch for {key}: {lt.shape} vs {rt.shape}, simple reshape failed: {e}" | |
| if "embedding.word_embeddings.weight" in key or ".output_layer.weight" in key: | |
| print( | |
| f"FIXME: Skipping {key} because it's a word embedding or output layer," | |
| "and something about padding changes under TP." | |
| ) | |
| continue | |
| raise AssertionError(msg) | |
| assert ( | |
| lt.shape == rt.shape and lt.dtype == rt.dtype | |
| ), f"Tensor meta mismatch for {key}: {lt.shape}/{lt.dtype} vs {rt.shape}/{rt.dtype}" | |
| # Reduce the rate of 0 vs near 0 rtol failures by adding a small epsilon | |
| left_scale = torch.max(torch.abs(lt)) | |
| right_scale = torch.max(torch.abs(rt)) | |
| if left_scale <= eps and right_scale <= eps: | |
| print( | |
| f"WARNING: zero-ish scale tensors ({left_scale=} vs {right_scale=}) " | |
| f"so they will trivially pass comparing {key=}" | |
| ) | |
| else: | |
| some_non_zero = True | |
| try: | |
| assert_grads_close(lt, rt) | |
| print(f"Optimizer tensors match for {key}") | |
| except AssertionError as e: | |
| assertions.append(AssertionError(f"AssertionError for {key}: {e}")) | |
| assert not assertions, f"Assertion Errors found comparing keys: {assertions}" | |
| assert some_non_zero, "No non-zero tensors found in this comparison" | |
| def load_dist_checkpoint_pt( | |
| ckpt_dir, | |
| metadata_ckpt_dir=None, | |
| pattern=r"optimizer", | |
| device="cpu", | |
| return_full_empty: bool = False, | |
| ): | |
| """Return {full_key: tensor} for every tensor whose key matches *pattern*.""" | |
| meta_ckpt_dir = Path(metadata_ckpt_dir or ckpt_dir) | |
| meta_reader = FileSystemReader(str(meta_ckpt_dir)) | |
| # --- fast metadata pass (no tensor data yet) ----------------------------- | |
| meta = meta_reader.read_metadata() # tiny JSON read | |
| tmeta = meta.state_dict_metadata # key ➜ TensorMetadata | |
| if return_full_empty: | |
| wanted = [k for k in tmeta if hasattr(tmeta[k], "size")] | |
| else: | |
| wanted = [k for k in tmeta if re.search(pattern, k) and hasattr(tmeta[k], "size")] | |
| if not wanted: | |
| raise ValueError(f"No keys matching /{pattern}/ in {ckpt_dir}") | |
| # --- build "empty" placeholders ----------------------------------------- | |
| placeholders = { | |
| k: torch.empty(tuple(tmeta[k].size), dtype=tmeta[k].properties.dtype, device=device) | |
| for k in wanted | |
| } | |
| if return_full_empty: | |
| return placeholders | |
| # --- stream just those tensors (no process-group needed) ----------------- | |
| data_reader = FileSystemReader(str(ckpt_dir)) | |
| load( | |
| state_dict=placeholders, | |
| storage_reader=data_reader, | |
| no_dist=True, # switches off all collectives | |
| ) | |
| return placeholders # dict[str, Tensor] | |
| def test_optimizer_states_match(checkpoint_dirs): | |
| """ | |
| Compare optimizer state across provided torch_dist checkpoints: | |
| - Keys: ensure the set of optimizer tensor keys match across checkpoints | |
| - Values: ensure corresponding tensors are equal (allclose) | |
| - Structure (non-tensor common state): ensure common optimizer structures match | |
| """ | |
| assert len(checkpoint_dirs) > 1, "This test requires 2 or more checkpoints <dir1> [<dir2> ...]." | |
| base_dir = checkpoint_dirs[0] | |
| # Compare optimizer tensors | |
| base_plain = load_dist_checkpoint_pt(base_dir) | |
| base_empty = load_dist_checkpoint_pt(base_dir, return_full_empty=True, device="meta") | |
| base_opt_tensors = _filter_optimizer_tensors(base_plain) | |
| assert base_opt_tensors, f"No optimizer tensors found in checkpoint: {base_dir}" | |
| assertions = [] | |
| for other_dir in checkpoint_dirs[1:]: | |
| try: | |
| other_plain = load_dist_checkpoint_pt(other_dir) | |
| other_empty = load_dist_checkpoint_pt(other_dir, return_full_empty=True, device="meta") | |
| other_opt_tensors = _filter_optimizer_tensors(other_plain) | |
| assert other_opt_tensors, f"No optimizer tensors found in checkpoint: {other_dir}" | |
| _assert_optimizer_tensors_equal( | |
| base_opt_tensors, other_opt_tensors, base_empty, other_empty | |
| ) | |
| print(f"Optimizer tensors match for {base_dir} and {other_dir}") | |
| del other_plain | |
| del other_opt_tensors | |
| except AssertionError as e: | |
| msg = f"AssertionError comparing {base_dir} to {other_dir}:\n{e}" | |
| print(f"Optimizer tensors mismatch for {base_dir} and {other_dir}:\n{msg}") | |
| assertions.append(AssertionError(msg)) | |
| assert not assertions, f"AssertionErrors comparing {checkpoint_dirs}:\n{assertions}" | |
| def main(): | |
| parser = ArgumentParser( | |
| description="Given checkpoints saved with adam b1,b2=0 trained for one step, " | |
| "we can check that the gradients match under different training configurations. " | |
| "Currently this test script has some hard-coded assumptions for GPT style models, " | |
| "namely which layers are RowParallel and require different unsharding logic." | |
| ) | |
| parser.add_argument( | |
| "checkpoints", nargs="+", type=Path, help="Path to the checkpoints to compare" | |
| ) | |
| args = parser.parse_args() | |
| test_optimizer_states_match(args.checkpoints) | |
| if __name__ == "__main__": | |
| main() | |