|
|
''' |
|
|
This file is a script tool for generating TensorRT plugin library for Triton. |
|
|
''' |
|
|
import argparse |
|
|
import glob |
|
|
import logging |
|
|
import os |
|
|
import subprocess |
|
|
import sys |
|
|
from dataclasses import dataclass |
|
|
from typing import ClassVar, Iterable, List, Optional, Tuple, Union |
|
|
|
|
|
try: |
|
|
import triton |
|
|
except ImportError: |
|
|
raise ImportError("Triton is not installed. Please install it first.") |
|
|
|
|
|
from tensorrt_llm.tools.plugin_gen.core import (KernelMetaData, |
|
|
PluginCmakeCodegen, |
|
|
PluginCppCodegen, |
|
|
PluginPyCodegen, |
|
|
PluginRegistryCodegen) |
|
|
|
|
|
PYTHON_BIN = sys.executable |
|
|
|
|
|
TRITON_ROOT = os.path.dirname(triton.__file__) |
|
|
TRITON_COMPILE_BIN = os.path.join(TRITON_ROOT, 'tools', 'compile.py') |
|
|
TRITON_LINK_BIN = os.path.join(TRITON_ROOT, 'tools', 'link.py') |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class StageArgs: |
|
|
workspace: str |
|
|
kernel_name: str |
|
|
|
|
|
@property |
|
|
def sub_workspace(self) -> str: |
|
|
return os.path.join(self.workspace, self.kernel_name, |
|
|
f"_{self.stage_name}") |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class _TritonAotArgs(StageArgs): |
|
|
stage_name: ClassVar[str] = 'triton_aot' |
|
|
|
|
|
@dataclass |
|
|
class _AotConfig: |
|
|
output_name: str |
|
|
num_warps: int |
|
|
num_stages: int |
|
|
signature: str |
|
|
|
|
|
kernel_file: str |
|
|
configs: List[_AotConfig] |
|
|
grid_dims: Tuple[str, str, str] |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class _TritonKernelCompileArgs(StageArgs): |
|
|
stage_name: ClassVar[str] = 'compile_triton_kernel' |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class _TrtPluginGenArgs(StageArgs): |
|
|
stage_name: ClassVar[str] = 'generate_trt_plugin' |
|
|
|
|
|
config: KernelMetaData |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class _TrtPluginCompileArgs: |
|
|
workspace: str |
|
|
trt_lib_dir: Optional[str] = None |
|
|
trt_include_dir: Optional[str] = None |
|
|
trt_llm_include_dir: Optional[str] = None |
|
|
|
|
|
stage_name: ClassVar[str] = 'compile_trt_plugin' |
|
|
|
|
|
@property |
|
|
def sub_workspace(self) -> str: |
|
|
return self.workspace |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class _CopyOutputArgs: |
|
|
so_path: str |
|
|
functional_py_path: str |
|
|
output_dir: str |
|
|
|
|
|
stage_name: ClassVar[str] = 'copy_output' |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Stage: |
|
|
''' |
|
|
Stage represents a stage in the plugin generation process. e.g. Triton AOT could be a stage. |
|
|
''' |
|
|
|
|
|
config: Union[_TritonAotArgs, _TritonKernelCompileArgs, _TrtPluginGenArgs, |
|
|
_TrtPluginCompileArgs, _CopyOutputArgs] |
|
|
|
|
|
def run(self): |
|
|
stages = { |
|
|
_TritonAotArgs.stage_name: self.do_triton_aot, |
|
|
_TritonKernelCompileArgs.stage_name: self.do_compile_triton_kernel, |
|
|
_TrtPluginGenArgs.stage_name: self.do_generate_trt_plugin, |
|
|
_TrtPluginCompileArgs.stage_name: self.do_compile_trt_plugin, |
|
|
_CopyOutputArgs.stage_name: self.do_copy_output, |
|
|
} |
|
|
|
|
|
logging.info(f"Running stage {self.config.stage_name}") |
|
|
|
|
|
stages[self.config.stage_name]() |
|
|
|
|
|
def do_triton_aot(self): |
|
|
compile_dir = self.config.sub_workspace |
|
|
_clean_path(compile_dir) |
|
|
|
|
|
|
|
|
for config in self.config.configs: |
|
|
command = [ |
|
|
PYTHON_BIN, TRITON_COMPILE_BIN, self.config.kernel_file, '-n', |
|
|
self.config.kernel_name, '-o', f"{compile_dir}/kernel", |
|
|
'--out-name', config.output_name, '-w', |
|
|
str(config.num_warps), '-s', f"{config.signature}", '-g', |
|
|
','.join(self.config.grid_dims), '--num-stages', |
|
|
str(config.num_stages) |
|
|
] |
|
|
_run_command(command) |
|
|
|
|
|
|
|
|
h_files = glob.glob(os.path.join(compile_dir, '*.h')) |
|
|
command = [ |
|
|
PYTHON_BIN, |
|
|
TRITON_LINK_BIN, |
|
|
*h_files, |
|
|
'-o', |
|
|
os.path.join(compile_dir, 'launcher'), |
|
|
] |
|
|
_run_command(command) |
|
|
|
|
|
def do_compile_triton_kernel(self): |
|
|
''' |
|
|
Compile the triton kernel to library. |
|
|
''' |
|
|
|
|
|
|
|
|
from triton.common import cuda_include_dir, libcuda_dirs |
|
|
kernel_dir = os.path.join(self.config.workspace, |
|
|
self.config.kernel_name, '_triton_aot') |
|
|
compile_dir = self.config.sub_workspace |
|
|
_mkdir(compile_dir) |
|
|
_clean_path(compile_dir) |
|
|
|
|
|
c_files = glob.glob(os.path.join(os.getcwd(), kernel_dir, "*.c")) |
|
|
assert c_files |
|
|
_run_command([ |
|
|
"gcc", |
|
|
"-c", |
|
|
*c_files, |
|
|
"-I", |
|
|
cuda_include_dir(), |
|
|
"-L", |
|
|
libcuda_dirs()[0], |
|
|
"-l", |
|
|
"cuda", |
|
|
], |
|
|
cwd=compile_dir) |
|
|
|
|
|
o_files = glob.glob(os.path.join(os.getcwd(), compile_dir, "*.o")) |
|
|
assert o_files |
|
|
''' |
|
|
_run_command([ |
|
|
"ar", "rcs", |
|
|
f"lib{self.args.kernel_name}.a", *o_files |
|
|
], cwd=compile_dir) |
|
|
''' |
|
|
|
|
|
def do_generate_trt_plugin(self): |
|
|
''' |
|
|
Generate the trt plugin from the triton kernel library. |
|
|
''' |
|
|
|
|
|
workspace = self.config.sub_workspace |
|
|
_mkdir(workspace) |
|
|
_clean_path(workspace) |
|
|
|
|
|
PluginCppCodegen(output_dir=workspace, |
|
|
meta_data=self.config.config).generate() |
|
|
|
|
|
def do_compile_trt_plugin(self): |
|
|
''' |
|
|
Compile the trt plugin library. |
|
|
''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def collect_all_kernel_configs( |
|
|
workspace: str) -> Iterable[KernelMetaData]: |
|
|
ymls = glob.glob(os.path.join(workspace, '**/*.yml'), |
|
|
recursive=True) |
|
|
for path in ymls: |
|
|
yield KernelMetaData.load_from_yaml(path) |
|
|
|
|
|
configs = list(collect_all_kernel_configs(self.config.workspace)) |
|
|
|
|
|
kernel_names = [config.kernel_name for config in configs] |
|
|
|
|
|
PluginRegistryCodegen(out_path=os.path.join(self.config.sub_workspace, |
|
|
'tritonPlugins.cpp'), |
|
|
plugin_names=kernel_names).generate() |
|
|
PluginCmakeCodegen(out_path=os.path.join(self.config.sub_workspace, |
|
|
'CMakeLists.txt'), |
|
|
plugin_names=kernel_names, |
|
|
kernel_names=kernel_names, |
|
|
workspace=self.config.workspace).generate() |
|
|
|
|
|
functional_py = os.path.join(self.config.workspace, 'functional.py') |
|
|
plugin_lib_path = os.path.join(self.config.workspace, 'build', |
|
|
'libtriton_plugins.so') |
|
|
|
|
|
for i, kernel_config in enumerate(configs): |
|
|
PluginPyCodegen(out_path=functional_py, |
|
|
meta_data=kernel_config, |
|
|
add_header=i == 0, |
|
|
plugin_lib_path=plugin_lib_path).generate() |
|
|
|
|
|
|
|
|
_run_command(['mkdir', '-p', 'build'], cwd=self.config.sub_workspace) |
|
|
_run_command(['rm', '-rf', 'build/*'], cwd=self.config.sub_workspace) |
|
|
|
|
|
cmake_cmds = ['cmake', '..'] |
|
|
if self.config.trt_lib_dir is not None: |
|
|
cmake_cmds.append(f'-DTRT_LIB_DIR={self.config.trt_lib_dir}') |
|
|
if self.config.trt_include_dir: |
|
|
cmake_cmds.append( |
|
|
f'-DTRT_INCLUDE_DIR={self.config.trt_include_dir}') |
|
|
if self.config.trt_llm_include_dir is None: |
|
|
self.config.trt_llm_include_dir = os.path.join( |
|
|
os.path.dirname(os.path.abspath(__file__)), '../../../cpp') |
|
|
cmake_cmds.append( |
|
|
f'-DTRT_LLM_INCLUDE_DIR={self.config.trt_llm_include_dir}') |
|
|
_run_command(cmake_cmds, |
|
|
cwd=os.path.join(self.config.sub_workspace, "build")) |
|
|
_run_command(['make', '-j'], |
|
|
cwd=os.path.join(self.config.sub_workspace, "build")) |
|
|
|
|
|
def do_copy_output(self): |
|
|
''' |
|
|
Copy the output to the destination directory. |
|
|
''' |
|
|
_mkdir(self.config.output_dir) |
|
|
_clean_path(self.config.output_dir) |
|
|
|
|
|
_run_command(['cp', self.config.so_path, self.config.output_dir]) |
|
|
|
|
|
|
|
|
_run_command( |
|
|
['cp', self.config.functional_py_path, self.config.output_dir]) |
|
|
|
|
|
|
|
|
def gen_trt_plugins(workspace: str, |
|
|
metas: List[KernelMetaData], |
|
|
trt_lib_dir: Optional[str] = None, |
|
|
trt_include_dir: Optional[str] = None, |
|
|
trt_llm_include_dir: Optional[str] = None): |
|
|
''' |
|
|
Generate TRT plugins end-to-end. |
|
|
''' |
|
|
for meta in metas: |
|
|
Stage(meta.to_TritonAotArgs(workspace)).run() |
|
|
Stage(meta.to_TritonKernelCompileArgs(workspace)).run() |
|
|
Stage(meta.to_TrtPluginGenArgs(workspace)).run() |
|
|
|
|
|
|
|
|
compile_args = _TrtPluginCompileArgs( |
|
|
workspace=workspace, |
|
|
trt_lib_dir=trt_lib_dir, |
|
|
trt_include_dir=trt_include_dir, |
|
|
trt_llm_include_dir=trt_llm_include_dir, |
|
|
) |
|
|
Stage(compile_args).run() |
|
|
Stage(metas[0].to_CopyOutputArgs(workspace)).run() |
|
|
|
|
|
|
|
|
def _clean_path(path: str): |
|
|
''' |
|
|
Clean the content within this directory |
|
|
''' |
|
|
_rmdir(path) |
|
|
_mkdir(path) |
|
|
|
|
|
|
|
|
def _mkdir(path: str): |
|
|
''' |
|
|
mkdir if not exists |
|
|
''' |
|
|
subprocess.run(['/usr/bin/mkdir', '-p', path], check=True) |
|
|
|
|
|
|
|
|
def _rmdir(path: str): |
|
|
''' |
|
|
rmdir if exists |
|
|
''' |
|
|
subprocess.run(['/usr/bin/rm', '-rf', path], check=True) |
|
|
|
|
|
|
|
|
def _run_command(args, cwd=None): |
|
|
print(f"Running command: {' '.join(args)}") |
|
|
subprocess.run(args, check=True, cwd=cwd) |
|
|
|
|
|
|
|
|
def parse_arguments(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument( |
|
|
'--workspace', |
|
|
type=str, |
|
|
required=True, |
|
|
help="The root path to store all the intermediate files") |
|
|
parser.add_argument( |
|
|
'--kernel_config', |
|
|
type=str, |
|
|
required=True, |
|
|
help= |
|
|
'The path to the kernel config file, which should be a python module ' |
|
|
'containing KernelMetaData instances') |
|
|
parser.add_argument( |
|
|
'--trt_lib_dir', |
|
|
type=str, |
|
|
default=None, |
|
|
help='Directory to find TensorRT library. If None, find the library ' |
|
|
'from the system path or /usr/local/tensorrt.') |
|
|
parser.add_argument( |
|
|
'--trt_include_dir', |
|
|
type=str, |
|
|
default=None, |
|
|
help='Directory to find TensorRT headers. If None, find the headers ' |
|
|
'from the system path or /usr/local/tensorrt.') |
|
|
parser.add_argument('--trt_llm_include_dir', |
|
|
type=str, |
|
|
default=None, |
|
|
help='Directory to find TensorRT LLM C++ headers.') |
|
|
args = parser.parse_args() |
|
|
|
|
|
assert args.kernel_config.endswith('.py'), \ |
|
|
f"Kernel config {args.kernel_config} should be a python module" |
|
|
|
|
|
return args |
|
|
|
|
|
|
|
|
def import_from_file(module_name, file_path): |
|
|
import importlib.util |
|
|
spec = importlib.util.spec_from_file_location(module_name, file_path) |
|
|
module = importlib.util.module_from_spec(spec) |
|
|
|
|
|
sys.modules[module_name] = module |
|
|
spec.loader.exec_module(module) |
|
|
return module |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
args = parse_arguments() |
|
|
|
|
|
module_name = os.path.basename(args.kernel_config).replace('.py', '') |
|
|
config_module = import_from_file(module_name, args.kernel_config) |
|
|
kernel_configs: List[KernelMetaData] = config_module.KERNELS |
|
|
|
|
|
gen_trt_plugins(workspace=args.workspace, |
|
|
trt_lib_dir=args.trt_lib_dir, |
|
|
trt_include_dir=args.trt_include_dir, |
|
|
trt_llm_include_dir=args.trt_llm_include_dir, |
|
|
metas=kernel_configs) |
|
|
|