File size: 7,166 Bytes
5000658 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
'''
Script that refits TRT-LLM engine(s) with weights in a TRT-LLM checkpoint.
'''
import argparse
import copy
import json
import os
import re
import shutil
import time
from pathlib import Path
import tensorrt as trt
from tensorrt_llm._common import _is_building
from tensorrt_llm._utils import np_dtype_to_trt
from tensorrt_llm.builder import EngineConfig, optimize_model_with_config
from tensorrt_llm.models import MODEL_MAP
from ..logger import logger
ENGINE_RE = re.compile('rank(\d+).engine')
@_is_building
def refit_engine(engine_path: str, refit_engine_dir: str, checkpoint_dir: str,
engine_config: EngineConfig, fixed_weights_names: list):
# This function loops through all weights in the model and does a textual match between
# checkpoint weight names and engine weight names.
rank = int(ENGINE_RE.fullmatch(os.path.basename(engine_path)).group(1))
tik = time.time()
with open(engine_path,
"rb") as f, trt.Runtime(logger.trt_logger) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Load TRT engine time: {t}')
refitter = trt.Refitter(engine, logger.trt_logger)
refittable_weights = set(refitter.get_all_weights())
# Load model.
tik = time.time()
rank_config = copy.deepcopy(engine_config.pretrained_config)
rank_config.set_rank(rank)
architecture = rank_config.architecture
assert architecture in MODEL_MAP, \
f"Unsupported model architecture: {architecture}"
model_cls = MODEL_MAP[architecture]
model = model_cls.from_checkpoint(checkpoint_dir, config=rank_config)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Load checkpoint(s) time: {t}')
# There are weights preprocess during optimize model.
tik = time.time()
build_config = copy.deepcopy(engine_config.build_config)
optimize_model_with_config(model, build_config)
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Preprocess weights time: {t}')
# Refit engine.
tik = time.time()
refitted_weights = []
for name, buf in model.named_parameters():
if name in refittable_weights:
assert buf.is_inited, f"Failed because weight `{name}` is not initialized in model."
weight = buf._value
weights_value = trt.Weights(np_dtype_to_trt(weight.dtype),
weight.ctypes.data, weight.size)
assert refitter.set_named_weights(
name, weights_value), f'Failed to refit weight: `{name}`'
refitted_weights.append(name)
else:
if name not in fixed_weights_names:
logger.warning(
f"model weights `{name}` (shape: {buf._value.shape}) is not refittable, this means that we might not be able to update the engine using fine-tuned checkpoint!"
)
# Validate all refittable weights are provided.
if len(refitted_weights) != len(refittable_weights):
raise RuntimeError(
f'Missing refittable weights {refittable_weights.difference(refitted_weights)} from {checkpoint_dir}'
)
assert refitter.refit_cuda_engine(), f'Failed to refit engine.'
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Execute GPU refit graph time: {t}')
tik = time.time()
refit_engine_path = os.path.join(refit_engine_dir,
os.path.basename(engine_path))
with open(refit_engine_path, 'wb') as f:
logger.info(f'\nWriting refitted engine to `{refit_engine_path}`')
s_config = engine.create_serialization_config()
s_config.flags &= ~(1 << int(trt.SerializationFlag.EXCLUDE_WEIGHTS))
f.write(engine.serialize_with_config(s_config))
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Write TRT engine to disk time: {t}')
del refitter
def refit(engine_dir: str, checkpoint_dir: str, engine_config: EngineConfig,
output_dir: str, fixed_weights_names: list):
refit_engine_dir = output_dir
os.makedirs(refit_engine_dir, exist_ok=True)
shutil.copyfile(os.path.join(engine_dir, 'config.json'),
os.path.join(refit_engine_dir, 'config.json'))
engine_paths = list(Path(engine_dir).glob('*.engine'))
for path in engine_paths:
refit_engine(path, refit_engine_dir, checkpoint_dir, engine_config,
fixed_weights_names)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--engine_dir',
type=str,
default=None,
help=
'Path to trt-llm engines. These engines must have been built from a pruned checkpoint, or otherwise be refittable.'
)
parser.add_argument('--checkpoint_dir',
type=str,
default=None,
help='Path to checkpoint containing desired weights')
parser.add_argument('--output_dir',
type=str,
default=None,
help="Output path of the refit model")
parser.add_argument('--log_level', type=str, default='info')
args = parser.parse_args()
logger.set_level(args.log_level)
if args.engine_dir is None or not Path(args.engine_dir).exists():
raise RuntimeError(
f'Please supply a valid --engine_dir (found `{args.engine_dir}`)')
if args.checkpoint_dir is None or not Path(args.checkpoint_dir).exists():
raise RuntimeError(
f'Please supply a valid --checkpoint_dir (found `{args.checkpoint_dir}`)'
)
engine_config = EngineConfig.from_json_file(
os.path.join(args.engine_dir, 'config.json'))
with open(os.path.join(args.checkpoint_dir, 'config.json'), 'r') as f:
checkpoint_config = json.load(f)
engine_arch = engine_config.pretrained_config.architecture
checkpoint_arch = checkpoint_config['architecture']
if engine_arch != checkpoint_arch:
raise RuntimeError(
f'Engine Architecture and Checkpoint Architecture do not match. ' +
f'Engine Architecture: `{engine_arch}`, Checkpoint Architecture: `{checkpoint_arch}`'
)
# The fixed weights are not read from checkpoint, they are hardcoded buffer from the model itself. These values remain constant across different fine-tuned checkpoints.
fixed_wts_in_model = []
model_cls = MODEL_MAP[engine_arch]
model = model_cls.from_config(engine_config.pretrained_config)
for name, param in model.named_parameters():
if param.is_inited():
fixed_wts_in_model.append(name)
refit(engine_dir=os.path.normpath(args.engine_dir),
checkpoint_dir=os.path.normpath(args.checkpoint_dir),
engine_config=engine_config,
output_dir=os.path.normpath(args.output_dir),
fixed_weights_names=fixed_wts_in_model)
if __name__ == '__main__':
main()
|