|
|
import contextlib |
|
|
import datetime |
|
|
import enum |
|
|
import hashlib |
|
|
import json |
|
|
import os |
|
|
import shutil |
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
from typing import Any, List, Optional |
|
|
|
|
|
import filelock |
|
|
|
|
|
import tensorrt_llm |
|
|
from tensorrt_llm.hlapi.llm_utils import BuildConfig |
|
|
from tensorrt_llm.logger import logger |
|
|
|
|
|
|
|
|
def get_build_cache_config_from_env() -> tuple[bool, str]: |
|
|
""" |
|
|
Get the build cache configuration from the environment variables |
|
|
""" |
|
|
build_cache_enabled = os.environ.get('TLLM_HLAPI_BUILD_CACHE') == '1' |
|
|
build_cache_root = os.environ.get( |
|
|
'TLLM_HLAPI_BUILD_CACHE_ROOT', |
|
|
'/tmp/.cache/tensorrt_llm/hlapi/') |
|
|
return build_cache_enabled, build_cache_root |
|
|
|
|
|
|
|
|
class BuildCacheConfig: |
|
|
""" |
|
|
Configuration for the build cache. |
|
|
|
|
|
Attributes: |
|
|
cache_root (str): The root directory for the build cache. |
|
|
max_records (int): The maximum number of records to store in the cache. |
|
|
max_cache_storage_gb (float): The maximum amount of storage (in GB) to use for the cache. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
cache_root: Optional[Path] = None, |
|
|
max_records: int = 10, |
|
|
max_cache_storage_gb: float = 256): |
|
|
self._cache_root = cache_root |
|
|
self._max_records = max_records |
|
|
self._max_cache_storage_gb = max_cache_storage_gb |
|
|
|
|
|
@property |
|
|
def cache_root(self) -> Path: |
|
|
_build_cache_enabled, _build_cache_root = get_build_cache_config_from_env( |
|
|
) |
|
|
return self._cache_root or Path(_build_cache_root) |
|
|
|
|
|
@property |
|
|
def max_records(self) -> int: |
|
|
return self._max_records |
|
|
|
|
|
@property |
|
|
def max_cache_storage_gb(self) -> float: |
|
|
return self._max_cache_storage_gb |
|
|
|
|
|
|
|
|
class BuildCache: |
|
|
""" |
|
|
The BuildCache class is a class that manages the intermediate products from the build steps. |
|
|
|
|
|
NOTE: currently, only engine-building is supported |
|
|
TODO[chunweiy]: add support for other build steps, such as quantization, convert_checkpoint, etc. |
|
|
""" |
|
|
|
|
|
CACHE_VERSION = 0 |
|
|
|
|
|
def __init__(self, config: Optional[BuildCacheConfig] = None): |
|
|
|
|
|
_, default_cache_root = get_build_cache_config_from_env() |
|
|
config = config or BuildCacheConfig() |
|
|
|
|
|
self.cache_root = config.cache_root or Path(default_cache_root) |
|
|
self.max_records = config.max_records |
|
|
self.max_cache_storage_gb = config.max_cache_storage_gb |
|
|
|
|
|
if config.max_records < 1: |
|
|
raise ValueError("max_records should be greater than 0") |
|
|
|
|
|
def get_engine_building_cache_stage(self, |
|
|
build_config: BuildConfig, |
|
|
model_path: Optional[Path] = None, |
|
|
**kwargs) -> 'CachedStage': |
|
|
''' |
|
|
Get the build step for engine building. |
|
|
''' |
|
|
from tensorrt_llm.hlapi.llm_utils import \ |
|
|
_ModelFormatKind |
|
|
force_rebuild = False |
|
|
if parallel_config := kwargs.get('parallel_config'): |
|
|
if parallel_config.auto_parallel: |
|
|
force_rebuild = True |
|
|
if model_format := kwargs.get('model_format'): |
|
|
if model_format is not _ModelFormatKind.HF: |
|
|
force_rebuild = True |
|
|
|
|
|
build_config_str = BuildCache.prune_build_config_for_cache_key( |
|
|
build_config.to_dict()) |
|
|
|
|
|
return CachedStage(parent=self, |
|
|
kind=CacheRecord.Kind.Engine, |
|
|
cache_root=self.cache_root, |
|
|
force_rebuild=force_rebuild, |
|
|
inputs=[build_config_str, model_path, kwargs]) |
|
|
|
|
|
def prune_caches(self, has_incoming_record: bool = False): |
|
|
''' |
|
|
Clean up the cache records to make sure the cache size is within the limit |
|
|
|
|
|
Args: |
|
|
has_incoming_record (bool): If the cache has incoming record, the existing records will be further pruned to |
|
|
reserve space for the incoming record |
|
|
''' |
|
|
if not self.cache_root.exists(): |
|
|
return |
|
|
self._clean_up_cache_dir() |
|
|
records = [] |
|
|
for dir in self.cache_root.iterdir(): |
|
|
records.append(self._load_cache_record(dir)) |
|
|
records.sort(key=lambda x: x.time, reverse=True) |
|
|
max_records = self.max_records - 1 if has_incoming_record else self.max_records |
|
|
|
|
|
while len(records) > max_records or sum( |
|
|
r.storage_gb for r in records) > self.max_cache_storage_gb: |
|
|
record = records.pop() |
|
|
|
|
|
shutil.rmtree(record.path) |
|
|
|
|
|
@staticmethod |
|
|
def prune_build_config_for_cache_key(build_config: dict) -> dict: |
|
|
|
|
|
black_list = ['auto_parallel_config', 'dry_run'] |
|
|
dic = build_config.copy() |
|
|
for key in black_list: |
|
|
if key in dic: |
|
|
dic.pop(key) |
|
|
return dic |
|
|
|
|
|
def load_cache_records(self) -> List["CacheRecord"]: |
|
|
''' |
|
|
Load all the cache records from the cache directory |
|
|
''' |
|
|
records = [] |
|
|
if not self.cache_root.exists(): |
|
|
return records |
|
|
|
|
|
for dir in self.cache_root.iterdir(): |
|
|
records.append(self._load_cache_record(dir)) |
|
|
return records |
|
|
|
|
|
def _load_cache_record(self, cache_dir) -> "CacheRecord": |
|
|
''' |
|
|
Get the cache record from the cache directory |
|
|
''' |
|
|
metadata = json.loads((cache_dir / 'metadata.json').read_text()) |
|
|
storage_gb = sum(f.stat().st_size for f in cache_dir.glob('**/*') |
|
|
if f.is_file()) / 1024**3 |
|
|
return CacheRecord(kind=CacheRecord.Kind.__members__[metadata['kind']], |
|
|
storage_gb=storage_gb, |
|
|
path=cache_dir, |
|
|
time=datetime.datetime.fromisoformat( |
|
|
metadata['datetime'])) |
|
|
|
|
|
def _clean_up_cache_dir(self): |
|
|
''' |
|
|
Clean up the files in the cache directory, remove anything that is not in the cache |
|
|
''' |
|
|
|
|
|
if not self.cache_root.exists(): |
|
|
return |
|
|
for file_or_dir in self.cache_root.iterdir(): |
|
|
if not self.is_cache_valid(file_or_dir): |
|
|
logger.info(f"Removing invalid cache directory {dir}") |
|
|
if file_or_dir.is_file(): |
|
|
file_or_dir.unlink() |
|
|
else: |
|
|
shutil.rmtree(file_or_dir) |
|
|
|
|
|
def is_cache_valid(self, cache_dir: Path) -> bool: |
|
|
''' |
|
|
Check if the cache directory is valid |
|
|
''' |
|
|
if not cache_dir.exists(): |
|
|
return False |
|
|
|
|
|
metadata_path = cache_dir / 'metadata.json' |
|
|
if not metadata_path.exists(): |
|
|
return False |
|
|
|
|
|
metadata = json.loads(metadata_path.read_text()) |
|
|
if metadata.get('version') != BuildCache.CACHE_VERSION: |
|
|
return False |
|
|
|
|
|
content = cache_dir / 'content' |
|
|
if not content.exists(): |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class CachedStage: |
|
|
''' |
|
|
CachedStage is a class that represents a stage in the build process, it helps to manage the intermediate product. |
|
|
|
|
|
The cache is organized as follows: |
|
|
|
|
|
this_cache_dir/ # name is like "engine-<hash>" |
|
|
metadata.json # the metadata of the cache |
|
|
content/ # the actual product of the build step, such trt-llm engine directory |
|
|
''' |
|
|
|
|
|
parent: BuildCache |
|
|
cache_root: Path |
|
|
|
|
|
inputs: List[Any] |
|
|
kind: "CacheRecord.Kind" |
|
|
|
|
|
force_rebuild: bool = False |
|
|
|
|
|
def get_hash_key(self): |
|
|
lib_version = tensorrt_llm.__version__ |
|
|
input_strs = [str(i) for i in self.inputs] |
|
|
return hashlib.md5( |
|
|
f"{lib_version}-{input_strs}".encode()).hexdigest() |
|
|
|
|
|
def get_cache_path(self) -> Path: |
|
|
''' |
|
|
The path to the product of the build step, will be overwritten if the step is re-run |
|
|
''' |
|
|
return self.cache_root / f"{self.kind.value}-{self.get_hash_key()}" |
|
|
|
|
|
def get_engine_path(self) -> Path: |
|
|
return self.get_cache_path() / 'content' |
|
|
|
|
|
def get_cache_metadata(self) -> dict: |
|
|
res = { |
|
|
"version": BuildCache.CACHE_VERSION, |
|
|
"datetime": datetime.datetime.now().isoformat(), |
|
|
"kind": self.kind.name, |
|
|
} |
|
|
return res |
|
|
|
|
|
def cache_hitted(self) -> bool: |
|
|
''' |
|
|
Check if the product of the build step is in the cache |
|
|
''' |
|
|
if self.force_rebuild: |
|
|
return False |
|
|
try: |
|
|
if self.get_cache_path().exists(): |
|
|
metadata = json.loads( |
|
|
(self.get_cache_path() / 'metadata.json').read_text()) |
|
|
if metadata["version"] == BuildCache.CACHE_VERSION: |
|
|
return True |
|
|
except: |
|
|
pass |
|
|
|
|
|
return False |
|
|
|
|
|
@contextlib.contextmanager |
|
|
def write_guard(self): |
|
|
''' |
|
|
Write the filelock to indicate that the build step is in progress |
|
|
''' |
|
|
self.parent.prune_caches(has_incoming_record=True) |
|
|
|
|
|
target_dir = self.get_cache_path() |
|
|
target_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
lock = filelock.FileLock(target_dir / '.filelock', timeout=10) |
|
|
|
|
|
with open(target_dir / 'metadata.json', 'w') as f: |
|
|
f.write(json.dumps(self.get_cache_metadata())) |
|
|
|
|
|
lock.__enter__() |
|
|
yield target_dir / 'content' |
|
|
lock.__exit__(None, None, None) |
|
|
|
|
|
|
|
|
@dataclass(unsafe_hash=True) |
|
|
class CacheRecord: |
|
|
''' |
|
|
CacheRecord is a class that represents a record in the cache directory. |
|
|
''' |
|
|
|
|
|
class Kind(enum.Enum): |
|
|
Engine = 'engine' |
|
|
Checkpoint = 'checkpoint' |
|
|
|
|
|
kind: Kind |
|
|
storage_gb: float |
|
|
path: Path |
|
|
time: datetime.datetime |
|
|
|