# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. import psutil import torch def chunk_bias(bias, parallel_mode, tp_size=1, ep_size=1): assert parallel_mode in ["row", "column"] if bias.dim() == 2: num_experts, hidden_size = bias.shape if parallel_mode == 'column': bias = bias.reshape(ep_size, num_experts // ep_size, tp_size, hidden_size // tp_size) bias = bias.permute(0, 2, 1, 3) # (ep_size, tp_size, local_eps, hidden_size) else: bias = bias.reshape(ep_size, num_experts // ep_size, hidden_size) # (ep_size, local_eps, hidden_size) return bias else: hidden_size = bias.shape if parallel_mode == "column": bias = bias.reshape(tp_size, hidden_size[0] // tp_size) # (tp_size, hidden_size) return bias def chunk_weight(weight, parallel_mode, tp_size=1, ep_size=1): assert parallel_mode in ["row", "column"] if weight.dim() == 3: num_experts, out_features, in_features = weight.shape if parallel_mode == "column": weight = weight.reshape(ep_size, num_experts // ep_size, tp_size, out_features // tp_size, in_features) weight = weight.permute(0, 2, 1, 3, 4) else: weight = weight.reshape(ep_size, num_experts // ep_size, out_features, tp_size, in_features // tp_size) weight = weight.permute(0, 3, 1, 2, 4) return weight # (ep_size, tp_size, local_eps, output_features, in_features) else: out_features, in_features = weight.shape if parallel_mode == "column": weight = weight.reshape(tp_size, out_features // tp_size, in_features) else: weight = weight.reshape(out_features, tp_size, in_features // tp_size).permute(1, 0, 2) return weight # (tp_size, output_features, in_features) def print_memory_usage(key, rank, num_ranks): '''Print memory usage.''' process = psutil.Process() mem_info = process.memory_info() print("> memory usage: '%s', rank %d / %d, mem %.1f/%.1f gb." % ( key, rank, num_ranks, mem_info.rss / 1024**3, 100 * mem_info.rss / process.memory_percent() / 1024**3, )) class _ConverterFakeProcessGroup: def __init__(self, rank=0, size=1): self._rank = rank self._size = size def rank(self): return self._rank def size(self): return self._size def set_rank(self, rank): self._rank = rank def set_size(self, size): self._size = size