khala / models /Megatron /tests /unit_tests /test_optimizer.py
multimodalart's picture
multimodalart HF Staff
Initial best-effort ZeroGPU port of Khala song generation
d1f1097 verified
import os
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam
# FP8 recipe will be used to test precision-aware-optimizer.
from transformer_engine.pytorch.fp8 import fp8_autocast
from megatron.core.distributed import DistributedDataParallel, DistributedDataParallelConfig
from megatron.core.optimizer import ChainedOptimizer, OptimizerConfig, get_megatron_optimizer
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer import TransformerConfig
from megatron.core.utils import is_te_min_version, is_torch_min_version
from tests.unit_tests.test_utilities import Utils
from tests.unit_tests.test_utils import _deinit_distributed, _init_distributed
try:
# Check if FP8 block scaling is available.
from transformer_engine.pytorch.fp8 import check_fp8_block_scaling_support
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = check_fp8_block_scaling_support()
from transformer_engine.common.recipe import Float8BlockScaling, Format
except:
fp8_block_scaling_available = False
reason_for_no_fp8_block_scaling = "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9."
try:
from transformer_engine.common.recipe import DelayedScaling
except:
delayed_scaling_available = False
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def test_chained_optimizer():
net = Net()
optimizer_1 = Adam(list(net.parameters())[:2], lr=0.01)
optimizer_2 = SGD(list(net.parameters())[2:], lr=0.1, momentum=0.9)
chained_optimizer = ChainedOptimizer([optimizer_1, optimizer_2])
# Test the chained optimizer's param groups is a reference of the underlying optimizers' param groups
assert optimizer_1.param_groups[0]["lr"] == 0.01
chained_optimizer.param_groups[0]["lr"] = 0.02
assert optimizer_1.param_groups[0]["lr"] == 0.02
# Test the chained optimizer's state is a reference of the underlying optimizers' state
# 1. run step on optimizers, make sure there is state
assert len(chained_optimizer.state) == 0
input = torch.randn(1, 3, 32, 32)
output = net(input)
output.sum().backward()
optimizer_1.step()
optimizer_2.step()
assert len(chained_optimizer.state) != 0
# 2. check the state is a reference
assert not list(optimizer_1.state.values())[0]["exp_avg"].is_cuda
assert not list(optimizer_2.state.values())[0]["momentum_buffer"].is_cuda
def to_cuda(d):
for k, v in d.items():
if isinstance(v, torch.Tensor):
d[k] = v.to("cuda")
elif isinstance(v, dict):
to_cuda(v)
return d
for k, v in chained_optimizer.state.items():
chained_optimizer.state[k] = to_cuda(v)
assert list(optimizer_1.state.values())[0]["exp_avg"].is_cuda
assert list(optimizer_2.state.values())[0]["momentum_buffer"].is_cuda
def test_precision_aware_fused_adam():
try:
from transformer_engine.pytorch.optimizers import FusedAdam
except ImportError:
# Older versions of TE don't have FusedAdam.
return
import inspect
adam_args = inspect.signature(FusedAdam).parameters
arg_names = ["master_weight_dtype", "exp_avg_dtype", "exp_avg_sq_dtype", "use_decoupled_grad"]
for name in arg_names:
if name not in adam_args:
# Skip the test if TE doesn't support precision aware FusedAdam.
return
tensor = torch.rand(278011, dtype=torch.bfloat16).cuda()
params_1 = [torch.nn.Parameter(tensor.float())] # FP32 reference
params_2 = [torch.nn.Parameter(tensor.clone())] # BF16
options = {"lr": 1, "betas": (0.1, 0.25), "eps": 1e-08, "weight_decay": 0, "amsgrad": False}
optimizer_1 = FusedAdam(params_1, **options)
optimizer_2 = FusedAdam(params_2, master_weights=True, use_decoupled_grad=True, **options)
for _ in range(1000):
for p_1, p_2 in zip(params_1, params_2):
p_1.grad = torch.rand_like(p_1)
p_2.decoupled_grad = p_1.grad.clone()
optimizer_1.step()
optimizer_2.step()
master_params = [optimizer_2.get_unscaled_state(p, "master_param") for p in params_2]
for p_1, p_2 in zip(params_1, master_params):
bytes_1 = p_1.data.view(torch.uint8)
bytes_2 = p_2.data.view(torch.uint8)
# Make sure bit-wise matched
assert torch.all(bytes_1 == bytes_2)
for p_1, p_2 in zip(params_1, params_2):
bytes_1 = p_1.data.bfloat16().view(torch.uint8)
bytes_2 = p_2.data.view(torch.uint8)
# Make sure bit-wise matched
assert torch.all(bytes_1 == bytes_2)
@pytest.mark.skipif(
not is_te_min_version("1.13.0"), reason="TE 1.13.0 is required for precision aware optimizer"
)
@pytest.mark.parametrize("precision", ['bf16', 'fp8'])
@pytest.mark.parametrize("main_params_dtype", [torch.float32, torch.float16])
@pytest.mark.parametrize("main_grads_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize(
# use the same dtype for exp_avg and exp_avg_sq to reduce the number of tests
"moment_dtype",
[torch.float32, torch.float16, torch.bfloat16, torch.uint8],
)
def test_precision_aware_optimizer(
precision: str,
main_params_dtype: torch.dtype,
main_grads_dtype: torch.dtype,
moment_dtype: torch.dtype,
):
# Skip because bf16 optimizer states are not supported before TE 2.3.0
if (moment_dtype == torch.bfloat16) and not is_te_min_version("2.3.0"):
pytest.skip("bfloat16 for moment_dtype requires TE >= 2.3.0")
if precision == 'fp8':
if not fp8_block_scaling_available:
fp8_recipe = "delayed"
fp8_recipe_settings = DelayedScaling()
else:
fp8_recipe = "blockwise"
fp8_recipe_settings = Float8BlockScaling(fp8_format=Format.E4M3)
else:
fp8_recipe = None
fp8_recipe_settings = None
world = int(os.getenv('WORLD_SIZE', '1'))
rank = int(os.getenv('RANK', '0'))
# Setup: distributed, model, mock_args.
_init_distributed(world, rank)
Utils.initialize_model_parallel()
# First create baseline model with float32 optimizer states
baseline_model = torch.nn.Linear(100, 100, bias=False, dtype=torch.bfloat16, device='cuda')
baseline_model.requires_grad_(True)
baseline_model.weight.data.fill_(1.0)
baseline_ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=True)
baseline_model = DistributedDataParallel(
TransformerConfig(num_attention_heads=1, num_layers=1), baseline_ddp_config, baseline_model
)
baseline_optimizer_config = OptimizerConfig(
optimizer='adam',
lr=0.01,
bf16=True,
use_distributed_optimizer=True,
use_precision_aware_optimizer=False,
main_params_dtype=torch.float32,
main_grads_dtype=torch.float32,
exp_avg_dtype=torch.float32,
exp_avg_sq_dtype=torch.float32,
)
baseline_optim = get_megatron_optimizer(baseline_optimizer_config, [baseline_model])
# Create test model with specified dtypes for optimizer states
test_model = torch.nn.Linear(100, 100, bias=False, dtype=torch.bfloat16, device='cuda')
test_model.requires_grad_(True)
test_model.weight.data.fill_(1.0)
ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=True)
test_model = DistributedDataParallel(
TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, test_model
)
test_optimizer_config = OptimizerConfig(
optimizer='adam',
lr=0.01,
bf16=True,
fp8_recipe=fp8_recipe,
use_distributed_optimizer=True,
use_precision_aware_optimizer=True,
main_params_dtype=main_params_dtype,
main_grads_dtype=main_grads_dtype,
exp_avg_dtype=moment_dtype,
exp_avg_sq_dtype=moment_dtype,
)
test_optim = get_megatron_optimizer(test_optimizer_config, [test_model])
# Use same input for both models
input = torch.randn(8, 100, dtype=torch.bfloat16, device='cuda')
# Run model
def run_model(model, input, optim, fp8_recipe, fp8_recipe_settings):
if not fp8_recipe:
output = model(input)
else:
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe_settings):
output = model(input)
loss = output.sum()
loss.backward()
optim.step()
return loss.item(), optim.get_grad_norm()
# Run baseline model and test model
baseline_loss, baseline_grad_norm = run_model(
baseline_model, input, baseline_optim, fp8_recipe, fp8_recipe_settings
)
test_loss, test_grad_norm = run_model(
test_model, input, test_optim, fp8_recipe, fp8_recipe_settings
)
rtol = 1e-3 # relative tolerance
atol = 1e-5 # absolute tolerance
# Compare grad norms - allow small difference due to precision
rel_diff = abs(test_grad_norm - baseline_grad_norm) / (
abs(baseline_grad_norm) + 1e-7 # avoid div by 0
)
abs_diff = abs(test_grad_norm - baseline_grad_norm)
assert (
rel_diff <= rtol or abs_diff <= atol
), f"Grad norm mismatch: baseline={baseline_grad_norm}, test={test_grad_norm}, rel_diff={rel_diff}, abs_diff={abs_diff}"
# Compare losses - allow small difference due to precision
loss_rel_diff = abs(test_loss - baseline_loss) / (abs(baseline_loss) + 1e-7)
loss_abs_diff = abs(test_loss - baseline_loss)
assert (
loss_rel_diff <= rtol or loss_abs_diff <= atol
), f"Loss mismatch: baseline={baseline_loss}, test={test_loss}, rel_diff={loss_rel_diff}, abs_diff={loss_abs_diff}"
# Save and reload state dict for the test model
state_dict = test_optim.state_dict()
test_optim.load_state_dict(state_dict)
@pytest.mark.parametrize("use_distributed_optimizer", [False, True])
@pytest.mark.parametrize("precision", ['bf16', 'fp32'])
def test_optim_sharded_state_dict(use_distributed_optimizer: bool, precision: str):
world = int(os.getenv('WORLD_SIZE', '1'))
rank = int(os.getenv('RANK', '0'))
# Setup: distributed, model, mock_args.
_init_distributed(world, rank)
Utils.initialize_model_parallel()
model = torch.nn.Linear(100, 100, bias=False, dtype=torch.bfloat16, device='cuda')
model.requires_grad_(True)
model.weight.data.fill_(1.0)
ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=use_distributed_optimizer)
model = DistributedDataParallel(
TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model
)
for param in model.parameters():
assert param.requires_grad
if precision == 'bf16':
optimizer_config = OptimizerConfig(
optimizer='adam', bf16=True, use_distributed_optimizer=use_distributed_optimizer
)
elif precision == 'fp32':
optimizer_config = OptimizerConfig(
optimizer='adam',
bf16=False,
fp16=False,
use_distributed_optimizer=use_distributed_optimizer,
)
optim = get_megatron_optimizer(optimizer_config, [model])
model_sharded_state_dict = model.sharded_state_dict()
sharded_state_dict = optim.sharded_state_dict(model_sharded_state_dict)
if 'optimizer' in sharded_state_dict and 'state' in sharded_state_dict['optimizer']:
assert (
'common_step' not in sharded_state_dict['optimizer']['state']
or sharded_state_dict['optimizer']['state']['common_step'] is not None
), "Found 'optimizer.state.common_step=None' in sharded state dict."
def test_optimizer_reload_model_params():
world = int(os.getenv('WORLD_SIZE', '1'))
rank = int(os.getenv('RANK', '0'))
_init_distributed(world, rank)
Utils.initialize_model_parallel()
model = Net().bfloat16().cuda()
# Initial values of model params are 1.
for param in model.parameters():
param.data.fill_(1.0)
ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=True)
model = DistributedDataParallel(
TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model
)
optimizer_config = OptimizerConfig(optimizer='adam', bf16=True, use_distributed_optimizer=True)
optim = get_megatron_optimizer(optimizer_config, [model])
# Set all model params to 2.
for param in model.parameters():
param.data.fill_(2.0)
# Although model params are 2 now, but we haven't called reload_model_params() yet, so
# main_params should be 1.
for group in optim.param_groups:
for main_param in group['params']:
assert main_param.dtype == torch.float32
torch.testing.assert_close(
main_param, torch.empty_like(main_param).fill_(1.0), atol=0, rtol=0
)
# Copy model params to main_params, so main_params should be 2 now.
optim.reload_model_params()
for group in optim.param_groups:
for main_param in group['params']:
assert main_param.dtype == torch.float32
torch.testing.assert_close(
main_param, torch.empty_like(main_param).fill_(2.0), atol=0, rtol=0
)
# Create a new state_dict with all params set to 3.
state_dict = model.state_dict()
new_state_dict = {}
for name, param in state_dict.items():
new_state_dict[name] = torch.empty_like(param).fill_(3.0)
# Initialize main_params with the new state_dict, so main_params should be 3 now, but model
# params should still be 2.
optim.reload_model_params(new_state_dict)
for param in model.parameters():
torch.testing.assert_close(param, torch.empty_like(param).fill_(2.0), atol=0, rtol=0)
for group in optim.param_groups:
for main_param in group['params']:
assert main_param.dtype == torch.float32
torch.testing.assert_close(
main_param, torch.empty_like(main_param).fill_(3.0), atol=0, rtol=0
)
@pytest.mark.skipif(
not is_torch_min_version("2.4.0"),
reason="torch.distributed.init_device_mesh requires torch >= 2.4.0",
)
@pytest.mark.parametrize(
"world_size, tp_size, cp_size, dp_size",
[
(1, 1, 1, 1), # Single GPU, no parallelism
(2, 1, 2, 1), # 2 GPUs, 1 TP, 2 CP
(2, 2, 1, 1), # 2 GPUs, 2 TP, 1 CP
(8, 8, 1, 1), # 8 GPUs, 8 TP, 1 CP
(8, 2, 4, 1), # 8 GPUs, 2 TP, 4 CP
(8, 4, 2, 1), # 8 GPUs, 4 TP, 2 CP
(8, 1, 1, 8), # 8 GPUs, 1 TP, 1 CP, 8 DP
(8, 2, 1, 4), # 8 GPUs, 2 TP, 1 CP, 4 DP
(8, 2, 2, 2), # 8 GPUs, 2 TP, 2 CP, 2 DP
],
)
def test_get_megatron_optimizer_with_custom_process_groups(world_size, tp_size, cp_size, dp_size):
"""
Test that get_megatron_optimizer works correctly with custom process groups
provided via pg_collection parameters.
"""
# Skip if world size doesn't match available GPUs
actual_world_size = torch.cuda.device_count()
if actual_world_size != world_size:
pytest.skip(f"Test requires world_size={world_size}, but got {actual_world_size}")
# Initialize model parallel with default settings first
Utils.initialize_model_parallel(
tensor_model_parallel_size=tp_size, context_parallel_size=cp_size
)
# Create device mesh for custom process groups
device_mesh = torch.distributed.init_device_mesh(
"cuda", (1, dp_size, 1, cp_size, tp_size), mesh_dim_names=("pp", "dp", "ep", "cp", "tp")
)
# Create custom process groups from device mesh
dp_group = device_mesh.get_group(mesh_dim="dp")
cp_group = device_mesh.get_group(mesh_dim="cp")
tp_group = device_mesh.get_group(mesh_dim="tp")
pp_group = device_mesh.get_group(mesh_dim="pp")
# Create dp_cp group
dp_cp_mesh = device_mesh["dp", "cp"]
dp_cp_group = dp_cp_mesh._flatten().get_group()
# Create model parallel group (tp + pp)
mp_mesh = device_mesh["pp", "tp"]
mp_group = mp_mesh._flatten().get_group()
# Create process group configurations
pg_collection = ProcessGroupCollection()
pg_collection.dp = dp_group
pg_collection.dp_cp = dp_cp_group
pg_collection.expt_dp = None # Not using expert parallelism in this test
pg_collection.tp = tp_group
pg_collection.cp = cp_group
pg_collection.pp = pp_group
pg_collection.mp = mp_group
pg_collection.tp_ep_pp = None # Not using expert parallelism in this test
# Create a simple model for testing
model = torch.nn.Linear(100, 100, bias=False, device='cuda')
model.requires_grad_(True)
model.weight.data.fill_(1.0)
ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=True)
model = DistributedDataParallel(
TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model
)
for param in model.parameters():
assert param.requires_grad
model_chunks = [model]
# Create optimizer config
optimizer_config = OptimizerConfig(
optimizer='adam',
lr=0.001,
weight_decay=0.01,
adam_beta1=0.9,
adam_beta2=0.999,
adam_eps=1e-8,
)
# Test 1: Create optimizer with custom process groups
optimizer = get_megatron_optimizer(
config=optimizer_config,
model_chunks=model_chunks,
use_gloo_process_groups=False, # Required when using custom process groups
pg_collection=pg_collection,
)
# Verify optimizer was created successfully
assert optimizer is not None, "Optimizer should not be None"
assert hasattr(optimizer, 'param_groups'), "Optimizer should have param_groups"
assert len(optimizer.param_groups) > 0, "Optimizer should have at least one parameter group"
# Test 2: Verify optimizer can perform forward and backward pass
input_tensor = torch.randn(32, 100, device='cuda', requires_grad=True)
output = model(input_tensor)
loss = output.sum()
loss.backward()
# Test 3: Optimizer step should work
optimizer.zero_grad()
output = model(input_tensor)
loss = output.sum()
loss.backward()
# Store original parameters
original_weight = model.module.weight.data.clone()
original_bias = model.module.bias.data.clone() if model.module.bias is not None else None
# Perform optimizer step
optimizer.step()
# Verify parameters were updated
assert not torch.equal(
model.module.weight.data, original_weight
), "Weight should be updated after optimizer step"
if model.module.bias is not None:
assert not torch.equal(
model.module.bias.data, original_bias
), "Bias should be updated after optimizer step"
# Test 4: Compare with default process groups optimizer (if world_size allows)
if world_size == 1: # Only test on single GPU to avoid complex setup
# Create optimizer with default process groups
default_optimizer = get_megatron_optimizer(
config=optimizer_config, model_chunks=model_chunks
)
# Both optimizers should have the same structure
assert len(optimizer.param_groups) == len(
default_optimizer.param_groups
), "Custom and default optimizers should have same number of parameter groups"
def test_get_megatron_optimizer_custom_process_groups_validation():
"""
Test validation logic for custom process groups in get_megatron_optimizer.
"""
Utils.initialize_model_parallel(tensor_model_parallel_size=1)
# Create a simple model
model = torch.nn.Linear(100, 100, bias=False, device='cuda')
model.requires_grad_(True)
model.weight.data.fill_(1.0)
ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=True)
model = DistributedDataParallel(
TransformerConfig(num_attention_heads=1, num_layers=1), ddp_config, model
)
for param in model.parameters():
assert param.requires_grad
model_chunks = [model]
optimizer_config = OptimizerConfig(optimizer='adam', lr=0.001)
# Test 2: Missing dp process group in pg_collection
pg_collection_no_dp = ProcessGroupCollection()
with pytest.raises(ValueError, match="dp process group is required"):
get_megatron_optimizer(
config=optimizer_config, model_chunks=model_chunks, pg_collection=pg_collection_no_dp
)
# Test 3: Missing expt_dp attribute in pg_collection
pg_collection_no_expt_dp = ProcessGroupCollection()
pg_collection_no_expt_dp.dp = torch.distributed.new_group()
# Missing required 'expt_dp' attribute
with pytest.raises(ValueError, match="expt_dp process group is required"):
get_megatron_optimizer(
config=optimizer_config,
model_chunks=model_chunks,
pg_collection=pg_collection_no_expt_dp,
)
# Test 4: Missing mp attribute in pg_collection
pg_collection_complete = ProcessGroupCollection()
pg_collection_complete.dp = torch.distributed.new_group()
pg_collection_complete.expt_dp = None # Explicitly set to None as allowed
# Missing required 'mp' attribute
with pytest.raises(ValueError, match="mp process group is required"):
get_megatron_optimizer(
config=optimizer_config, model_chunks=model_chunks, pg_collection=pg_collection_complete
)
# Test 5: Missing tp_ep_pp attribute in pg_collection
pg_collection_complete.mp = None # Explicitly set to None as allowed
with pytest.raises(ValueError, match="tp_ep_pp process group is required"):
get_megatron_optimizer(
config=optimizer_config, model_chunks=model_chunks, pg_collection=pg_collection_complete
)
# Test 6: Gloo process groups should not be used with custom process groups
pg_collection_complete.mp = None # Explicitly set to None as allowed
pg_collection_complete.tp_ep_pp = None # Explicitly set to None as allowed
with pytest.raises(ValueError, match="Gloo process groups are not supported"):
get_megatron_optimizer(
config=optimizer_config,
model_chunks=model_chunks,
use_gloo_process_groups=True, # Should be False when using custom groups
pg_collection=pg_collection_complete,
)