File size: 1,783 Bytes
e61fdc8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | from megatron.core.tensor_parallel.random import CudaRNGStatesTracker
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.tensor_parallel.random import _CUDA_RNG_STATE_TRACKER
from megatron.core.tensor_parallel.random import checkpoint
from tests.unit_tests.test_utilities import Utils
import pytest
import torch
def test_cuda_rng_states_tracker():
rng_tracker = CudaRNGStatesTracker()
rng_tracker.set_states({"state1":1234})
assert(rng_tracker.get_states()["state1"] == 1234)
rng_tracker.reset()
assert(rng_tracker.get_states() == {})
seed = 1111
rng_tracker.add("state2",seed)
with pytest.raises(Exception):
assert(rng_tracker.add("state3",seed))
with pytest.raises(Exception):
assert(rng_tracker.add("state2",111))
assert(rng_tracker.get_states()['state2'] is not None)
with pytest.raises(Exception):
assert()
rng_tracker.fork("state2")
torch.cuda.manual_seed(seed)
rng_state = torch.cuda.get_rng_state()
assert torch.equal(rng_tracker.get_states()['state2'], rng_state)
def test_model_parallel_cuda_manual_seed():
Utils.initialize_model_parallel(4,2)
model_parallel_cuda_manual_seed(0)
assert(_CUDA_RNG_STATE_TRACKER.get_states()['model-parallel-rng'] is not None)
Utils.destroy_model_parallel()
def test_checkpoint():
def test_forward(*input):
return input[0]+input[1]
assert(torch.equal(torch.ones(16)*3,checkpoint(test_forward, None, torch.ones(16), torch.ones(16)*2)))
Utils.initialize_model_parallel()
input1 = torch.ones((4,4))
checkpoint(test_forward, True, input1, torch.ones((4,4))*2)
assert(torch.equal(torch.ones(input1.numel()).cuda(), input1))
Utils.destroy_model_parallel() |