|
|
|
|
|
|
| import copy
|
| import os
|
| import pickle
|
| import shutil
|
| import socket
|
| import subprocess
|
| import sys
|
| import tarfile
|
| import tempfile
|
| import unittest
|
| from collections import OrderedDict
|
| from collections.abc import Mapping
|
| from os.path import expanduser
|
|
|
| import numpy as np
|
| import requests
|
| from modelscope.hub.constants import DEFAULT_CREDENTIALS_PATH
|
|
|
| TEST_LEVEL = 2
|
| TEST_LEVEL_STR = 'TEST_LEVEL'
|
|
|
|
|
| TEST_ACCESS_TOKEN1 = os.environ.get('TEST_ACCESS_TOKEN_CITEST', None)
|
| TEST_ACCESS_TOKEN2 = os.environ.get('TEST_ACCESS_TOKEN_SDKDEV', None)
|
|
|
| TEST_MODEL_CHINESE_NAME = '内部测试模型'
|
| TEST_MODEL_ORG = 'citest'
|
|
|
|
|
| def delete_credential():
|
| path_credential = expanduser(DEFAULT_CREDENTIALS_PATH)
|
| shutil.rmtree(path_credential, ignore_errors=True)
|
|
|
|
|
| def test_level():
|
| global TEST_LEVEL
|
| if TEST_LEVEL_STR in os.environ:
|
| TEST_LEVEL = int(os.environ[TEST_LEVEL_STR])
|
|
|
| return TEST_LEVEL
|
|
|
|
|
| def require_tf(test_case):
|
| test_case = unittest.skip('test requires TensorFlow')(test_case)
|
| return test_case
|
|
|
|
|
| def require_torch(test_case):
|
| return test_case
|
|
|
|
|
| def set_test_level(level: int):
|
| global TEST_LEVEL
|
| TEST_LEVEL = level
|
|
|
|
|
| class DummyTorchDataset:
|
|
|
| def __init__(self, feat, label, num) -> None:
|
| self.feat = feat
|
| self.label = label
|
| self.num = num
|
|
|
| def __getitem__(self, index):
|
| import torch
|
| return {'feat': torch.Tensor(self.feat), 'labels': torch.Tensor(self.label)}
|
|
|
| def __len__(self):
|
| return self.num
|
|
|
|
|
| def create_dummy_test_dataset(feat, label, num):
|
| return DummyTorchDataset(feat, label, num)
|
|
|
|
|
| def download_and_untar(fpath, furl, dst) -> str:
|
| if not os.path.exists(fpath):
|
| r = requests.get(furl)
|
| with open(fpath, 'wb') as f:
|
| f.write(r.content)
|
|
|
| file_name = os.path.basename(fpath)
|
| root_dir = os.path.dirname(fpath)
|
| target_dir_name = os.path.splitext(os.path.splitext(file_name)[0])[0]
|
| target_dir_path = os.path.join(root_dir, target_dir_name)
|
|
|
|
|
| t = tarfile.open(fpath)
|
| t.extractall(path=dst)
|
|
|
| return target_dir_path
|
|
|
|
|
| def get_case_model_info():
|
| status_code, result = subprocess.getstatusoutput(
|
| 'grep -rn "damo/" tests/ | grep -v ".pyc" | grep -v "Binary file" | grep -v run.py ')
|
| lines = result.split('\n')
|
| test_cases = OrderedDict()
|
| model_cases = OrderedDict()
|
| for line in lines:
|
|
|
| line = line.strip()
|
| elements = line.split(':')
|
| test_file = elements[0]
|
| model_pos = line.find('damo')
|
| left_quote = line[model_pos - 1]
|
| rquote_idx = line.rfind(left_quote)
|
| model_name = line[model_pos:rquote_idx]
|
| if test_file not in test_cases:
|
| test_cases[test_file] = set()
|
| model_info = test_cases[test_file]
|
| model_info.add(model_name)
|
|
|
| if model_name not in model_cases:
|
| model_cases[model_name] = set()
|
| case_info = model_cases[model_name]
|
| case_info.add(test_file.replace('tests/', '').replace('.py', '').replace('/', '.'))
|
|
|
| return model_cases
|
|
|
|
|
| def compare_arguments_nested(print_content, arg1, arg2, rtol=1.e-3, atol=1.e-8, ignore_unknown_type=True):
|
| type1 = type(arg1)
|
| type2 = type(arg2)
|
| if type1.__name__ != type2.__name__:
|
| if print_content is not None:
|
| print(f'{print_content}, type not equal:{type1.__name__} and {type2.__name__}')
|
| return False
|
|
|
| if arg1 is None:
|
| return True
|
| elif isinstance(arg1, (int, str, bool, np.bool_, np.integer, np.str_)):
|
| if arg1 != arg2:
|
| if print_content is not None:
|
| print(f'{print_content}, arg1:{arg1}, arg2:{arg2}')
|
| return False
|
| return True
|
| elif isinstance(arg1, (float, np.floating)):
|
| if not np.isclose(arg1, arg2, rtol=rtol, atol=atol, equal_nan=True):
|
| if print_content is not None:
|
| print(f'{print_content}, arg1:{arg1}, arg2:{arg2}')
|
| return False
|
| return True
|
| elif isinstance(arg1, (tuple, list)):
|
| if len(arg1) != len(arg2):
|
| if print_content is not None:
|
| print(f'{print_content}, length is not equal:{len(arg1)}, {len(arg2)}')
|
| return False
|
| if not all([
|
| compare_arguments_nested(None, sub_arg1, sub_arg2, rtol=rtol, atol=atol)
|
| for sub_arg1, sub_arg2 in zip(arg1, arg2)
|
| ]):
|
| if print_content is not None:
|
| print(f'{print_content}')
|
| return False
|
| return True
|
| elif isinstance(arg1, Mapping):
|
| keys1 = arg1.keys()
|
| keys2 = arg2.keys()
|
| if len(keys1) != len(keys2):
|
| if print_content is not None:
|
| print(f'{print_content}, key length is not equal:{len(keys1)}, {len(keys2)}')
|
| return False
|
| if len(set(keys1) - set(keys2)) > 0:
|
| if print_content is not None:
|
| print(f'{print_content}, key diff:{set(keys1) - set(keys2)}')
|
| return False
|
| if not all([compare_arguments_nested(None, arg1[key], arg2[key], rtol=rtol, atol=atol) for key in keys1]):
|
| if print_content is not None:
|
| print(f'{print_content}')
|
| return False
|
| return True
|
| elif isinstance(arg1, np.ndarray):
|
| arg1 = np.where(np.equal(arg1, None), np.NaN, arg1).astype(dtype=float)
|
| arg2 = np.where(np.equal(arg2, None), np.NaN, arg2).astype(dtype=float)
|
| if not all(np.isclose(arg1, arg2, rtol=rtol, atol=atol, equal_nan=True).flatten()):
|
| if print_content is not None:
|
| print(f'{print_content}')
|
| return False
|
| return True
|
| else:
|
| if ignore_unknown_type:
|
| return True
|
| else:
|
| raise ValueError(f'type not supported: {type1}')
|
|
|
|
|
| _DIST_SCRIPT_TEMPLATE = """
|
| import ast
|
| import argparse
|
| import pickle
|
| import torch
|
| from torch import distributed as dist
|
| from modelscope.utils.torch_utils import get_dist_info
|
| import {}
|
|
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument('--save_all_ranks', type=ast.literal_eval, help='save all ranks results')
|
| parser.add_argument('--save_file', type=str, help='save file')
|
| parser.add_argument('--local_rank', type=int, default=0)
|
| args = parser.parse_args()
|
|
|
|
|
| def main():
|
| results = {}.{}({}) # module.func(params)
|
| if args.save_all_ranks:
|
| save_file = args.save_file + str(dist.get_rank())
|
| with open(save_file, 'wb') as f:
|
| pickle.dump(results, f)
|
| else:
|
| rank, _ = get_dist_info()
|
| if rank == 0:
|
| with open(args.save_file, 'wb') as f:
|
| pickle.dump(results, f)
|
|
|
|
|
| if __name__ == '__main__':
|
| main()
|
| """
|
|
|
|
|
| class DistributedTestCase(unittest.TestCase):
|
| """Distributed TestCase for test function with distributed mode.
|
| Examples:
|
| >>> import torch
|
| >>> from torch import distributed as dist
|
| >>> from modelscope.utils.torch_utils import init_dist
|
|
|
| >>> def _test_func(*args, **kwargs):
|
| >>> init_dist(launcher='pytorch')
|
| >>> rank = dist.get_rank()
|
| >>> if rank == 0:
|
| >>> value = torch.tensor(1.0).cuda()
|
| >>> else:
|
| >>> value = torch.tensor(2.0).cuda()
|
| >>> dist.all_reduce(value)
|
| >>> return value.cpu().numpy()
|
|
|
| >>> class DistTest(DistributedTestCase):
|
| >>> def test_function_dist(self):
|
| >>> args = () # args should be python builtin type
|
| >>> kwargs = {} # kwargs should be python builtin type
|
| >>> self.start(
|
| >>> _test_func,
|
| >>> num_gpus=2,
|
| >>> assert_callback=lambda x: self.assertEqual(x, 3.0),
|
| >>> *args,
|
| >>> **kwargs,
|
| >>> )
|
| """
|
|
|
| def _start(self, dist_start_cmd, func, num_gpus, assert_callback=None, save_all_ranks=False, *args, **kwargs):
|
| script_path = func.__code__.co_filename
|
| script_dir, script_name = os.path.split(script_path)
|
| script_name = os.path.splitext(script_name)[0]
|
| func_name = func.__qualname__
|
|
|
| func_params = []
|
| for arg in args:
|
| if isinstance(arg, str):
|
| arg = ('\'{}\''.format(arg))
|
| func_params.append(str(arg))
|
|
|
| for k, v in kwargs.items():
|
| if isinstance(v, str):
|
| v = ('\'{}\''.format(v))
|
| func_params.append('{}={}'.format(k, v))
|
|
|
| func_params = ','.join(func_params).strip(',')
|
|
|
| tmp_run_file = tempfile.NamedTemporaryFile(suffix='.py').name
|
| tmp_res_file = tempfile.NamedTemporaryFile(suffix='.pkl').name
|
|
|
| with open(tmp_run_file, 'w') as f:
|
| print('save temporary run file to : {}'.format(tmp_run_file))
|
| print('save results to : {}'.format(tmp_res_file))
|
| run_file_content = _DIST_SCRIPT_TEMPLATE.format(script_name, script_name, func_name, func_params)
|
| f.write(run_file_content)
|
|
|
| tmp_res_files = []
|
| if save_all_ranks:
|
| for i in range(num_gpus):
|
| tmp_res_files.append(tmp_res_file + str(i))
|
| else:
|
| tmp_res_files = [tmp_res_file]
|
| self.addCleanup(self.clean_tmp, [tmp_run_file] + tmp_res_files)
|
|
|
| tmp_env = copy.deepcopy(os.environ)
|
| tmp_env['PYTHONPATH'] = ':'.join((tmp_env.get('PYTHONPATH', ''), script_dir)).lstrip(':')
|
|
|
| tmp_env['NCCL_P2P_DISABLE'] = '1'
|
| script_params = '--save_all_ranks=%s --save_file=%s' % (save_all_ranks, tmp_res_file)
|
| script_cmd = '%s %s %s' % (dist_start_cmd, tmp_run_file, script_params)
|
| print('script command: %s' % script_cmd)
|
| res = subprocess.call(script_cmd, shell=True, env=tmp_env)
|
|
|
| script_res = []
|
| for res_file in tmp_res_files:
|
| with open(res_file, 'rb') as f:
|
| script_res.append(pickle.load(f))
|
| if not save_all_ranks:
|
| script_res = script_res[0]
|
|
|
| if assert_callback:
|
| assert_callback(script_res)
|
|
|
| self.assertEqual(res, 0, msg='The test function ``{}`` in ``{}`` run failed!'.format(func_name, script_name))
|
|
|
| return script_res
|
|
|
| def start(self, func, num_gpus, assert_callback=None, save_all_ranks=False, *args, **kwargs):
|
| from .torch_utils import _find_free_port
|
| ip = socket.gethostbyname(socket.gethostname())
|
| if 'dist_start_cmd' in kwargs:
|
| dist_start_cmd = kwargs.pop('dist_start_cmd')
|
| else:
|
| dist_start_cmd = '%s -m torch.distributed.launch --nproc_per_node=%d ' \
|
| '--master_addr=\'%s\' --master_port=%s' % (sys.executable, num_gpus, ip, _find_free_port())
|
|
|
| return self._start(
|
| dist_start_cmd=dist_start_cmd,
|
| func=func,
|
| num_gpus=num_gpus,
|
| assert_callback=assert_callback,
|
| save_all_ranks=save_all_ranks,
|
| *args,
|
| **kwargs)
|
|
|
| def clean_tmp(self, tmp_file_list):
|
| for file in tmp_file_list:
|
| if os.path.exists(file):
|
| if os.path.isdir(file):
|
| shutil.rmtree(file)
|
| else:
|
| os.remove(file)
|
|
|