Spaces:
Running on Zero
Running on Zero
| # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | |
| """Sample Generate GPT.""" | |
| import functools | |
| import os | |
| import sys | |
| import warnings | |
| import json | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))) | |
| import modelopt | |
| from modelopt.torch.speculative.plugins.megatron_eagle import MegatronARValidation | |
| import torch | |
| from datasets import load_dataset | |
| from tqdm import tqdm | |
| from megatron.core import mpu | |
| from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage | |
| from megatron.core.pipeline_parallel import get_forward_backward_func | |
| from megatron.core.tensor_parallel.mappings import gather_from_tensor_model_parallel_region | |
| from megatron.post_training.arguments import add_modelopt_args | |
| from megatron.post_training.checkpointing import load_modelopt_checkpoint | |
| from megatron.post_training.model_provider import model_provider | |
| from megatron.post_training.utils import get_mtbench_chat_data | |
| from megatron.training import get_args, get_model, get_tokenizer, initialize_megatron | |
| from megatron.training.checkpointing import save_checkpoint | |
| from megatron.training.utils import get_ltor_masks_and_position_ids, print_rank_0, unwrap_model | |
| warnings.filterwarnings('ignore') | |
| def add_ar_validation_args(parser): | |
| """Add additional arguments for ModelOpt acceptance rate validation.""" | |
| group = parser.add_argument_group(title='ModelOpt ar validation') | |
| group.add_argument( | |
| "--osl", type=int, default=64, help="Output sequence length." | |
| ) | |
| parser.add_argument( | |
| "--prompts-path", | |
| type=str, | |
| default=None, | |
| help="Path to the prompts json file. If not provided, MTBench will be used.", | |
| ) | |
| parser.add_argument( | |
| "--ground-truth-path", | |
| type=str, | |
| default=None, | |
| help="Path to the ground truth pt file.", | |
| ) | |
| parser.add_argument( | |
| "--steps", type=int, default=1, help="Only used in EAGLE." | |
| ) | |
| parser.add_argument( | |
| "--save-ground-truth-path", | |
| type=str, | |
| default=None, | |
| help="Save path for the ground truth pt file.", | |
| ) | |
| add_modelopt_args(parser) | |
| return parser | |
| def check_arguments(): | |
| """Checking user arguments.""" | |
| args = get_args() | |
| if args.num_layers_per_virtual_pipeline_stage is not None: | |
| print_rank_0("Interleaved pipeline schedule is not yet supported for text generation.") | |
| exit() | |
| if hasattr(args, 'moe_grouped_gemm') and args.moe_grouped_gemm == True: | |
| print_rank_0("WARNING: Forcing moe_grouped_gemm to False for PTQ and export.") | |
| args.moe_grouped_gemm = False | |
| def get_current_memory_info(): | |
| remaining_mem, total_mem = torch.cuda.mem_get_info() | |
| info = "rank {:02} memory remaining {:03}% ({}/{} MB) ".format( | |
| torch.distributed.get_rank(), | |
| int(remaining_mem * 100 / total_mem), | |
| remaining_mem // 1048576, | |
| total_mem // 1048576, | |
| ) | |
| return info | |
| def report_current_memory_info(): | |
| """Report current memory usage.""" | |
| print(get_current_memory_info(), flush=True) | |
| torch.distributed.barrier() | |
| if __name__ == "__main__": | |
| initialize_megatron( | |
| extra_args_provider=add_ar_validation_args, | |
| args_defaults={ | |
| 'tokenizer_type': 'HuggingFaceTokenizer', | |
| 'no_load_rng': True, | |
| 'no_load_optim': True, | |
| }, | |
| ) | |
| check_arguments() | |
| args = get_args() | |
| if not args.prompts_path: | |
| dataset = get_mtbench_chat_data() | |
| prompts = [[sample["conversations"][0]] for sample in dataset] | |
| else: | |
| with open(args.prompts_path, "r") as f: | |
| prompts = [json.loads(line) for line in f] | |
| if args.ground_truth_path is not None: | |
| ground_truth = torch.load(args.ground_truth_path) | |
| ground_truth = [gt.to(torch.cuda.current_device()) for gt in ground_truth] | |
| else: | |
| ground_truth = [None for _ in range(len(prompts))] | |
| tokenizer = get_tokenizer()._tokenizer | |
| model = get_model(functools.partial(model_provider, parallel_output=True), wrap_with_ddp=False) | |
| report_current_memory_info() | |
| if args.load is not None: | |
| load_modelopt_checkpoint(model, strict=not args.untie_embeddings_and_output_weights) | |
| print_rank_0("Done loading checkpoint") | |
| unwrapped_model = unwrap_model(model)[0] | |
| unwrapped_model.eval() | |
| validator = MegatronARValidation(unwrapped_model, tokenizer) | |
| gt = [] | |
| ar = [] | |
| for prompt, truth in zip(prompts, ground_truth): | |
| output = validator.validate(args.osl, prompt, ground_truth=truth, steps=args.steps) | |
| gt.append(output[0]) | |
| ar.append(output[1]) | |
| print_rank_0("Acceptance Rate: " + str(ar)) | |
| print_rank_0("Average: " + str(sum(ar)/len(ar))) | |
| if args.save_ground_truth_path is not None: | |
| torch.save(gt, args.save_ground_truth_path) | |