|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
from unittest import mock |
|
|
|
|
|
import pytest |
|
|
import torch |
|
|
|
|
|
from nemo.utils.get_rank import get_last_rank, get_rank, is_global_rank_zero |
|
|
|
|
|
|
|
|
class TestIsGlobalRankZero: |
|
|
"""Test the is_global_rank_zero function with various environment variable settings.""" |
|
|
|
|
|
@pytest.fixture(autouse=True) |
|
|
def setup_method(self): |
|
|
"""Clear all relevant environment variables before each test.""" |
|
|
for var in ["RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK", "NODE_RANK", "GROUP_RANK", "LOCAL_RANK"]: |
|
|
if var in os.environ: |
|
|
del os.environ[var] |
|
|
|
|
|
def test_default_behavior(self): |
|
|
"""Test the default behavior when no environment variables are set.""" |
|
|
assert is_global_rank_zero() is True |
|
|
|
|
|
def test_with_pytorch_rank_0(self): |
|
|
"""Test when RANK=0 (pytorch environment).""" |
|
|
os.environ["RANK"] = "0" |
|
|
assert is_global_rank_zero() is True |
|
|
|
|
|
def test_with_pytorch_rank_nonzero(self): |
|
|
"""Test when RANK is not 0 (pytorch environment).""" |
|
|
os.environ["RANK"] = "1" |
|
|
assert is_global_rank_zero() is False |
|
|
|
|
|
def test_with_slurm_rank_0(self): |
|
|
"""Test when SLURM_PROCID=0 (SLURM environment).""" |
|
|
os.environ["SLURM_PROCID"] = "0" |
|
|
assert is_global_rank_zero() is True |
|
|
|
|
|
def test_with_slurm_rank_nonzero(self): |
|
|
"""Test when SLURM_PROCID is not 0 (SLURM environment).""" |
|
|
os.environ["SLURM_PROCID"] = "1" |
|
|
assert is_global_rank_zero() is False |
|
|
|
|
|
def test_with_mpi_rank_0(self): |
|
|
"""Test when OMPI_COMM_WORLD_RANK=0 (MPI environment).""" |
|
|
os.environ["OMPI_COMM_WORLD_RANK"] = "0" |
|
|
assert is_global_rank_zero() is True |
|
|
|
|
|
def test_with_mpi_rank_nonzero(self): |
|
|
"""Test when OMPI_COMM_WORLD_RANK is not 0 (MPI environment).""" |
|
|
os.environ["OMPI_COMM_WORLD_RANK"] = "1" |
|
|
assert is_global_rank_zero() is False |
|
|
|
|
|
def test_with_node_rank_0_local_rank_0(self): |
|
|
"""Test when NODE_RANK=0 and LOCAL_RANK=0.""" |
|
|
os.environ["NODE_RANK"] = "0" |
|
|
os.environ["LOCAL_RANK"] = "0" |
|
|
assert is_global_rank_zero() is True |
|
|
|
|
|
def test_with_node_rank_0_local_rank_nonzero(self): |
|
|
"""Test when NODE_RANK=0 but LOCAL_RANK is not 0.""" |
|
|
os.environ["NODE_RANK"] = "0" |
|
|
os.environ["LOCAL_RANK"] = "1" |
|
|
assert is_global_rank_zero() is False |
|
|
|
|
|
def test_with_node_rank_nonzero(self): |
|
|
"""Test when NODE_RANK is not 0.""" |
|
|
os.environ["NODE_RANK"] = "1" |
|
|
os.environ["LOCAL_RANK"] = "0" |
|
|
assert is_global_rank_zero() is False |
|
|
|
|
|
def test_with_group_rank_fallback(self): |
|
|
"""Test using GROUP_RANK as fallback for NODE_RANK.""" |
|
|
os.environ["GROUP_RANK"] = "0" |
|
|
os.environ["LOCAL_RANK"] = "0" |
|
|
assert is_global_rank_zero() is True |
|
|
|
|
|
os.environ["GROUP_RANK"] = "1" |
|
|
assert is_global_rank_zero() is False |
|
|
|
|
|
def test_env_var_precedence(self): |
|
|
"""Test that environment variables are checked in the expected order of precedence.""" |
|
|
|
|
|
os.environ["RANK"] = "0" |
|
|
os.environ["SLURM_PROCID"] = "1" |
|
|
os.environ["OMPI_COMM_WORLD_RANK"] = "1" |
|
|
assert is_global_rank_zero() is True |
|
|
|
|
|
os.environ["RANK"] = "1" |
|
|
os.environ["SLURM_PROCID"] = "0" |
|
|
assert is_global_rank_zero() is False |
|
|
|
|
|
|
|
|
del os.environ["RANK"] |
|
|
assert is_global_rank_zero() is True |
|
|
|
|
|
os.environ["SLURM_PROCID"] = "1" |
|
|
os.environ["OMPI_COMM_WORLD_RANK"] = "0" |
|
|
assert is_global_rank_zero() is False |
|
|
|
|
|
|
|
|
del os.environ["SLURM_PROCID"] |
|
|
assert is_global_rank_zero() is True |
|
|
|
|
|
|
|
|
class TestGetRank: |
|
|
"""Test the get_rank function.""" |
|
|
|
|
|
@pytest.fixture(autouse=True) |
|
|
def setup_method(self): |
|
|
"""Clear all relevant environment variables before each test.""" |
|
|
for var in ["RANK", "SLURM_PROCID", "OMPI_COMM_WORLD_RANK", "NODE_RANK", "GROUP_RANK", "LOCAL_RANK"]: |
|
|
if var in os.environ: |
|
|
del os.environ[var] |
|
|
|
|
|
@mock.patch("torch.distributed.is_initialized", return_value=False) |
|
|
def test_not_distributed(self, mock_is_initialized): |
|
|
"""Test when not in a distributed environment.""" |
|
|
assert get_rank() == 0 |
|
|
|
|
|
@mock.patch("torch.distributed.is_initialized", return_value=True) |
|
|
@mock.patch("torch.distributed.get_rank", return_value=2) |
|
|
def test_distributed_not_global_rank_zero(self, mock_dist_get_rank, mock_is_initialized): |
|
|
"""Test when in a distributed environment and not global rank zero.""" |
|
|
|
|
|
os.environ["RANK"] = "1" |
|
|
assert get_rank() == 2 |
|
|
mock_dist_get_rank.assert_called_once() |
|
|
|
|
|
@mock.patch("torch.distributed.is_initialized", return_value=True) |
|
|
@mock.patch("torch.distributed.get_rank", return_value=0) |
|
|
def test_distributed_global_rank_zero(self, mock_dist_get_rank, mock_is_initialized): |
|
|
"""Test when in a distributed environment and is global rank zero.""" |
|
|
|
|
|
os.environ["RANK"] = "0" |
|
|
assert get_rank() == 0 |
|
|
|
|
|
mock_dist_get_rank.assert_not_called() |
|
|
|
|
|
|
|
|
class TestGetLastRank: |
|
|
"""Test the get_last_rank function.""" |
|
|
|
|
|
@mock.patch("torch.distributed.is_initialized", return_value=False) |
|
|
def test_not_distributed(self, mock_is_initialized): |
|
|
"""Test when not in a distributed environment.""" |
|
|
assert get_last_rank() == 0 |
|
|
mock_is_initialized.assert_called_once() |
|
|
|
|
|
@mock.patch("torch.distributed.is_initialized", return_value=True) |
|
|
@mock.patch("torch.distributed.get_world_size", return_value=4) |
|
|
def test_distributed(self, mock_get_world_size, mock_is_initialized): |
|
|
"""Test when in a distributed environment.""" |
|
|
assert get_last_rank() == 3 |
|
|
mock_is_initialized.assert_called_once() |
|
|
mock_get_world_size.assert_called_once() |
|
|
|