|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import copy |
|
|
import os |
|
|
import time |
|
|
import traceback |
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed |
|
|
from importlib.machinery import SourceFileLoader |
|
|
from multiprocessing import get_context |
|
|
from typing import Optional, Union |
|
|
|
|
|
import torch |
|
|
|
|
|
from tensorrt_llm.auto_parallel import infer_cluster_config |
|
|
from tensorrt_llm.auto_parallel.cluster_info import cluster_infos |
|
|
from tensorrt_llm.builder import BuildConfig, Engine, build |
|
|
from tensorrt_llm.functional import PositionEmbeddingType |
|
|
from tensorrt_llm.logger import logger |
|
|
from tensorrt_llm.lora_manager import LoraConfig, LoraManager |
|
|
from tensorrt_llm.models import MODEL_MAP, PretrainedConfig |
|
|
from tensorrt_llm.models.modeling_utils import SpeculativeDecodingMode |
|
|
from tensorrt_llm.plugin import PluginConfig, add_plugin_argument |
|
|
|
|
|
|
|
|
def parse_arguments(): |
|
|
parser = argparse.ArgumentParser( |
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
|
|
parser.add_argument('--checkpoint_dir', type=str, default=None) |
|
|
parser.add_argument('--model_config', type=str, default=None) |
|
|
parser.add_argument('--build_config', type=str, default=None) |
|
|
parser.add_argument('--model_cls_file', type=str, default=None) |
|
|
parser.add_argument('--model_cls_name', type=str, default=None) |
|
|
parser.add_argument( |
|
|
'--input_timing_cache', |
|
|
type=str, |
|
|
default=None, |
|
|
help= |
|
|
'The path to read timing cache, will be ignored if the file does not exist' |
|
|
) |
|
|
parser.add_argument('--output_timing_cache', |
|
|
type=str, |
|
|
default='model.cache', |
|
|
help='The path to write timing cache') |
|
|
parser.add_argument('--log_level', type=str, default='info') |
|
|
parser.add_argument( |
|
|
'--profiling_verbosity', |
|
|
type=str, |
|
|
default='layer_names_only', |
|
|
choices=['layer_names_only', 'detailed', 'none'], |
|
|
help= |
|
|
'The profiling verbosity for the generated TRT engine. Set to detailed can inspect tactic choices and kernel parameters.' |
|
|
) |
|
|
parser.add_argument('--enable_debug_output', |
|
|
default=False, |
|
|
action='store_true') |
|
|
parser.add_argument( |
|
|
'--output_dir', |
|
|
type=str, |
|
|
default='engine_outputs', |
|
|
help='The path to save the serialized engine files and model configs') |
|
|
parser.add_argument('--workers', |
|
|
type=int, |
|
|
default='1', |
|
|
help='The number of workers for building in parallel') |
|
|
parser.add_argument( |
|
|
'--max_batch_size', |
|
|
type=int, |
|
|
default=256, |
|
|
help="Max number of requests that the engine can handle.") |
|
|
parser.add_argument('--max_input_len', |
|
|
type=int, |
|
|
default=1024, |
|
|
help="Max input length of one request.") |
|
|
parser.add_argument( |
|
|
'--max_seq_len', |
|
|
'--max_decoder_seq_len', |
|
|
dest='max_seq_len', |
|
|
type=int, |
|
|
default=None, |
|
|
help="Max total length of one request, including prompt and outputs. " |
|
|
"If unspecified, will try to deduce from the model config.") |
|
|
parser.add_argument('--max_beam_width', type=int, default=1) |
|
|
parser.add_argument( |
|
|
'--max_num_tokens', |
|
|
type=int, |
|
|
default=8192, |
|
|
help="Max number of batched input tokens after padding is removed " |
|
|
"(triggered by `--remove_input_padding`) in each batch.") |
|
|
parser.add_argument( |
|
|
'--opt_num_tokens', |
|
|
type=int, |
|
|
default=None, |
|
|
help='It equals to max_batch_size*max_beam_width by default, set this ' |
|
|
'value as close as possible to the actual number of tokens on your workload. ' |
|
|
'Note that this argument might be removed in the future.') |
|
|
parser.add_argument('--tp_size', type=int, default=1) |
|
|
parser.add_argument('--pp_size', type=int, default=1) |
|
|
parser.add_argument( |
|
|
'--max_prompt_embedding_table_size', |
|
|
'--max_multimodal_len', |
|
|
type=int, |
|
|
default=0, |
|
|
help= |
|
|
'Setting to a value > 0 enables support for prompt tuning or multimodal input.' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--use_fused_mlp', |
|
|
default=False, |
|
|
action='store_true', |
|
|
help= |
|
|
'Enable horizontal fusion in GatedMLP, reduces layer input traffic and potentially improves performance. ' |
|
|
'For FP8 PTQ, the downside is slight reduction of accuracy because one of the quantization scaling factors is discarded. ' |
|
|
'(An example for reference only: 0.45734 vs 0.45755 for LLaMA-v2 7B using `modelopt/examples/hf/instruct_eval/mmlu.py`).' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--gather_all_token_logits', |
|
|
action='store_true', |
|
|
default=False, |
|
|
help='Enable both gather_context_logits and gather_generation_logits') |
|
|
parser.add_argument('--gather_context_logits', |
|
|
action='store_true', |
|
|
default=False, |
|
|
help='Gather context logits') |
|
|
parser.add_argument('--gather_generation_logits', |
|
|
action='store_true', |
|
|
default=False, |
|
|
help='Gather generation logits') |
|
|
|
|
|
parser.add_argument('--builder_opt', type=int, default=None) |
|
|
parser.add_argument('--builder_force_num_profiles', type=int, default=None) |
|
|
parser.add_argument('--logits_dtype', |
|
|
type=str, |
|
|
default=None, |
|
|
choices=['float16', 'float32']) |
|
|
parser.add_argument('--weight_sparsity', default=False, action='store_true') |
|
|
parser.add_argument( |
|
|
'--max_draft_len', |
|
|
type=int, |
|
|
default=0, |
|
|
help= |
|
|
'Maximum lengths of draft tokens for speculative decoding target model.' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--lora_dir', |
|
|
type=str, |
|
|
default=None, |
|
|
nargs="+", |
|
|
help="The directory of LoRA weights. " |
|
|
"Use config from the first directory if multiple directories are provided." |
|
|
) |
|
|
parser.add_argument('--lora_ckpt_source', |
|
|
type=str, |
|
|
default="hf", |
|
|
choices=["hf", "nemo"], |
|
|
help="The source of lora checkpoint.") |
|
|
parser.add_argument( |
|
|
'--lora_target_modules', |
|
|
nargs='+', |
|
|
default=None, |
|
|
choices=LoraManager.LORA_MODULE_IDS.keys(), |
|
|
help= |
|
|
"Add lora in which modules. Only be activated when use_lora_plugin is enabled." |
|
|
) |
|
|
parser.add_argument( |
|
|
'--max_lora_rank', |
|
|
type=int, |
|
|
default=64, |
|
|
help='maximum lora rank for different lora modules. ' |
|
|
'It is used to compute the workspace size of lora plugin.') |
|
|
parser.add_argument('--auto_parallel', |
|
|
type=int, |
|
|
default=1, |
|
|
help='MPI world size for auto parallel.') |
|
|
parser.add_argument( |
|
|
'--gpus_per_node', |
|
|
type=int, |
|
|
default=8, |
|
|
help= |
|
|
'Number of GPUs each node has in a multi-node setup. This is a cluster spec and can be greater/smaller than world size' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--cluster_key', |
|
|
type=str, |
|
|
default=None, |
|
|
choices=cluster_infos.keys(), |
|
|
help= |
|
|
'Unique name for target GPU type. Inferred from current GPU type if not specified.' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--strip_plan', |
|
|
default=False, |
|
|
action='store_true', |
|
|
help= |
|
|
'Whether to strip weights from the final TRT engine under the assumption that the refit weights will be identical to those provided at build time.' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--max_encoder_input_len', |
|
|
type=int, |
|
|
default=1024, |
|
|
help= |
|
|
'Specify max encoder input length when using enc-dec models. Set max_input_len to 1 to start generation from decoder_start_token_id of length 1.' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--visualize_network', |
|
|
default=False, |
|
|
action='store_true', |
|
|
help= |
|
|
'TRT Networks will be exported to ONNX prior to Engine build for debugging. ' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--dry_run', |
|
|
default=False, |
|
|
action='store_true', |
|
|
help= |
|
|
'Run through the build process except the actual Engine build for debugging. ' |
|
|
) |
|
|
parser.add_argument('--speculative_decoding_mode', |
|
|
default=None, |
|
|
choices=[ |
|
|
"draft_tokens_external", |
|
|
"lookahead_decoding", |
|
|
"medusa", |
|
|
"explicit_draft_tokens", |
|
|
], |
|
|
help='Mode of speculative decoding.') |
|
|
parser.add_argument( |
|
|
'--weight_streaming', |
|
|
default=False, |
|
|
action='store_true', |
|
|
help= |
|
|
'Specify whether offloading weights to CPU and streaming loading at runtime.', |
|
|
) |
|
|
|
|
|
plugin_config_parser = parser.add_argument_group("plugin_config") |
|
|
add_plugin_argument(plugin_config_parser) |
|
|
|
|
|
args = parser.parse_args() |
|
|
if args.gather_all_token_logits: |
|
|
args.gather_context_logits = True |
|
|
args.gather_generation_logits = True |
|
|
|
|
|
if args.gather_context_logits and args.max_draft_len > 0: |
|
|
raise RuntimeError( |
|
|
"Gather context logits is not support with draft len > 0. " |
|
|
"If want to get the accepted tokens' logits from target model, please just enable gather_generation_logits" |
|
|
) |
|
|
return args |
|
|
|
|
|
|
|
|
def build_model( |
|
|
build_config: BuildConfig, |
|
|
rank: int = 0, |
|
|
ckpt_dir: str = None, |
|
|
model_config: Union[str, PretrainedConfig] = None, |
|
|
model_cls=None, |
|
|
dry_run: |
|
|
bool = False, |
|
|
**kwargs |
|
|
) -> Union[Engine, BuildConfig]: |
|
|
model_config = copy.deepcopy(model_config) |
|
|
|
|
|
logits_dtype = kwargs.get('logits_dtype') |
|
|
if logits_dtype is not None: |
|
|
model_config.logits_dtype = logits_dtype |
|
|
|
|
|
architecture = model_config.architecture |
|
|
assert not build_config.plugin_config.streamingllm or architecture == "LlamaForCausalLM", \ |
|
|
"StreamingLLM is only supported in the llama model." |
|
|
real_rank = rank |
|
|
|
|
|
if build_config.plugin_config.reduce_fusion and model_config.mapping.tp_size == 1: |
|
|
build_config.plugin_config.reduce_fusion = False |
|
|
|
|
|
model_config.mapping.gpus_per_node = build_config.auto_parallel_config.gpus_per_node |
|
|
if build_config.auto_parallel_config.enabled: |
|
|
assert rank < build_config.auto_parallel_config.world_size |
|
|
assert model_config.mapping.pp_size == 1 and model_config.mapping.tp_size == 1, \ |
|
|
"You must convert to full model with TP=1&&PP=1 to use auto parallel planner" |
|
|
|
|
|
rank = 0 |
|
|
else: |
|
|
assert rank < model_config.mapping.world_size |
|
|
|
|
|
rank_config = copy.deepcopy(model_config) |
|
|
rank_config.set_rank(rank) |
|
|
|
|
|
if model_cls is None: |
|
|
assert architecture in MODEL_MAP, \ |
|
|
f"Unsupported model architecture: {architecture}" |
|
|
model_cls = MODEL_MAP[architecture] |
|
|
if ckpt_dir is None: |
|
|
model = model_cls(rank_config) |
|
|
else: |
|
|
model = model_cls.from_checkpoint(ckpt_dir, config=rank_config) |
|
|
is_checkpoint_pruned = getattr(rank_config, 'is_pruned', False) |
|
|
|
|
|
if build_config.plugin_config.lora_plugin is not None: |
|
|
lora_config = LoraConfig(lora_dir=kwargs['lora_dir'] or [], |
|
|
lora_ckpt_source=kwargs['lora_ckpt_source'], |
|
|
max_lora_rank=kwargs['max_lora_rank']) |
|
|
if kwargs['lora_target_modules'] is not None: |
|
|
|
|
|
lora_config.lora_target_modules = kwargs['lora_target_modules'] |
|
|
build_config.lora_config = lora_config |
|
|
|
|
|
build_config.use_fused_mlp = kwargs.get('use_fused_mlp', False) |
|
|
|
|
|
if build_config.auto_parallel_config.enabled: |
|
|
model.config.mapping.rank = real_rank |
|
|
|
|
|
if is_checkpoint_pruned or kwargs.pop('strip_plan', False): |
|
|
build_config.use_strip_plan = True |
|
|
build_config.use_refit = kwargs.get('refit', False) |
|
|
|
|
|
if dry_run: |
|
|
return build_config |
|
|
|
|
|
return build(model, build_config) |
|
|
|
|
|
|
|
|
def build_and_save(rank, gpu_id, ckpt_dir, build_config, output_dir, log_level, |
|
|
model_config, model_cls, **kwargs): |
|
|
torch.cuda.set_device(gpu_id) |
|
|
logger.set_level(log_level) |
|
|
engine = build_model(build_config, |
|
|
rank, |
|
|
ckpt_dir, |
|
|
model_config, |
|
|
model_cls=model_cls, |
|
|
**kwargs) |
|
|
assert engine is not None |
|
|
engine.save(output_dir) |
|
|
return True |
|
|
|
|
|
|
|
|
def parallel_build(model_config: PretrainedConfig, |
|
|
ckpt_dir: Optional[str], |
|
|
build_config: BuildConfig, |
|
|
output_dir: str, |
|
|
workers: int = 1, |
|
|
log_level: str = 'info', |
|
|
model_cls=None, |
|
|
**kwargs): |
|
|
|
|
|
if build_config.auto_parallel_config.enabled: |
|
|
if model_config.mapping.world_size > 1: |
|
|
raise RuntimeError( |
|
|
"manually TP and PP are not supported in auto parallel mode.") |
|
|
if build_config.auto_parallel_config.debug_mode: |
|
|
world_size = 1 |
|
|
else: |
|
|
world_size = build_config.auto_parallel_config.world_size |
|
|
else: |
|
|
world_size = model_config.mapping.world_size |
|
|
|
|
|
if workers == 1: |
|
|
for rank in range(world_size): |
|
|
passed = build_and_save(rank, rank % workers, ckpt_dir, |
|
|
build_config, output_dir, log_level, |
|
|
model_config, model_cls, **kwargs) |
|
|
assert passed, "Engine building failed, please check error log." |
|
|
else: |
|
|
with ProcessPoolExecutor(mp_context=get_context('spawn'), |
|
|
max_workers=workers) as p: |
|
|
futures = [ |
|
|
p.submit(build_and_save, rank, rank % workers, ckpt_dir, |
|
|
build_config, output_dir, log_level, model_config, |
|
|
model_cls, **kwargs) for rank in range(world_size) |
|
|
] |
|
|
exceptions = [] |
|
|
for future in as_completed(futures): |
|
|
try: |
|
|
future.result() |
|
|
except Exception as e: |
|
|
traceback.print_exc() |
|
|
exceptions.append(e) |
|
|
assert len(exceptions |
|
|
) == 0, "Engine building failed, please check error log." |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_arguments() |
|
|
logger.set_level(args.log_level) |
|
|
tik = time.time() |
|
|
|
|
|
if not os.path.exists(args.output_dir): |
|
|
os.makedirs(args.output_dir) |
|
|
|
|
|
model_cls = None |
|
|
if args.model_cls_file is not None: |
|
|
assert args.model_cls_name is not None |
|
|
loader = SourceFileLoader('models', args.model_cls_file) |
|
|
mod = loader.load_module() |
|
|
model_cls = getattr(mod, args.model_cls_name) |
|
|
|
|
|
workers = min(torch.cuda.device_count(), args.workers) |
|
|
|
|
|
plugin_config = PluginConfig.from_arguments(args) |
|
|
|
|
|
kwargs = { |
|
|
'logits_dtype': args.logits_dtype, |
|
|
'use_fused_mlp': args.use_fused_mlp, |
|
|
'tp_size': args.tp_size, |
|
|
'pp_size': args.pp_size, |
|
|
'lora_dir': args.lora_dir, |
|
|
'lora_ckpt_source': args.lora_ckpt_source, |
|
|
'max_lora_rank': args.max_lora_rank, |
|
|
'lora_target_modules': args.lora_target_modules, |
|
|
'strip_plan': args.strip_plan, |
|
|
'refit': False, |
|
|
} |
|
|
speculative_decoding_mode = SpeculativeDecodingMode.from_arguments(args) |
|
|
|
|
|
ckpt_dir_or_model_config = args.checkpoint_dir if args.checkpoint_dir is not None else args.model_config |
|
|
if ckpt_dir_or_model_config.lower().endswith('.json'): |
|
|
config_path = ckpt_dir_or_model_config |
|
|
ckpt_dir = None |
|
|
else: |
|
|
config_path = os.path.join(ckpt_dir_or_model_config, 'config.json') |
|
|
ckpt_dir = ckpt_dir_or_model_config |
|
|
|
|
|
model_config = PretrainedConfig.from_json_file(config_path) |
|
|
|
|
|
if args.build_config is None: |
|
|
if args.multiple_profiles == "enable" and args.opt_num_tokens is not None: |
|
|
raise RuntimeError( |
|
|
"multiple_profiles is enabled, while opt_num_tokens is set. " |
|
|
"They are not supposed to be working in the same time for now.") |
|
|
if args.cluster_key is not None: |
|
|
cluster_config = dict(cluster_key=args.cluster_key) |
|
|
else: |
|
|
cluster_config = infer_cluster_config() |
|
|
|
|
|
|
|
|
rotary_scaling = getattr(model_config, "rotary_scaling", None) |
|
|
if rotary_scaling is not None: |
|
|
rotary_type = rotary_scaling['type'] |
|
|
rotary_factor = rotary_scaling.get( |
|
|
'factor', 1.0) if rotary_type != 'su' else 1 |
|
|
else: |
|
|
rotary_factor = 1 |
|
|
|
|
|
if args.max_seq_len is None: |
|
|
|
|
|
deduced_max_seq_len = 2048 |
|
|
if model_config.max_position_embeddings is not None: |
|
|
deduced_max_seq_len = model_config.max_position_embeddings |
|
|
|
|
|
|
|
|
if rotary_factor != 1: |
|
|
deduced_max_seq_len *= rotary_factor |
|
|
logger.warning( |
|
|
f'max_seq_len is scaled to {deduced_max_seq_len} by rotary scaling {rotary_factor}' |
|
|
) |
|
|
|
|
|
|
|
|
args.max_seq_len = deduced_max_seq_len |
|
|
logger.info( |
|
|
f'max_seq_len is not specified, using value {deduced_max_seq_len}' |
|
|
) |
|
|
else: |
|
|
if not plugin_config.streamingllm and model_config.max_position_embeddings is not None \ |
|
|
and model_config.position_embedding_type != PositionEmbeddingType.relative: |
|
|
if args.max_seq_len > model_config.max_position_embeddings * rotary_factor: |
|
|
logger.warning( |
|
|
f'max_seq_len {args.max_seq_len} is larger than max_position_embeddings {model_config.max_position_embeddings} * rotary scaling {rotary_factor}, ' |
|
|
'the model accuracy might be affected') |
|
|
|
|
|
if args.max_input_len > args.max_seq_len: |
|
|
logger.warning( |
|
|
f'max_input_len is {args.max_input_len} is larger than max_seq_len {args.max_seq_len}, clipping it to max_seq_len' |
|
|
) |
|
|
args.max_input_len = args.max_seq_len |
|
|
|
|
|
build_config = BuildConfig.from_dict( |
|
|
{ |
|
|
'max_input_len': args.max_input_len, |
|
|
'max_seq_len': args.max_seq_len, |
|
|
'max_batch_size': args.max_batch_size, |
|
|
'max_beam_width': args.max_beam_width, |
|
|
'max_num_tokens': args.max_num_tokens, |
|
|
'opt_num_tokens': args.opt_num_tokens, |
|
|
'max_prompt_embedding_table_size': |
|
|
args.max_prompt_embedding_table_size, |
|
|
'gather_context_logits': args.gather_context_logits, |
|
|
'gather_generation_logits': args.gather_generation_logits, |
|
|
'strongly_typed': True, |
|
|
'builder_opt': args.builder_opt, |
|
|
'force_num_profiles': args.builder_force_num_profiles, |
|
|
'weight_sparsity': args.weight_sparsity, |
|
|
'profiling_verbosity': args.profiling_verbosity, |
|
|
'enable_debug_output': args.enable_debug_output, |
|
|
'max_draft_len': args.max_draft_len, |
|
|
'speculative_decoding_mode': speculative_decoding_mode, |
|
|
'input_timing_cache': args.input_timing_cache, |
|
|
'output_timing_cache': args.output_timing_cache, |
|
|
'auto_parallel_config': { |
|
|
'world_size': |
|
|
args.auto_parallel, |
|
|
'gpus_per_node': |
|
|
args.gpus_per_node, |
|
|
'sharded_io_allowlist': [ |
|
|
'past_key_value_\\d+', |
|
|
'present_key_value_\\d*', |
|
|
], |
|
|
'same_buffer_io': { |
|
|
'past_key_value_(\\d+)': 'present_key_value_\\1', |
|
|
}, |
|
|
**cluster_config, |
|
|
}, |
|
|
'dry_run': args.dry_run, |
|
|
'visualize_network': args.visualize_network, |
|
|
'max_encoder_input_len': args.max_encoder_input_len, |
|
|
'weight_streaming': args.weight_streaming, |
|
|
}, |
|
|
plugin_config=plugin_config) |
|
|
else: |
|
|
build_config = BuildConfig.from_json_file(args.build_config, |
|
|
plugin_config=plugin_config) |
|
|
|
|
|
parallel_build(model_config, ckpt_dir, build_config, args.output_dir, |
|
|
workers, args.log_level, model_cls, **kwargs) |
|
|
|
|
|
tok = time.time() |
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) |
|
|
logger.info(f'Total time of building all engines: {t}') |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|