File size: 7,166 Bytes
5000658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
'''
Script that refits TRT-LLM engine(s) with weights in a TRT-LLM checkpoint.
'''
import argparse
import copy
import json
import os
import re
import shutil
import time
from pathlib import Path

import tensorrt as trt

from tensorrt_llm._common import _is_building
from tensorrt_llm._utils import np_dtype_to_trt
from tensorrt_llm.builder import EngineConfig, optimize_model_with_config
from tensorrt_llm.models import MODEL_MAP

from ..logger import logger

ENGINE_RE = re.compile('rank(\d+).engine')


@_is_building
def refit_engine(engine_path: str, refit_engine_dir: str, checkpoint_dir: str,
                 engine_config: EngineConfig, fixed_weights_names: list):
    # This function loops through all weights in the model and does a textual match between
    # checkpoint weight names and engine weight names.
    rank = int(ENGINE_RE.fullmatch(os.path.basename(engine_path)).group(1))
    tik = time.time()
    with open(engine_path,
              "rb") as f, trt.Runtime(logger.trt_logger) as runtime:
        engine = runtime.deserialize_cuda_engine(f.read())
    tok = time.time()
    t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
    logger.info(f'Load TRT engine time: {t}')

    refitter = trt.Refitter(engine, logger.trt_logger)
    refittable_weights = set(refitter.get_all_weights())

    # Load model.
    tik = time.time()
    rank_config = copy.deepcopy(engine_config.pretrained_config)
    rank_config.set_rank(rank)

    architecture = rank_config.architecture
    assert architecture in MODEL_MAP, \
        f"Unsupported model architecture: {architecture}"
    model_cls = MODEL_MAP[architecture]
    model = model_cls.from_checkpoint(checkpoint_dir, config=rank_config)

    tok = time.time()
    t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
    logger.info(f'Load checkpoint(s) time: {t}')

    # There are weights preprocess during optimize model.
    tik = time.time()
    build_config = copy.deepcopy(engine_config.build_config)
    optimize_model_with_config(model, build_config)
    tok = time.time()
    t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
    logger.info(f'Preprocess weights time: {t}')

    # Refit engine.
    tik = time.time()
    refitted_weights = []
    for name, buf in model.named_parameters():
        if name in refittable_weights:
            assert buf.is_inited, f"Failed because weight `{name}` is not initialized in model."
            weight = buf._value
            weights_value = trt.Weights(np_dtype_to_trt(weight.dtype),
                                        weight.ctypes.data, weight.size)
            assert refitter.set_named_weights(
                name, weights_value), f'Failed to refit weight: `{name}`'
            refitted_weights.append(name)
        else:
            if name not in fixed_weights_names:
                logger.warning(
                    f"model weights `{name}` (shape: {buf._value.shape}) is not refittable, this means that we might not be able to update the engine using fine-tuned checkpoint!"
                )

    # Validate all refittable weights are provided.
    if len(refitted_weights) != len(refittable_weights):
        raise RuntimeError(
            f'Missing refittable weights {refittable_weights.difference(refitted_weights)} from {checkpoint_dir}'
        )

    assert refitter.refit_cuda_engine(), f'Failed to refit engine.'
    tok = time.time()
    t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
    logger.info(f'Execute GPU refit graph time: {t}')

    tik = time.time()
    refit_engine_path = os.path.join(refit_engine_dir,
                                     os.path.basename(engine_path))
    with open(refit_engine_path, 'wb') as f:
        logger.info(f'\nWriting refitted engine to `{refit_engine_path}`')
        s_config = engine.create_serialization_config()
        s_config.flags &= ~(1 << int(trt.SerializationFlag.EXCLUDE_WEIGHTS))
        f.write(engine.serialize_with_config(s_config))
    tok = time.time()
    t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
    logger.info(f'Write TRT engine to disk time: {t}')

    del refitter


def refit(engine_dir: str, checkpoint_dir: str, engine_config: EngineConfig,
          output_dir: str, fixed_weights_names: list):
    refit_engine_dir = output_dir
    os.makedirs(refit_engine_dir, exist_ok=True)
    shutil.copyfile(os.path.join(engine_dir, 'config.json'),
                    os.path.join(refit_engine_dir, 'config.json'))
    engine_paths = list(Path(engine_dir).glob('*.engine'))
    for path in engine_paths:
        refit_engine(path, refit_engine_dir, checkpoint_dir, engine_config,
                     fixed_weights_names)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--engine_dir',
        type=str,
        default=None,
        help=
        'Path to trt-llm engines. These engines must have been built from a pruned checkpoint, or otherwise be refittable.'
    )
    parser.add_argument('--checkpoint_dir',
                        type=str,
                        default=None,
                        help='Path to checkpoint containing desired weights')
    parser.add_argument('--output_dir',
                        type=str,
                        default=None,
                        help="Output path of the refit model")
    parser.add_argument('--log_level', type=str, default='info')
    args = parser.parse_args()

    logger.set_level(args.log_level)
    if args.engine_dir is None or not Path(args.engine_dir).exists():
        raise RuntimeError(
            f'Please supply a valid --engine_dir (found `{args.engine_dir}`)')
    if args.checkpoint_dir is None or not Path(args.checkpoint_dir).exists():
        raise RuntimeError(
            f'Please supply a valid --checkpoint_dir (found `{args.checkpoint_dir}`)'
        )

    engine_config = EngineConfig.from_json_file(
        os.path.join(args.engine_dir, 'config.json'))

    with open(os.path.join(args.checkpoint_dir, 'config.json'), 'r') as f:
        checkpoint_config = json.load(f)

    engine_arch = engine_config.pretrained_config.architecture
    checkpoint_arch = checkpoint_config['architecture']
    if engine_arch != checkpoint_arch:
        raise RuntimeError(
            f'Engine Architecture and Checkpoint Architecture do not match. ' +
            f'Engine Architecture: `{engine_arch}`, Checkpoint Architecture: `{checkpoint_arch}`'
        )

    # The fixed weights are not read from checkpoint, they are hardcoded buffer from the model itself. These values remain constant across different fine-tuned checkpoints.
    fixed_wts_in_model = []
    model_cls = MODEL_MAP[engine_arch]
    model = model_cls.from_config(engine_config.pretrained_config)
    for name, param in model.named_parameters():
        if param.is_inited():
            fixed_wts_in_model.append(name)

    refit(engine_dir=os.path.normpath(args.engine_dir),
          checkpoint_dir=os.path.normpath(args.checkpoint_dir),
          engine_config=engine_config,
          output_dir=os.path.normpath(args.output_dir),
          fixed_weights_names=fixed_wts_in_model)


if __name__ == '__main__':
    main()