|
|
import os |
|
|
import re |
|
|
from abc import ABC, abstractmethod |
|
|
from typing import List |
|
|
|
|
|
import h5py |
|
|
import numpy as np |
|
|
from filelock import FileLock |
|
|
|
|
|
from .config import AutoParallelConfig, CostModel |
|
|
from .tensor_parallel.shape_consistency import ShapeConsistencyManager |
|
|
|
|
|
|
|
|
class ProfileDB(ABC): |
|
|
"""A database that stores profiling results for multiple device mesh |
|
|
shapes.""" |
|
|
|
|
|
@abstractmethod |
|
|
def query(self, cluster_key, data_key): |
|
|
... |
|
|
|
|
|
@abstractmethod |
|
|
def update(self, cluster_key, data_key, mesh_result): |
|
|
... |
|
|
|
|
|
def close(self): |
|
|
pass |
|
|
|
|
|
|
|
|
class MemDB(ProfileDB): |
|
|
|
|
|
def __init__(self): |
|
|
self.data = {} |
|
|
|
|
|
def query(self, cluster_key, data_key): |
|
|
key = (cluster_key, data_key) |
|
|
mesh_result = self.data.get(key, None) |
|
|
if mesh_result is None: |
|
|
return None |
|
|
else: |
|
|
return mesh_result[0] |
|
|
|
|
|
def update(self, cluster_key, data_key, mesh_result): |
|
|
key = (cluster_key, data_key) |
|
|
self.data[key] = mesh_result |
|
|
|
|
|
|
|
|
class Hdf5DB(ProfileDB): |
|
|
|
|
|
def __init__(self, name): |
|
|
self.name = name |
|
|
lock_name = self.name + ".lock" |
|
|
self.lock = FileLock(lock_name, thread_local=False) |
|
|
|
|
|
def query(self, cluster_key, data_key): |
|
|
file_name = f"{self.name}.hdf5" |
|
|
key = str((cluster_key, data_key)) |
|
|
self.lock.acquire() |
|
|
mesh_result = None |
|
|
with h5py.File(file_name, 'a') as f: |
|
|
if key in f: |
|
|
self.lock.release() |
|
|
mesh_result = f[key] |
|
|
return mesh_result[0] |
|
|
else: |
|
|
return None |
|
|
|
|
|
def update(self, cluster_key, data_key, mesh_result): |
|
|
key = str((cluster_key, data_key)) |
|
|
file_name = f"{self.name}.hdf5" |
|
|
with h5py.File(file_name, 'a') as f: |
|
|
f[key] = mesh_result |
|
|
|
|
|
def close(self): |
|
|
self.lock.release(force=True) |
|
|
|
|
|
|
|
|
class LogicalDeviceMesh(object): |
|
|
|
|
|
def __init__(self, |
|
|
phy_mesh_shape, |
|
|
mesh_shape, |
|
|
phy_ids, |
|
|
config: AutoParallelConfig, |
|
|
alpha, |
|
|
beta, |
|
|
sharp, |
|
|
prof_database=None, |
|
|
shape_consistency_manager=None, |
|
|
host_ips=None): |
|
|
self.phy_mesh_shape = phy_mesh_shape |
|
|
self.mesh_shape = mesh_shape |
|
|
self.phy_ids = phy_ids |
|
|
self.host_ips = host_ips |
|
|
self.cluster_key = config.cluster_key + '_mesh_shape{}'.format('_'.join( |
|
|
[str(i) for i in mesh_shape])) |
|
|
self.prof_min_max_size = [1, 2**34] |
|
|
self.prof_comm_dtypes = [ |
|
|
"int8", "uint8", "int32", "uint32", "int64", "uint64", "float16", |
|
|
"float32", "float64", "bfloat16" |
|
|
] |
|
|
self.devices_group = { |
|
|
(0, ): [self.phy_ids.transpose(), self.mesh_shape[1] - 1], |
|
|
(1, ): [self.phy_ids, self.mesh_shape[1]], |
|
|
(0, 1): [self.phy_ids.reshape([1, self.phy_ids.size]), 0] |
|
|
} |
|
|
self.prof_database = prof_database |
|
|
self.shape_consistency_manager = shape_consistency_manager |
|
|
self.config = config |
|
|
self.cluster_info = config.get_cluster_info() |
|
|
self.hw_alpha = alpha |
|
|
self.hw_beta = beta |
|
|
self.hw_sharp = sharp |
|
|
self.algo_alpha_beta = self._estimate_algo_alpha_beta() |
|
|
self.comm_op_to_nccl_test_func_name = { |
|
|
'all_reduce': 'all_reduce_perf_mpi', |
|
|
'all_gather': 'all_gather_perf_mpi', |
|
|
'all_to_all': 'alltoall_perf_mpi', |
|
|
'reduce_scatter': 'reduce_scatter_perf_mpi', |
|
|
'split': 'split', |
|
|
} |
|
|
|
|
|
@property |
|
|
def size(self) -> int: |
|
|
return self.phy_ids.size |
|
|
|
|
|
def _estimate_algo_alpha_beta(self): |
|
|
ret = {} |
|
|
ar_alpha, ar_beta = {}, {} |
|
|
ag_alpha, ag_beta = {}, {} |
|
|
rs_alpha, rs_beta = {}, {} |
|
|
a2a_alpha, a2a_beta = {}, {} |
|
|
phy_num_hosts, phy_num_devices_per_host = self.phy_mesh_shape |
|
|
if phy_num_hosts == 1 or phy_num_devices_per_host == 1: |
|
|
for dims in [(0, ), (1, ), (0, 1), (1, 0)]: |
|
|
num_devices = 1 |
|
|
for dim in dims: |
|
|
num_devices = self.mesh_shape[dim] * num_devices |
|
|
if num_devices != 1: |
|
|
ar_alpha[dims] = self.hw_alpha[0] if self.hw_sharp[ |
|
|
0] else self.hw_alpha[0] * num_devices / 2 / ( |
|
|
num_devices - 1) |
|
|
ar_beta[dims] = self.hw_beta[0] |
|
|
ag_alpha[dims] = self.hw_alpha[0] * num_devices / ( |
|
|
num_devices - 1) |
|
|
ag_beta[dims] = self.hw_beta[0] |
|
|
rs_alpha[dims] = self.hw_alpha[0] * num_devices / ( |
|
|
num_devices - 1) |
|
|
rs_beta[dims] = self.hw_beta[0] |
|
|
a2a_alpha[dims] = self.hw_alpha[0] * num_devices / ( |
|
|
num_devices - 1) |
|
|
a2a_beta[dims] = self.hw_beta[0] |
|
|
|
|
|
else: |
|
|
for dims in [(0, ), (1, ), (0, 1), (1, 0)]: |
|
|
num_devices = 1 |
|
|
for dim in dims: |
|
|
num_devices = self.mesh_shape[dim] * num_devices |
|
|
if num_devices != 1: |
|
|
if len(dims) == 1: |
|
|
dim = dims[0] |
|
|
ar_alpha[dims] = self.hw_alpha[dim] if self.hw_sharp[ |
|
|
dim] else self.hw_alpha[dim] * num_devices / 2 / ( |
|
|
num_devices - 1) |
|
|
ar_beta[dims] = self.hw_beta[dim] |
|
|
ag_alpha[dims] = self.hw_alpha[dim] * num_devices / ( |
|
|
num_devices - 1) |
|
|
ag_beta[dims] = self.hw_beta[dim] |
|
|
rs_alpha[dims] = self.hw_alpha[dim] * num_devices / ( |
|
|
num_devices - 1) |
|
|
rs_beta[dims] = self.hw_beta[dim] |
|
|
a2a_alpha[dims] = self.hw_alpha[dim] * num_devices / ( |
|
|
num_devices - 1) |
|
|
a2a_beta[dims] = self.hw_beta[dim] |
|
|
elif len(dims) == 2: |
|
|
num_hosts, num_devices_per_host = phy_num_hosts, phy_num_devices_per_host |
|
|
inter_node_col_alpha = self.hw_alpha[ |
|
|
0] * num_devices_per_host |
|
|
inter_node_ar_alpha = inter_node_col_alpha if self.hw_sharp[ |
|
|
0] else inter_node_col_alpha * num_hosts / 2 / ( |
|
|
num_hosts - 1) |
|
|
intra_node_ar_alpha = self.hw_alpha[1] |
|
|
intra_node_ar_alpha = intra_node_ar_alpha if self.hw_sharp[ |
|
|
1] else intra_node_ar_alpha * num_devices_per_host / 2 / ( |
|
|
num_devices_per_host - 1) |
|
|
ar_alpha[dims] = min(inter_node_ar_alpha, |
|
|
intra_node_ar_alpha) |
|
|
ar_beta[dims] = max(self.hw_beta) |
|
|
ag_alpha[dims] = min( |
|
|
inter_node_col_alpha * num_hosts / (num_hosts - 1), |
|
|
self.hw_alpha[1] * num_devices_per_host / |
|
|
(num_devices_per_host - 1)) |
|
|
ag_beta[dims] = max(self.hw_beta) |
|
|
rs_alpha[dims] = ag_alpha[dims] |
|
|
rs_beta[dims] = ag_beta[dims] |
|
|
a2a_alpha[dims] = min( |
|
|
num_hosts * self.hw_alpha[0] / (num_hosts - 1), |
|
|
self.hw_alpha[1] * num_hosts) |
|
|
a2a_beta[dims] = max(self.hw_beta) |
|
|
else: |
|
|
pass |
|
|
ret['all_to_all'] = [a2a_alpha, a2a_beta] |
|
|
ret['all_reduce'] = [ar_alpha, ar_beta] |
|
|
ret['all_gather'] = [ag_alpha, ag_beta] |
|
|
ret['reduce_scatter'] = [rs_alpha, rs_beta] |
|
|
ret['p2p_cross_device'] = [ |
|
|
self.cluster_info.intra_node_bw_per_device, |
|
|
self.cluster_info.intra_node_latency |
|
|
] |
|
|
ret['p2p_cross_host'] = [ |
|
|
self.cluster_info.inter_node_bw_per_device, |
|
|
self.cluster_info.inter_node_latency |
|
|
] |
|
|
return ret |
|
|
|
|
|
|
|
|
def _profile_split(self, min_max_comm_size): |
|
|
comm_size, elapsed_time = [], [] |
|
|
size = min_max_comm_size[0] |
|
|
while size <= min_max_comm_size[1]: |
|
|
time = size * 2 / self.cluster_info.memory_bw |
|
|
comm_size.append(size) |
|
|
elapsed_time.append(time) |
|
|
size = size * 2 |
|
|
return np.array([comm_size, elapsed_time]) |
|
|
|
|
|
def _prase_nccl_test_results(self, f_nccl_test_out_log): |
|
|
'''[ToDo][KDuan] There is some dtye that may not been supported by nccl test, using default dtype (float)''' |
|
|
start_parse = False |
|
|
comm_size, elapsed_time = [], [] |
|
|
try: |
|
|
with open(f_nccl_test_out_log, 'r') as lines: |
|
|
for line in lines: |
|
|
if start_parse: |
|
|
prof_data = re.split(r"[ ]+", line.strip()) |
|
|
if len(prof_data) != 13: |
|
|
continue |
|
|
comm_size.append(float(prof_data[0])) |
|
|
elapsed_time.append(float(prof_data[5])) |
|
|
if 'GB/s' in line and 'us' in line: |
|
|
start_parse = True |
|
|
except Exception: |
|
|
print(f'failed to parse {f_nccl_test_out_log}') |
|
|
return comm_size, elapsed_time |
|
|
|
|
|
def _profile_with_nccl_test(self, min_max_comm_size, dtype, device_group, |
|
|
func_name, step, workload_key): |
|
|
|
|
|
if func_name == 'split': |
|
|
if 2 == step: |
|
|
return self._profile_split(min_max_comm_size) |
|
|
else: |
|
|
return None |
|
|
workspace_dir = self.config['profiling_workspace'] + f'/{workload_key}' |
|
|
os.makedirs(workspace_dir, exist_ok=True) |
|
|
outfile, errfile = workspace_dir + '/profile.out', workspace_dir + '/profile.err' |
|
|
if 1 == step: |
|
|
num_nodes = len(self.host_ips) |
|
|
num_gpus = self.mesh_shape[0] * self.mesh_shape[1] |
|
|
ntasks_per_node = num_gpus // num_nodes |
|
|
nccl_test_command = '"export NCCL_TESTS_SPLIT_MASK={} && export NCCL_COLLNET_ENABLE=1 && {} -b {} -e {} -g 1 -d {} -f {}"'.format( |
|
|
device_group[1], func_name, min_max_comm_size[0], |
|
|
min_max_comm_size[1], dtype, 2) |
|
|
sbatch_command = '#!/bin/bash\n' |
|
|
sbatch_command += '#SBATCH -p {}\n'.format(self.config['partition']) |
|
|
sbatch_command += '#SBATCH -A {}\n'.format(self.config['account']) |
|
|
sbatch_command += '#SBATCH -J {}\n'.format(self.config['jobname']) |
|
|
sbatch_command += '#SBATCH -N {}\n'.format(num_nodes) |
|
|
sbatch_command += '#SBATCH -t {}\n'.format(self.config['time']) |
|
|
sbatch_command += '#SBATCH --ntasks-per-node={}\n'.format( |
|
|
ntasks_per_node) |
|
|
sbatch_command += '#SBATCH --exclusive\n' |
|
|
sbatch_command += '#SBATCH --mem=0\n' |
|
|
sbatch_command += '#SBATCH --network=sharp\n' |
|
|
sbatch_command += '#SBATCH --mail-type=FAIL\n' |
|
|
srun_command = 'srun --nodes={} --mpi=pmix --ntasks-per-node={} --network=sharp -o {} -e {} --container-image={} bash -c '.format( |
|
|
num_nodes, ntasks_per_node, outfile, errfile, |
|
|
self.config['container']) |
|
|
command = sbatch_command + srun_command + nccl_test_command |
|
|
with open(workspace_dir + '/workload.sub', 'w') as f: |
|
|
f.write(command) |
|
|
with open('./preprofiling_step1.sh', 'a') as f: |
|
|
f.write(f'sbatch {workspace_dir}/workload.sub\n') |
|
|
return None |
|
|
|
|
|
else: |
|
|
comm_size, elapsed_time = self._prase_nccl_test_results(outfile) |
|
|
if len(comm_size) < 2: |
|
|
assert 0, 'the profiling for {} was failed at step1, please try again'.format( |
|
|
workload_key) |
|
|
else: |
|
|
print(workload_key, comm_size, elapsed_time) |
|
|
return np.array([comm_size, elapsed_time]) |
|
|
|
|
|
def _profile_single_comm_perf(self, device_group, comm_op, step, data_key): |
|
|
results = {} |
|
|
func_name = self.comm_op_to_nccl_test_func_name[comm_op] |
|
|
for dtype in self.prof_comm_dtypes: |
|
|
size_time = self._profile_with_nccl_test( |
|
|
self.prof_min_max_size, dtype, device_group, func_name, step, |
|
|
data_key + f'_dtype{dtype}') |
|
|
results[dtype] = size_time |
|
|
return results |
|
|
|
|
|
def profile_all_comms_perf(self, step): |
|
|
if self.mesh_shape == (1, 1): |
|
|
return None |
|
|
mesh_results = self.prof_database.query(self.cluster_key, |
|
|
self.mesh_shape) |
|
|
if mesh_results: |
|
|
return mesh_results |
|
|
|
|
|
mesh_results = {} |
|
|
data_key = self.cluster_key + f'_mesh_shape{self.mesh_shape[0]}x{self.mesh_shape[1]}' |
|
|
for comm_op in [ |
|
|
'all_reduce', 'all_to_all', 'all_gather', 'reduce_scatter', |
|
|
'split' |
|
|
]: |
|
|
comm_perf = {} |
|
|
for dim, device_group in self.devices_group.items(): |
|
|
|
|
|
if len(dim) == 1 and self.mesh_shape[dim[0]] == 1: |
|
|
continue |
|
|
|
|
|
comm_perf[dim] = self._profile_single_comm_perf( |
|
|
device_group, comm_op, step, data_key + |
|
|
'_comm_op{}_dim{}'.format(comm_op, ''.join(map(str, dim)))) |
|
|
mesh_results[comm_op] = comm_perf |
|
|
if 2 == step: |
|
|
self.prof_database.update(self.cluster_key, self.mesh_shape, |
|
|
mesh_results) |
|
|
|
|
|
return mesh_results |
|
|
|
|
|
def _model_comm_cost_from_s_curve(self, size_time_array, realsize): |
|
|
assert size_time_array[0][0] <= realsize <= size_time_array[0][-1],\ |
|
|
'the comm_size: {} is not in the profile range: [{}{}]'\ |
|
|
.format(realsize, size_time_array[0][0], size_time_array[0][-1]) |
|
|
return np.interp(realsize, size_time_array[0], size_time_array[1]) |
|
|
|
|
|
def _model_comm_cost_from_alpha_beta(self, comm_op, dim_key, size_in_bytes): |
|
|
elapsed_time = 0.0 |
|
|
if 'split' == comm_op: |
|
|
elapsed_time = size_in_bytes * 2 / ( |
|
|
self.cluster_info.memory_bw * |
|
|
self.cluster_info.memory_efficiency) * 1e-3 |
|
|
else: |
|
|
dict_alpha, dict_beta = self.algo_alpha_beta[comm_op] |
|
|
alpha, beta = dict_alpha[dim_key], dict_beta[dim_key] |
|
|
elapsed_time = (size_in_bytes / |
|
|
(alpha * self.cluster_info.communication_efficiency) |
|
|
* 1e-3) + beta |
|
|
return elapsed_time |
|
|
|
|
|
def _input_size_to_comm_size(self, comm_op, dims, input_size): |
|
|
ret = input_size |
|
|
if 'all_gather' == comm_op: |
|
|
for dim in dims: |
|
|
ret = ret * self.mesh_shape[dim] |
|
|
return ret |
|
|
|
|
|
def estimate_comm_cost(self, comm_op, dim, input_size, dtype): |
|
|
|
|
|
size = self._input_size_to_comm_size(comm_op, dim, input_size) |
|
|
if self.config.comm_cost_model == CostModel.S_CURVE: |
|
|
mesh_perf = self.prof_database.query(self.cluster_key, |
|
|
self.mesh_shape) |
|
|
assert mesh_perf is not None, 'the mesh is not profiled, mesh_shape = {}'.format( |
|
|
self.mesh_shape) |
|
|
comm_op_perf = mesh_perf.get(comm_op, None) |
|
|
assert comm_op_perf is not None, '{} is not profiled'.format( |
|
|
comm_op) |
|
|
elapsed_time = self._model_comm_cost_from_s_curve( |
|
|
comm_op_perf[tuple(dim)][dtype], size) |
|
|
return elapsed_time |
|
|
elif self.config.comm_cost_model == CostModel.ALPHA_BETA: |
|
|
elapsed_time = self._model_comm_cost_from_alpha_beta( |
|
|
comm_op, tuple(dim), size) |
|
|
elif self.config.comm_cost_model == CostModel.PROFILE: |
|
|
assert False, 'Unsupported profile based communication cost model now' |
|
|
elif self.config.comm_cost_model == CostModel.ZERO: |
|
|
elapsed_time = 0.0 |
|
|
|
|
|
return elapsed_time |
|
|
|
|
|
|
|
|
class PhysicalDeviceMesh(object): |
|
|
|
|
|
def __init__(self, |
|
|
phy_devices_id, |
|
|
config: AutoParallelConfig, |
|
|
prof_database=None, |
|
|
shape_consistency_manager=None, |
|
|
host_ips=None): |
|
|
self.phy_devices_id = np.array(phy_devices_id) |
|
|
self.num_hosts, self.num_devices_per_host = self.phy_devices_id.shape |
|
|
self.host_ips = host_ips |
|
|
if host_ips is None: |
|
|
self.host_ips = [''] * self.num_hosts |
|
|
self.config = config |
|
|
self.cluster_info = config.get_cluster_info() |
|
|
self.prof_database: ProfileDB = prof_database |
|
|
self.shape_consistency_manager = shape_consistency_manager |
|
|
if self.config.comm_cost_model not in CostModel: |
|
|
raise ValueError( |
|
|
f'unsupported communication cost model: {self.config.comm_cost_model}' |
|
|
) |
|
|
if self.config.sharding_cost_model not in CostModel: |
|
|
raise ValueError( |
|
|
f'unsupported sharding cost model: {self.config.sharding_cost_model}' |
|
|
) |
|
|
if self.config.comm_cost_model == CostModel.S_CURVE or self.config.sharding_cost_model == CostModel.PROFILE: |
|
|
if self.prof_database is None: |
|
|
profile_cache = config.profile_cache |
|
|
if profile_cache is None: |
|
|
self.prof_database = MemDB() |
|
|
else: |
|
|
self.prof_database = Hdf5DB(profile_cache) |
|
|
elif self.config.comm_cost_model == CostModel.ALPHA_BETA: |
|
|
assert self.cluster_info.intra_node_bw_per_device > 0, 'intra_node_bw_per_device is needed for alpha_beta method' |
|
|
assert self.cluster_info.inter_node_bw_per_device > 0, 'inter_node_bw_per_device is needed for alpha_beta method' |
|
|
if self.config.sharding_cost_model == CostModel.ALPHA_BETA: |
|
|
assert self.cluster_info.memory_bw > 0, 'memory_bw is needed for alpha_beta method' |
|
|
|
|
|
if not shape_consistency_manager: |
|
|
self.shape_consistency_manager = ShapeConsistencyManager() |
|
|
|
|
|
@property |
|
|
def size(self) -> int: |
|
|
return self.phy_devices_id.size |
|
|
|
|
|
def close(self): |
|
|
if self.prof_database is not None: |
|
|
self.prof_database.close() |
|
|
|
|
|
def split_pipeline_meshes( |
|
|
self, num_stages, |
|
|
num_devices_per_stage) -> List["PhysicalDeviceMesh"]: |
|
|
sub_meshes = [] |
|
|
if num_devices_per_stage <= self.num_devices_per_host: |
|
|
assert self.num_devices_per_host % num_devices_per_stage == 0, \ |
|
|
"num_devices_per_host ({}) % num_devices_per_stage ({}) != 0"\ |
|
|
.format(self.num_devices_per_host, num_devices_per_stage) |
|
|
num_clusters_per_host = self.num_devices_per_host // num_devices_per_stage |
|
|
num_clusters = self.num_hosts * num_clusters_per_host |
|
|
assert num_stages % num_clusters == 0, \ |
|
|
"num_stages({}) % num_clusters({}) !=0".format(num_stages, num_clusters) |
|
|
for mesh_id in range(num_stages): |
|
|
cluster_id = mesh_id % num_clusters |
|
|
cluster_col = cluster_id % num_clusters_per_host |
|
|
cluster_row = cluster_id // num_clusters_per_host |
|
|
sub_devices_id = [ |
|
|
self.phy_devices_id[cluster_row][cluster_col * |
|
|
num_devices_per_stage:( |
|
|
(cluster_col + 1) * |
|
|
num_devices_per_stage)] |
|
|
] |
|
|
sub_meshes.append( |
|
|
PhysicalDeviceMesh(sub_devices_id, self.config, |
|
|
self.prof_database, |
|
|
self.shape_consistency_manager, |
|
|
[self.host_ips[cluster_row]])) |
|
|
else: |
|
|
assert num_devices_per_stage % self.num_devices_per_host == 0, \ |
|
|
"num_devices_per_stage ({}) % num_devices_per_host ({}) != 0"\ |
|
|
.format(num_devices_per_stage, self.num_devices_per_host) |
|
|
num_host_per_cluster = num_devices_per_stage // self.num_devices_per_host |
|
|
assert self.num_hosts % num_host_per_cluster == 0, \ |
|
|
"num_hosts ({}) % num_host_per_cluster({}) != 0".format(self.num_hosts, num_host_per_cluster) |
|
|
num_clusters = self.num_hosts // num_host_per_cluster |
|
|
for mesh_id in range(num_stages): |
|
|
cluster_id = mesh_id % num_clusters |
|
|
cluster_row = cluster_id * num_host_per_cluster |
|
|
sub_devices_id = self.phy_devices_id[cluster_row:( |
|
|
cluster_row + num_host_per_cluster)] |
|
|
host_ips = self.host_ips[cluster_row:(cluster_row + |
|
|
num_host_per_cluster)] |
|
|
sub_meshes.append( |
|
|
PhysicalDeviceMesh(sub_devices_id, self.config, |
|
|
self.prof_database, |
|
|
self.shape_consistency_manager, |
|
|
host_ips)) |
|
|
return sub_meshes |
|
|
|
|
|
def _profile_logical_meshes(self, logical_meshes, step): |
|
|
for lmesh in logical_meshes: |
|
|
lmesh.profile_all_comms_perf(step) |
|
|
|
|
|
def as_logical_mesh(self) -> LogicalDeviceMesh: |
|
|
alpha = [ |
|
|
self.cluster_info.inter_node_bw_per_device, |
|
|
self.cluster_info.intra_node_bw_per_device |
|
|
] |
|
|
beta = [ |
|
|
self.cluster_info.inter_node_latency, |
|
|
self.cluster_info.intra_node_latency |
|
|
] |
|
|
sharp = [ |
|
|
self.cluster_info.inter_node_sharp, |
|
|
self.cluster_info.intra_node_sharp |
|
|
] |
|
|
return LogicalDeviceMesh( |
|
|
self.phy_devices_id.shape, |
|
|
self.phy_devices_id.shape, |
|
|
self.phy_devices_id, |
|
|
self.config, |
|
|
alpha, |
|
|
beta, |
|
|
sharp, |
|
|
self.prof_database, |
|
|
self.shape_consistency_manager, |
|
|
self.host_ips, |
|
|
) |
|
|
|
|
|
def get_logical_meshes(self): |
|
|
logical_meshes = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.num_hosts == 1: |
|
|
alpha = [self.cluster_info.intra_node_bw_per_device] |
|
|
beta = [self.cluster_info.intra_node_latency] |
|
|
sharp = [self.cluster_info.intra_node_sharp] |
|
|
for i in range(2, self.num_devices_per_host): |
|
|
if self.num_devices_per_host % i == 0 and i * i <= self.num_devices_per_host: |
|
|
lmesh_shape = (i, self.num_devices_per_host // i) |
|
|
lmesh_phy_ids = self.phy_devices_id.reshape(lmesh_shape) |
|
|
logical_meshes.append( |
|
|
LogicalDeviceMesh(self.phy_devices_id.shape, |
|
|
lmesh_shape, lmesh_phy_ids, |
|
|
self.config, alpha, beta, sharp, |
|
|
self.prof_database, |
|
|
self.shape_consistency_manager, |
|
|
self.host_ips)) |
|
|
|
|
|
|
|
|
elif self.num_devices_per_host == 1: |
|
|
alpha = [self.cluster_info.inter_node_bw_per_device] |
|
|
beta = [self.cluster_info.inter_node_latency] |
|
|
sharp = [self.cluster_info.inter_node_sharp] |
|
|
for i in range(2, self.num_hosts): |
|
|
if self.num_hosts % i == 0 and i * i <= self.num_hosts: |
|
|
lmesh_shape = (i, self.num_hosts // i) |
|
|
lmesh_phy_ids = self.phy_devices_id.reshape(lmesh_shape) |
|
|
logical_meshes.append( |
|
|
LogicalDeviceMesh(self.phy_devices_id.shape, |
|
|
lmesh_phy_ids, self.config, alpha, |
|
|
beta, sharp, self.prof_database, |
|
|
self.shape_consistency_manager, |
|
|
self.host_ips)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if 0 == len(logical_meshes): |
|
|
logical_meshes.append(self.as_logical_mesh()) |
|
|
return logical_meshes |
|
|
|
|
|
''' |
|
|
we assume we can evenly split the pipeline and deviceMesh |
|
|
''' |
|
|
|
|
|
def _list_all_sub_meshes(self): |
|
|
sub_meshes = [] |
|
|
for num_devices_per_stage in range(1, self.num_devices_per_host + 1): |
|
|
if self.num_devices_per_host % num_devices_per_stage == 0: |
|
|
num_stages = self.num_hosts * self.num_devices_per_host // num_devices_per_stage |
|
|
sub_meshes.append( |
|
|
self.split_pipeline_meshes(num_stages, |
|
|
num_devices_per_stage)[0]) |
|
|
for num_hosts_per_stage in range(2, self.num_hosts + 1): |
|
|
if self.num_hosts % num_hosts_per_stage == 0: |
|
|
num_stages = self.num_hosts // num_hosts_per_stage |
|
|
sub_meshes.append( |
|
|
self.split_pipeline_meshes( |
|
|
num_stages, |
|
|
num_hosts_per_stage * self.num_devices_per_host)[0]) |
|
|
return sub_meshes |
|
|
|
|
|
def list_all_pipeline_configs(self): |
|
|
configs = [] |
|
|
for num_devices_per_stage in range(1, self.num_devices_per_host + 1): |
|
|
if self.num_devices_per_host % num_devices_per_stage == 0: |
|
|
num_stages = self.num_hosts * self.num_devices_per_host // num_devices_per_stage |
|
|
configs.append((num_stages, num_devices_per_stage)) |
|
|
for num_hosts_per_stage in range(2, self.num_hosts + 1): |
|
|
if self.num_hosts % num_hosts_per_stage == 0: |
|
|
num_stages = self.num_hosts // num_hosts_per_stage |
|
|
configs.append( |
|
|
(num_stages, |
|
|
num_hosts_per_stage * self.num_devices_per_host)) |
|
|
return configs |
|
|
|
|
|
def profile_s_curve(self, step): |
|
|
sub_phy_device_meshes = self._list_all_sub_meshes() |
|
|
for phy_mesh in sub_phy_device_meshes: |
|
|
lmeshes = phy_mesh.get_logical_meshes() |
|
|
self._profile_logical_meshes(lmeshes, step) |
|
|
if 2 == step: |
|
|
self.save_profile_database() |
|
|
|
|
|
def profile_alpha_beta(self): |
|
|
alpha = [250, 25] |
|
|
beta = [100, 100] |
|
|
return alpha, beta |
|
|
|