|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import logging |
|
|
import pprint |
|
|
from typing import Optional |
|
|
|
|
|
from nemo.export.tensorrt_llm import TensorRTLLM |
|
|
|
|
|
LOGGER = logging.getLogger("NeMo") |
|
|
|
|
|
|
|
|
def get_args(): |
|
|
parser = argparse.ArgumentParser( |
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
|
|
description="Exports NeMo checkpoint to TensorRT-LLM engine", |
|
|
) |
|
|
parser.add_argument("-nc", "--nemo_checkpoint", required=True, type=str, help="Source model path") |
|
|
parser.add_argument("-mt", "--model_type", type=str, help="Type of the TensorRT-LLM model.") |
|
|
parser.add_argument( |
|
|
"-mr", "--model_repository", required=True, default=None, type=str, help="Folder for the trt-llm model files" |
|
|
) |
|
|
parser.add_argument("-tps", "--tensor_parallelism_size", default=1, type=int, help="Tensor parallelism size") |
|
|
parser.add_argument("-pps", "--pipeline_parallelism_size", default=1, type=int, help="Pipeline parallelism size") |
|
|
parser.add_argument( |
|
|
"-dt", |
|
|
"--dtype", |
|
|
choices=["bfloat16", "float16"], |
|
|
help="Data type of the model on TensorRT-LLM", |
|
|
) |
|
|
parser.add_argument("-mil", "--max_input_len", default=256, type=int, help="Max input length of the model") |
|
|
parser.add_argument("-mol", "--max_output_len", default=256, type=int, help="Max output length of the model") |
|
|
parser.add_argument("-mbs", "--max_batch_size", default=8, type=int, help="Max batch size of the model") |
|
|
parser.add_argument("-mnt", "--max_num_tokens", default=None, type=int, help="Max number of tokens") |
|
|
parser.add_argument("-ont", "--opt_num_tokens", default=None, type=int, help="Optimum number of tokens") |
|
|
parser.add_argument( |
|
|
"-mpet", "--max_prompt_embedding_table_size", default=None, type=int, help="Max prompt embedding table size" |
|
|
) |
|
|
parser.add_argument( |
|
|
"-upe", |
|
|
"--use_parallel_embedding", |
|
|
default=False, |
|
|
action='store_true', |
|
|
help="Use parallel embedding.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"-npkc", "--no_paged_kv_cache", default=False, action='store_true', help="Disable paged kv cache." |
|
|
) |
|
|
parser.add_argument( |
|
|
"-drip", |
|
|
"--disable_remove_input_padding", |
|
|
default=False, |
|
|
action='store_true', |
|
|
help="Disables the remove input padding option.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"-mbm", |
|
|
'--multi_block_mode', |
|
|
default=False, |
|
|
action='store_true', |
|
|
help='Split long kv sequence into multiple blocks (applied to generation MHA kernels). \ |
|
|
It is beneifical when batchxnum_heads cannot fully utilize GPU. \ |
|
|
available when using c++ runtime.', |
|
|
) |
|
|
parser.add_argument( |
|
|
'--use_lora_plugin', |
|
|
nargs='?', |
|
|
const=None, |
|
|
choices=['float16', 'float32', 'bfloat16'], |
|
|
help="Activates the lora plugin which enables embedding sharing.", |
|
|
) |
|
|
parser.add_argument( |
|
|
'--lora_target_modules', |
|
|
nargs='+', |
|
|
default=None, |
|
|
choices=[ |
|
|
"attn_qkv", |
|
|
"attn_q", |
|
|
"attn_k", |
|
|
"attn_v", |
|
|
"attn_dense", |
|
|
"mlp_h_to_4h", |
|
|
"mlp_gate", |
|
|
"mlp_4h_to_h", |
|
|
], |
|
|
help="Add lora in which modules. Only be activated when use_lora_plugin is enabled.", |
|
|
) |
|
|
parser.add_argument( |
|
|
'--max_lora_rank', |
|
|
type=int, |
|
|
default=64, |
|
|
help='maximum lora rank for different lora modules. ' |
|
|
'It is used to compute the workspace size of lora plugin.', |
|
|
) |
|
|
parser.add_argument("-dm", "--debug_mode", default=False, action='store_true', help="Enable debug mode") |
|
|
parser.add_argument( |
|
|
"--use_mcore_path", |
|
|
action="store_true", |
|
|
help="Use Megatron-Core implementation on exporting the model. If not set, use local NeMo codebase", |
|
|
) |
|
|
parser.add_argument( |
|
|
"-fp8", |
|
|
"--export_fp8_quantized", |
|
|
default="auto", |
|
|
type=str, |
|
|
help="Enables exporting to a FP8-quantized TRT LLM checkpoint", |
|
|
) |
|
|
parser.add_argument( |
|
|
"-kv_fp8", |
|
|
"--use_fp8_kv_cache", |
|
|
default="auto", |
|
|
type=str, |
|
|
help="Enables exporting with FP8-quantizatized KV-cache", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
|
|
|
def str_to_bool(name: str, s: str, optional: bool = False) -> Optional[bool]: |
|
|
s = s.lower() |
|
|
true_strings = ["true", "1"] |
|
|
false_strings = ["false", "0"] |
|
|
if s in true_strings: |
|
|
return True |
|
|
if s in false_strings: |
|
|
return False |
|
|
if optional and s == 'auto': |
|
|
return None |
|
|
raise argparse.ArgumentTypeError(f"Invalid boolean value for argument --{name}: '{s}'") |
|
|
|
|
|
args.export_fp8_quantized = str_to_bool("export_fp8_quantized", args.export_fp8_quantized, optional=True) |
|
|
args.use_fp8_kv_cache = str_to_bool("use_fp8_kv_cache", args.use_fp8_kv_cache, optional=True) |
|
|
return args |
|
|
|
|
|
|
|
|
def nemo_export_trt_llm(): |
|
|
args = get_args() |
|
|
|
|
|
loglevel = logging.DEBUG if args.debug_mode else logging.INFO |
|
|
LOGGER.setLevel(loglevel) |
|
|
LOGGER.info(f"Logging level set to {loglevel}") |
|
|
LOGGER.info(pprint.pformat(vars(args))) |
|
|
|
|
|
trt_llm_exporter = TensorRTLLM( |
|
|
model_dir=args.model_repository, load_model=False, multi_block_mode=args.multi_block_mode |
|
|
) |
|
|
|
|
|
LOGGER.info("Export to TensorRT-LLM function is called.") |
|
|
trt_llm_exporter.export( |
|
|
nemo_checkpoint_path=args.nemo_checkpoint, |
|
|
model_type=args.model_type, |
|
|
tensor_parallelism_size=args.tensor_parallelism_size, |
|
|
pipeline_parallelism_size=args.pipeline_parallelism_size, |
|
|
max_input_len=args.max_input_len, |
|
|
max_output_len=args.max_output_len, |
|
|
max_batch_size=args.max_batch_size, |
|
|
max_num_tokens=args.max_num_tokens, |
|
|
opt_num_tokens=args.opt_num_tokens, |
|
|
max_prompt_embedding_table_size=args.max_prompt_embedding_table_size, |
|
|
use_parallel_embedding=args.use_parallel_embedding, |
|
|
paged_kv_cache=not args.no_paged_kv_cache, |
|
|
remove_input_padding=not args.disable_remove_input_padding, |
|
|
dtype=args.dtype, |
|
|
use_lora_plugin=args.use_lora_plugin, |
|
|
lora_target_modules=args.lora_target_modules, |
|
|
max_lora_rank=args.max_lora_rank, |
|
|
fp8_quantized=args.export_fp8_quantized, |
|
|
fp8_kvcache=args.use_fp8_kv_cache, |
|
|
load_model=False, |
|
|
use_mcore_path=args.use_mcore_path, |
|
|
) |
|
|
|
|
|
LOGGER.info("Export is successful.") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
nemo_export_trt_llm() |
|
|
|