File size: 10,444 Bytes
5000658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
import contextlib
import datetime
import enum
import hashlib
import json
import os
import shutil
from dataclasses import dataclass
from pathlib import Path
from typing import Any, List, Optional

import filelock

import tensorrt_llm
from tensorrt_llm.hlapi.llm_utils import BuildConfig
from tensorrt_llm.logger import logger


def get_build_cache_config_from_env() -> tuple[bool, str]:
    """
    Get the build cache configuration from the environment variables
    """
    build_cache_enabled = os.environ.get('TLLM_HLAPI_BUILD_CACHE') == '1'
    build_cache_root = os.environ.get(
        'TLLM_HLAPI_BUILD_CACHE_ROOT',
        '/tmp/.cache/tensorrt_llm/hlapi/')  # nosec B108
    return build_cache_enabled, build_cache_root


class BuildCacheConfig:
    """
    Configuration for the build cache.

    Attributes:
        cache_root (str): The root directory for the build cache.
        max_records (int): The maximum number of records to store in the cache.
        max_cache_storage_gb (float): The maximum amount of storage (in GB) to use for the cache.
    """

    def __init__(self,
                 cache_root: Optional[Path] = None,
                 max_records: int = 10,
                 max_cache_storage_gb: float = 256):
        self._cache_root = cache_root
        self._max_records = max_records
        self._max_cache_storage_gb = max_cache_storage_gb

    @property
    def cache_root(self) -> Path:
        _build_cache_enabled, _build_cache_root = get_build_cache_config_from_env(
        )
        return self._cache_root or Path(_build_cache_root)

    @property
    def max_records(self) -> int:
        return self._max_records

    @property
    def max_cache_storage_gb(self) -> float:
        return self._max_cache_storage_gb


class BuildCache:
    """
    The BuildCache class is a class that manages the intermediate products from the build steps.

    NOTE: currently, only engine-building is supported
    TODO[chunweiy]: add support for other build steps, such as quantization, convert_checkpoint, etc.
    """
    # The version of the cache, will be used to determine if the cache is compatible
    CACHE_VERSION = 0

    def __init__(self, config: Optional[BuildCacheConfig] = None):

        _, default_cache_root = get_build_cache_config_from_env()
        config = config or BuildCacheConfig()

        self.cache_root = config.cache_root or Path(default_cache_root)
        self.max_records = config.max_records
        self.max_cache_storage_gb = config.max_cache_storage_gb

        if config.max_records < 1:
            raise ValueError("max_records should be greater than 0")

    def get_engine_building_cache_stage(self,
                                        build_config: BuildConfig,
                                        model_path: Optional[Path] = None,
                                        **kwargs) -> 'CachedStage':
        '''
        Get the build step for engine building.
        '''
        from tensorrt_llm.hlapi.llm_utils import \
            _ModelFormatKind  # avoid cyclic import
        force_rebuild = False
        if parallel_config := kwargs.get('parallel_config'):
            if parallel_config.auto_parallel:
                force_rebuild = True
        if model_format := kwargs.get('model_format'):
            if model_format is not _ModelFormatKind.HF:
                force_rebuild = True

        build_config_str = BuildCache.prune_build_config_for_cache_key(
            build_config.to_dict())

        return CachedStage(parent=self,
                           kind=CacheRecord.Kind.Engine,
                           cache_root=self.cache_root,
                           force_rebuild=force_rebuild,
                           inputs=[build_config_str, model_path, kwargs])

    def prune_caches(self, has_incoming_record: bool = False):
        '''
        Clean up the cache records to make sure the cache size is within the limit

        Args:
            has_incoming_record (bool): If the cache has incoming record, the existing records will be further pruned to
            reserve space for the incoming record
        '''
        if not self.cache_root.exists():
            return
        self._clean_up_cache_dir()
        records = []
        for dir in self.cache_root.iterdir():
            records.append(self._load_cache_record(dir))
        records.sort(key=lambda x: x.time, reverse=True)
        max_records = self.max_records - 1 if has_incoming_record else self.max_records
        # prune the cache to meet max_records and max_cache_storage_gb limitation
        while len(records) > max_records or sum(
                r.storage_gb for r in records) > self.max_cache_storage_gb:
            record = records.pop()
            # remove the directory and its content
            shutil.rmtree(record.path)

    @staticmethod
    def prune_build_config_for_cache_key(build_config: dict) -> dict:
        # The BuildCache will be disabled once auto_pp is enabled, so 'auto_parallel_config' should be removed
        black_list = ['auto_parallel_config', 'dry_run']
        dic = build_config.copy()
        for key in black_list:
            if key in dic:
                dic.pop(key)
        return dic

    def load_cache_records(self) -> List["CacheRecord"]:
        '''
        Load all the cache records from the cache directory
        '''
        records = []
        if not self.cache_root.exists():
            return records

        for dir in self.cache_root.iterdir():
            records.append(self._load_cache_record(dir))
        return records

    def _load_cache_record(self, cache_dir) -> "CacheRecord":
        '''
        Get the cache record from the cache directory
        '''
        metadata = json.loads((cache_dir / 'metadata.json').read_text())
        storage_gb = sum(f.stat().st_size for f in cache_dir.glob('**/*')
                         if f.is_file()) / 1024**3
        return CacheRecord(kind=CacheRecord.Kind.__members__[metadata['kind']],
                           storage_gb=storage_gb,
                           path=cache_dir,
                           time=datetime.datetime.fromisoformat(
                               metadata['datetime']))

    def _clean_up_cache_dir(self):
        '''
        Clean up the files in the cache directory, remove anything that is not in the cache
        '''
        # get all the files and directies in the cache_root
        if not self.cache_root.exists():
            return
        for file_or_dir in self.cache_root.iterdir():
            if not self.is_cache_valid(file_or_dir):
                logger.info(f"Removing invalid cache directory {dir}")
                if file_or_dir.is_file():
                    file_or_dir.unlink()
                else:
                    shutil.rmtree(file_or_dir)

    def is_cache_valid(self, cache_dir: Path) -> bool:
        '''
        Check if the cache directory is valid
        '''
        if not cache_dir.exists():
            return False

        metadata_path = cache_dir / 'metadata.json'
        if not metadata_path.exists():
            return False

        metadata = json.loads(metadata_path.read_text())
        if metadata.get('version') != BuildCache.CACHE_VERSION:
            return False

        content = cache_dir / 'content'
        if not content.exists():
            return False

        return True


@dataclass
class CachedStage:
    '''
    CachedStage is a class that represents a stage in the build process, it helps to manage the intermediate product.

    The cache is organized as follows:

    this_cache_dir/     # name is like "engine-<hash>"
        metadata.json   # the metadata of the cache
        content/        # the actual product of the build step, such trt-llm engine directory
    '''
    # The parent should be kept alive by CachedStep instance
    parent: BuildCache
    cache_root: Path
    # The inputs will be used to determine if the step needs to be re-run, so all the variables should be put here
    inputs: List[Any]
    kind: "CacheRecord.Kind"
    # If force_rebuild is set to True, the cache will be ignored
    force_rebuild: bool = False

    def get_hash_key(self):
        lib_version = tensorrt_llm.__version__
        input_strs = [str(i) for i in self.inputs]
        return hashlib.md5(
            f"{lib_version}-{input_strs}".encode()).hexdigest()  # nosec B324

    def get_cache_path(self) -> Path:
        '''
        The path to the product of the build step, will be overwritten if the step is re-run
        '''
        return self.cache_root / f"{self.kind.value}-{self.get_hash_key()}"

    def get_engine_path(self) -> Path:
        return self.get_cache_path() / 'content'

    def get_cache_metadata(self) -> dict:
        res = {
            "version": BuildCache.CACHE_VERSION,
            "datetime": datetime.datetime.now().isoformat(),
            "kind": self.kind.name,
        }
        return res

    def cache_hitted(self) -> bool:
        '''
        Check if the product of the build step is in the cache
        '''
        if self.force_rebuild:
            return False
        try:
            if self.get_cache_path().exists():
                metadata = json.loads(
                    (self.get_cache_path() / 'metadata.json').read_text())
                if metadata["version"] == BuildCache.CACHE_VERSION:
                    return True
        except:
            pass

        return False

    @contextlib.contextmanager
    def write_guard(self):
        '''
        Write the filelock to indicate that the build step is in progress
        '''
        self.parent.prune_caches(has_incoming_record=True)

        target_dir = self.get_cache_path()
        target_dir.mkdir(parents=True, exist_ok=True)
        # TODO[chunweiy]: deal with the cache modification conflict
        lock = filelock.FileLock(target_dir / '.filelock', timeout=10)

        with open(target_dir / 'metadata.json', 'w') as f:
            f.write(json.dumps(self.get_cache_metadata()))

        lock.__enter__()
        yield target_dir / 'content'
        lock.__exit__(None, None, None)


@dataclass(unsafe_hash=True)
class CacheRecord:
    '''
    CacheRecord is a class that represents a record in the cache directory.
    '''

    class Kind(enum.Enum):
        Engine = 'engine'
        Checkpoint = 'checkpoint'

    kind: Kind
    storage_gb: float
    path: Path
    time: datetime.datetime