| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import argparse |
| | import os |
| |
|
| | import tensorrt as trt |
| | from common_runtime import * |
| |
|
| | try: |
| | |
| | FileNotFoundError |
| | except NameError: |
| | FileNotFoundError = IOError |
| |
|
| |
|
| | def GiB(val): |
| | return val * 1 << 30 |
| |
|
| |
|
| | def add_help(description): |
| | parser = argparse.ArgumentParser( |
| | description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter |
| | ) |
| | args, _ = parser.parse_known_args() |
| |
|
| |
|
| | def find_sample_data( |
| | description="Runs a TensorRT Python sample", subfolder="", find_files=[], err_msg="" |
| | ): |
| | """ |
| | Parses sample arguments. |
| | |
| | Args: |
| | description (str): Description of the sample. |
| | subfolder (str): The subfolder containing data relevant to this sample |
| | find_files (str): A list of filenames to find. Each filename will be replaced with an absolute path. |
| | |
| | Returns: |
| | str: Path of data directory. |
| | """ |
| |
|
| | |
| | kDEFAULT_DATA_ROOT = os.path.join(os.sep, "usr", "src", "tensorrt", "data") |
| | parser = argparse.ArgumentParser( |
| | description=description, formatter_class=argparse.ArgumentDefaultsHelpFormatter |
| | ) |
| | parser.add_argument( |
| | "-d", |
| | "--datadir", |
| | help="Location of the TensorRT sample data directory, and any additional data directories.", |
| | action="append", |
| | default=[kDEFAULT_DATA_ROOT], |
| | ) |
| | args, _ = parser.parse_known_args() |
| |
|
| | def get_data_path(data_dir): |
| | |
| | data_path = os.path.join(data_dir, subfolder) |
| | if not os.path.exists(data_path): |
| | if data_dir != kDEFAULT_DATA_ROOT: |
| | print( |
| | "WARNING: " |
| | + data_path |
| | + " does not exist. Trying " |
| | + data_dir |
| | + " instead." |
| | ) |
| | data_path = data_dir |
| | |
| | if not (os.path.exists(data_path)) and data_dir != kDEFAULT_DATA_ROOT: |
| | print( |
| | "WARNING: {:} does not exist. Please provide the correct data path with the -d option.".format( |
| | data_path |
| | ) |
| | ) |
| | return data_path |
| |
|
| | data_paths = [get_data_path(data_dir) for data_dir in args.datadir] |
| | return data_paths, locate_files(data_paths, find_files, err_msg) |
| |
|
| |
|
| | def locate_files(data_paths, filenames, err_msg=""): |
| | """ |
| | Locates the specified files in the specified data directories. |
| | If a file exists in multiple data directories, the first directory is used. |
| | |
| | Args: |
| | data_paths (List[str]): The data directories. |
| | filename (List[str]): The names of the files to find. |
| | |
| | Returns: |
| | List[str]: The absolute paths of the files. |
| | |
| | Raises: |
| | FileNotFoundError if a file could not be located. |
| | """ |
| | found_files = [None] * len(filenames) |
| | for data_path in data_paths: |
| | |
| | for index, (found, filename) in enumerate(zip(found_files, filenames)): |
| | if not found: |
| | file_path = os.path.abspath(os.path.join(data_path, filename)) |
| | if os.path.exists(file_path): |
| | found_files[index] = file_path |
| |
|
| | |
| | for f, filename in zip(found_files, filenames): |
| | if not f or not os.path.exists(f): |
| | raise FileNotFoundError( |
| | "Could not find {:}. Searched in data paths: {:}\n{:}".format( |
| | filename, data_paths, err_msg |
| | ) |
| | ) |
| | return found_files |
| |
|
| |
|
| | |
| | def setup_timing_cache(config: trt.IBuilderConfig, timing_cache_path: os.PathLike): |
| | buffer = b"" |
| | if os.path.exists(timing_cache_path): |
| | with open(timing_cache_path, mode="rb") as timing_cache_file: |
| | buffer = timing_cache_file.read() |
| | timing_cache: trt.ITimingCache = config.create_timing_cache(buffer) |
| | config.set_timing_cache(timing_cache, True) |
| |
|
| |
|
| | |
| | def save_timing_cache(config: trt.IBuilderConfig, timing_cache_path: os.PathLike): |
| | timing_cache: trt.ITimingCache = config.get_timing_cache() |
| | with open(timing_cache_path, "wb") as timing_cache_file: |
| | timing_cache_file.write(memoryview(timing_cache.serialize())) |
| |
|