|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _add_trt_llm_dll_directory(): |
|
|
import platform |
|
|
on_windows = platform.system() == "Windows" |
|
|
if on_windows: |
|
|
import os |
|
|
import sysconfig |
|
|
from pathlib import Path |
|
|
os.add_dll_directory( |
|
|
Path(sysconfig.get_paths()['purelib']) / "tensorrt_llm" / "libs") |
|
|
|
|
|
|
|
|
_add_trt_llm_dll_directory() |
|
|
|
|
|
import sys |
|
|
|
|
|
import tensorrt_llm.functional as functional |
|
|
import tensorrt_llm.models as models |
|
|
import tensorrt_llm.quantization as quantization |
|
|
import tensorrt_llm.runtime as runtime |
|
|
import tensorrt_llm.tools as tools |
|
|
|
|
|
from ._common import _init, default_net, default_trtnet, precision |
|
|
|
|
|
|
|
|
from ._utils import mpi_barrier |
|
|
from ._utils import str_dtype_to_torch |
|
|
from ._utils import (mpi_rank, mpi_world_size, str_dtype_to_trt, |
|
|
torch_dtype_to_trt) |
|
|
from .auto_parallel import AutoParallelConfig, auto_parallel |
|
|
from .builder import BuildConfig, Builder, BuilderConfig, build |
|
|
from .functional import Tensor, constant |
|
|
from .hlapi.llm import LLM, LlmArgs, SamplingParams |
|
|
from .logger import logger |
|
|
from .mapping import Mapping |
|
|
from .module import Module |
|
|
from .network import Network, net_guard |
|
|
from .parameter import Parameter |
|
|
from .version import __version__ |
|
|
|
|
|
__all__ = [ |
|
|
'logger', |
|
|
'str_dtype_to_trt', |
|
|
'torch_dtype_to_trt', |
|
|
'str_dtype_to_torch' |
|
|
'mpi_barrier', |
|
|
'mpi_rank', |
|
|
'mpi_world_size', |
|
|
'constant', |
|
|
'default_net', |
|
|
'default_trtnet', |
|
|
'precision', |
|
|
'net_guard', |
|
|
'Network', |
|
|
'Mapping', |
|
|
'Builder', |
|
|
'BuilderConfig', |
|
|
'build', |
|
|
'BuildConfig', |
|
|
'Tensor', |
|
|
'Parameter', |
|
|
'runtime', |
|
|
'Module', |
|
|
'functional', |
|
|
'models', |
|
|
'auto_parallel', |
|
|
'AutoParallelConfig', |
|
|
'quantization', |
|
|
'tools', |
|
|
'LLM', |
|
|
'LlmArgs', |
|
|
'SamplingParams', |
|
|
'KvCacheConfig', |
|
|
'__version__', |
|
|
] |
|
|
|
|
|
_init(log_level="error") |
|
|
|
|
|
print(f"[TensorRT-LLM] TensorRT-LLM version: {__version__}") |
|
|
|
|
|
sys.stdout.flush() |
|
|
|