Spaces:
Running on Zero
Running on Zero
File size: 5,122 Bytes
d1f1097 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
import random
import numpy as np
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam
try:
from transformer_engine.pytorch.optimizers import FusedAdam as GPUAdam
from transformer_engine.pytorch.optimizers import FusedSGD as GPUSGD
except:
# Handle environment where transformer_engine is not installed
from torch.optim import SGD as GPUSGD
from torch.optim import Adam as GPUAdam
from megatron.core.optimizer.cpu_offloading import HybridDeviceOptimizer
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def setup_seed(seed):
random.seed(seed) # Set Python's built-in random seed
np.random.seed(seed) # Set NumPy's random seed
torch.manual_seed(seed) # Set PyTorch's CPU seed
torch.cuda.manual_seed(seed) # Set PyTorch's GPU seed (if using CUDA)
torch.cuda.manual_seed_all(seed) # Set seed for all GPUs
torch.backends.cudnn.deterministic = True # Ensure deterministic behavior
torch.backends.cudnn.benchmark = False # Disable auto-tuner for reproducibility
@pytest.mark.skipif(
torch.__version__ < '2.3.0',
reason=(
"Requires PyTorch 2.3.0 or higher, lower versions of pytorch have "
"misaligned optimizer accuracy for CPU and GPU."
),
)
@pytest.mark.parametrize('n_steps', [1, 10])
@pytest.mark.parametrize('overlap_cpu_optimizer_d2h_h2d', [False, True])
@pytest.mark.parametrize('offload_fraction', [0, 0.5, 1.0])
@pytest.mark.parametrize('optimizer', ['sgd', 'adam'])
@pytest.mark.parametrize('with_param_groups', [False, True])
def test_multi_device_hybrid_optimizer(
with_param_groups, optimizer, offload_fraction, overlap_cpu_optimizer_d2h_h2d, n_steps
):
setup_seed(42)
net1 = Net().cuda()
net2 = Net().cuda()
net2.load_state_dict(net1.state_dict())
base_lr = 1e-3
params = list(net1.parameters())
ref_params = list(net2.parameters())
if with_param_groups:
param_groups = [
{"params": params[: len(params) // 2], "wd_mult": 1.0, "lr_mult": 1e-4},
{"params": params[len(params) // 2 :], "wd_mult": 0.0, "lr_mult": 2e-4},
]
params = param_groups
ref_param_groups = [
{"params": ref_params[: len(ref_params) // 2], "wd_mult": 1.0, "lr_mult": 1e-4},
{"params": ref_params[len(ref_params) // 2 :], "wd_mult": 0.0, "lr_mult": 2e-4},
]
ref_params = ref_param_groups
if optimizer == 'adam':
cls_kwargs = dict(cpu_optimizer_cls=Adam, gpu_optimizer_cls=GPUAdam)
else:
cls_kwargs = dict(cpu_optimizer_cls=SGD, gpu_optimizer_cls=GPUSGD)
hdo = HybridDeviceOptimizer(
params,
offload_fraction=offload_fraction,
lr=base_lr,
overlap_cpu_optimizer_d2h_h2d=overlap_cpu_optimizer_d2h_h2d,
**cls_kwargs,
)
ref_optimizer = cls_kwargs['gpu_optimizer_cls'](ref_params, lr=base_lr)
# 1. run step on optimizer, make sure there is state generated
assert len(hdo.state_dict()["state"]) == 0 # state is empty
input = torch.randn(1, 3, 32, 32).cuda()
output = net1(input)
output.sum().backward()
hdo.step()
output = net2(input)
output.sum().backward()
ref_optimizer.step()
# PyTorch SGD will not generate state
if optimizer != 'sgd':
assert len(hdo.state_dict()["state"]) != 0
# 2. check the state is on right device
if optimizer == 'adam':
first_param_id = hdo.state_dict()["param_groups"][0]["params"][0]
last_param_id = hdo.state_dict()["param_groups"][-1]["params"][-1]
if offload_fraction > 0:
assert not hdo.state_dict()["state"][first_param_id]["exp_avg"].is_cuda
if offload_fraction < 1:
assert hdo.state_dict()["state"][last_param_id]["exp_avg"].is_cuda
# 3. check parameters allclose
for _ in range(1, n_steps):
input = torch.randn(1, 3, 32, 32).cuda()
output = net1(input)
output.sum().backward()
hdo.step()
output = net2(input)
output.sum().backward()
ref_optimizer.step()
params = net1.state_dict()
ref_params = net2.state_dict()
for k, v in params.items():
assert (v.isnan() == ref_params[k].isnan()).all()
torch.nan_to_num_(v, 0)
torch.nan_to_num_(ref_params[k], 0)
assert torch.allclose(
v, ref_params[k], atol=1e-03
), f"Weight {k} value mismatch, max error: {(v - ref_params[k]).abs().max()}"
|