| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| from argparse import ArgumentParser |
|
|
| import torch |
| from pytorch_lightning import Trainer |
|
|
| from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector |
| from nemo.utils import logging, model_utils |
| from nemo.utils.app_state import AppState |
|
|
|
|
| """ |
| Usage: |
| python megatron_change_num_partitions.py \ |
| --model_file=PATH_TO_SRC_FILE \ |
| --target_file=PATH_TO_TGT_FILE \ |
| --tensor_model_parallel_size=2 \ |
| --target_tensor_model_parallel_size=1 |
| """ |
|
|
|
|
| def merge_partition(model, partitions, write_path=None): |
| idx = 0 |
| for name, param in model.named_parameters(): |
| if param.shape == partitions[0][idx].shape: |
| concated = partitions[0][idx].data |
| elif param.shape[0] == partitions[0][idx].shape[0]: |
| concated = torch.cat([partitions[i][idx].data for i in range(len(partitions))], dim=-1) |
| else: |
| concated = torch.cat([partitions[i][idx].data for i in range(len(partitions))], dim=0) |
| if concated.shape != param.shape: |
| logging.info( |
| f"Warning: Shape mismatch for parameter {name} required shape: {param.shape}, merged shape: {concated.shape}. Narrowing to match required size." |
| ) |
| if concated.shape[1:] == param.shape[1:]: |
| concated = torch.narrow(concated, 0, 0, param.shape[0]) |
| elif concated.shape[:-1] == param.shape[:-1]: |
| concated = torch.narrow(concated, -1, 0, param.shape[-1]) |
| else: |
| raise RuntimeError( |
| f"Can not handle parameter {name}, required shape: {param.shape}, merged shape: {concated.shape}." |
| ) |
| param.data = concated |
| idx += 1 |
|
|
| if write_path is not None: |
| model.save_to(write_path) |
|
|
|
|
| def split_partition(model, partitions, tp_size, write_path=None, megatron_legacy=False): |
| if len(partitions) != 1: |
| raise ValueError( |
| "Can only split partitions of model with TP=1. For partitions of models with TP>1, merge first." |
| ) |
|
|
| if tp_size < 1: |
| raise ValueError("TP size must to be >= 1.") |
|
|
| app_state = AppState() |
| app_state.data_parallel_rank = 0 |
| app_state.pipeline_model_parallel_size = 1 |
| app_state.tensor_model_parallel_size = tp_size |
| app_state.model_parallel_size = app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size |
|
|
| app_state.tensor_model_parallel_rank = tp_size - 1 |
|
|
| idx = 0 |
| splits = [] |
| for param_name, param in model.named_parameters(): |
| if param.shape == partitions[0][idx].shape: |
| split = [partitions[0][idx].data] * tp_size |
| elif param.shape[0] == partitions[0][idx].shape[0]: |
| split = torch.split(partitions[0][idx].data, param.shape[-1], dim=-1) |
| else: |
| |
| if 'query_key_value.weight' in param_name and megatron_legacy: |
| split_dim = partitions[0][idx].data.shape[0] |
| if split_dim % (tp_size * 3) != 0: |
| raise ValueError( |
| f"Can not split Q,K,V parameter {param_name} with shape {param.shape} into tensor parallel size {tp_size}. Not divisible by {tp_size * 3}." |
| ) |
| tp_qkv_splits = torch.chunk(partitions[0][idx].data, tp_size * 3, dim=0) |
| split = [] |
| for i in range(tp_size): |
| tp_qkv = torch.cat([tp_qkv_splits[item] for item in range(i, tp_size * 3, tp_size)]) |
| split.append(tp_qkv) |
| elif 'key_value.weight' in param_name and megatron_legacy: |
| split_dim = partitions[0][idx].data.shape[0] |
| if split_dim % (tp_size * 2) != 0: |
| raise ValueError( |
| f"Can not split K,V parameter {param_name} with shape {param.shape} into tensor parallel size {tp_size}. Not divisible by {tp_size * 2}." |
| ) |
| tp_qkv_splits = torch.chunk(partitions[0][idx].data, tp_size * 2, dim=0) |
| split = [] |
| for i in range(tp_size): |
| tp_qkv = torch.cat([tp_qkv_splits[item] for item in range(i, tp_size * 2, tp_size)]) |
| split.append(tp_qkv) |
| |
| else: |
| split = torch.split(partitions[0][idx].data, param.shape[0], dim=0) |
| splits.append(split) |
| idx += 1 |
|
|
| for i in range(tp_size - 1, -1, -1): |
| app_state.tensor_model_parallel_rank = i |
|
|
| idx = 0 |
| for name, param in model.named_parameters(): |
| split_val = splits[idx][i].clone() |
|
|
| if param.shape != split_val.shape: |
| logging.info( |
| f"Warning: Shape mismatch for parameter {name} required shape: {param.shape}, split shape: {split_val.shape}. Padding to match required size." |
| ) |
|
|
| if split_val.shape[1:] == param.shape[1:]: |
| pad = [0, 0] * len(split_val.shape) |
| pad[-1] = param.shape[0] - split_val.shape[0] |
| split_val = torch.nn.functional.pad(split_val, pad, 'constant') |
| elif split_val.shape[:-1] == param.shape[:-1]: |
| pad = [0, param.shape[-1] - split_val.shape[-1]] |
| split_val = torch.nn.functional.pad(split_val, pad, 'constant') |
| else: |
| raise RuntimeError( |
| f"Can not handle parameter {name}, required shape: {param.shape}, split shape: {split_val.shape}." |
| ) |
|
|
| param.data = split_val |
| idx += 1 |
|
|
| if write_path is not None: |
| model.save_to(write_path) |
|
|
|
|
| def main(): |
| parser = ArgumentParser() |
| parser.add_argument("--model_file", type=str, required=True, help="Path to source .nemo file") |
| parser.add_argument("--target_file", type=str, required=True, help="Path to write target .nemo file") |
| parser.add_argument("--tensor_model_parallel_size", type=int, required=True, help="TP size of source model") |
| parser.add_argument("--target_tensor_model_parallel_size", type=int, required=True, help="TP size of target model") |
| parser.add_argument( |
| "--model_class", |
| type=str, |
| default="nemo.collections.nlp.models.language_modeling.megatron_gpt_model.MegatronGPTModel", |
| help="NeMo model class. This script should support all NeMo megatron models that use Tensor Parallel", |
| ) |
| parser.add_argument("--precision", default=16, help="PyTorch Lightning Trainer precision flag") |
| parser.add_argument( |
| "--megatron_legacy", |
| action="store_true", |
| help="Converter for legacy megatron modles that have different q,k,v weight splits", |
| ) |
| parser.add_argument( |
| "--tokenizer_model_path", |
| type=str, |
| required=False, |
| default=None, |
| help="Path to the tokenizer model path if your model uses a tokenizer model as an artifact. This is needed if your model uses a sentencepiece tokenizer.", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| precision = args.precision |
| if args.precision in ["32", "16"]: |
| precision = int(float(args.precision)) |
| tp_size = args.tensor_model_parallel_size |
| tgt_tp_size = args.target_tensor_model_parallel_size |
| cls = model_utils.import_class_by_path(args.model_class) |
|
|
| trainer = Trainer(devices=1, strategy=NLPDDPStrategy(), accelerator="cpu", precision=precision) |
| app_state = AppState() |
| app_state.data_parallel_rank = 0 |
| app_state.pipeline_model_parallel_size = 1 |
| app_state.tensor_model_parallel_size = tp_size |
| app_state.model_parallel_size = app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size |
|
|
| if tp_size > 1: |
| partitions = [] |
| for i in range(tp_size): |
| app_state.tensor_model_parallel_rank = i |
| model = cls.restore_from(restore_path=args.model_file, trainer=trainer, map_location=torch.device("cpu")) |
| params = [p for _, p in model.named_parameters()] |
| partitions.append(params) |
| |
| app_state.data_parallel_rank = 0 |
| app_state.pipeline_model_parallel_size = 1 |
| app_state.tensor_model_parallel_size = tp_size |
| app_state.model_parallel_size = ( |
| app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size |
| ) |
|
|
| model.cfg.tensor_model_parallel_size = 1 |
| app_state.model_parallel_size = 1 |
| trainer = Trainer(devices=1, strategy=NLPDDPStrategy(), accelerator="cpu", precision=precision) |
| if args.tokenizer_model_path is not None: |
| model.cfg.tokenizer.model = args.tokenizer_model_path |
| model = cls(model.cfg, trainer).to('cpu') |
| model._save_restore_connector = NLPSaveRestoreConnector() |
|
|
| if tgt_tp_size > 1: |
| merge_partition(model, partitions) |
| else: |
| merge_partition(model, partitions, args.target_file) |
| else: |
| app_state.model_parallel_size = 1 |
| model = cls.restore_from(restore_path=args.model_file, trainer=trainer, map_location=torch.device("cpu")) |
|
|
| if tgt_tp_size > 1: |
| partitions = [] |
| params = [p for _, p in model.named_parameters()] |
| partitions.append(params) |
|
|
| model.cfg.tensor_model_parallel_size = tgt_tp_size |
| app_state.model_parallel_size = tgt_tp_size |
| trainer = Trainer(devices=1, strategy=NLPDDPStrategy(), accelerator="cpu", precision=precision) |
| if args.tokenizer_model_path is not None: |
| model.cfg.tokenizer.model = args.tokenizer_model_path |
| model = cls(model.cfg, trainer).to('cpu') |
| model._save_restore_connector = NLPSaveRestoreConnector() |
| split_partition(model, partitions, tgt_tp_size, args.target_file, args.megatron_legacy) |
|
|
| logging.info("Successfully finished changing partitions!") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|