| import os |
| import unittest |
| from argparse import Namespace |
| from unittest import TestCase, mock |
|
|
| from trainer import TrainerArgs |
| from trainer.distribute import get_gpus |
|
|
|
|
| class TestGpusStringParsingMethods(TestCase): |
| @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) |
| def test_parse_gpus_set_in_env_var_and_args(self): |
| args = Namespace(gpus="0,1") |
| gpus = get_gpus(args) |
| expected_value = ["0"] |
| self.assertEqual(expected_value, gpus, msg_for_test_failure(expected_value)) |
|
|
| @mock.patch.dict(os.environ, {}) |
| def test_parse_gpus_set_in_args(self): |
| args = Namespace(gpus="0,1") |
| gpus = get_gpus(args) |
| expected_value = ["0", "1"] |
| self.assertEqual(expected_value, gpus, msg_for_test_failure(expected_value)) |
|
|
| @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1"}) |
| def test_parse_gpus_set_in_env_var(self): |
| args = Namespace() |
| gpus = get_gpus(args) |
| expected_value = ["0", "1"] |
| self.assertEqual(expected_value, gpus, msg_for_test_failure(expected_value)) |
|
|
| @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0, 1 "}) |
| def test_parse_gpus_set_in_env_var_with_spaces(self): |
| args = Namespace() |
| gpus = get_gpus(args) |
| expected_value = ["0", "1"] |
| self.assertEqual(expected_value, gpus, msg_for_test_failure(expected_value)) |
|
|
| @mock.patch.dict(os.environ, {}) |
| def test_parse_gpus_set_in_args_with_spaces(self): |
| args = Namespace(gpus="0, 1, 2, 3 ") |
| gpus = get_gpus(args) |
| expected_value = ["0", "1", "2", "3"] |
| self.assertEqual(expected_value, gpus, msg_for_test_failure(expected_value)) |
|
|
|
|
| def msg_for_test_failure(expected_value): |
| return "GPU Values are expected to be " + str(expected_value) |
|
|
|
|
| def create_args_parser(): |
| parser = TrainerArgs().init_argparse(arg_prefix="") |
| parser.add_argument("--gpus", type=str) |
| return parser |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|