| import os |
| import sys |
| import time |
|
|
| import torch |
| import comfy.model_management |
|
|
| import tensorrt as trt |
| import folder_paths |
| from tqdm import tqdm |
|
|
| |
| |
| |
| try: |
| from torch.export import Dim |
| except Exception as e: |
| raise RuntimeError( |
| "[TensorRTExport] torch.export.Dim not available. " |
| "Please upgrade PyTorch to >= 2.1 / 2.5+ to use the Dynamo-based " |
| "ONNX exporter with dynamic shapes." |
| ) from e |
|
|
|
|
| def trtlog(msg: str): |
| print(f"[TensorRTExport] {msg}", flush=True) |
|
|
|
|
| |
| |
| |
| |
| DEFAULT_ONNX_OPSET = None |
| _env_opset = os.getenv("COMFY_TRT_ONNX_OPSET") |
| if _env_opset is not None: |
| try: |
| DEFAULT_ONNX_OPSET = int(_env_opset) |
| trtlog(f"Using opset_version from COMFY_TRT_ONNX_OPSET={DEFAULT_ONNX_OPSET}") |
| except ValueError: |
| trtlog( |
| f"WARNING: invalid COMFY_TRT_ONNX_OPSET={_env_opset!r}, " |
| "falling back to PyTorch recommended opset (None)." |
| ) |
| DEFAULT_ONNX_OPSET = None |
|
|
|
|
| |
| |
| |
| if "tensorrt" in folder_paths.folder_names_and_paths: |
| folder_paths.folder_names_and_paths["tensorrt"][0].append( |
| os.path.join(folder_paths.get_output_directory(), "tensorrt") |
| ) |
| folder_paths.folder_names_and_paths["tensorrt"][1].add(".engine") |
| else: |
| folder_paths.folder_names_and_paths["tensorrt"] = ( |
| [os.path.join(folder_paths.get_output_directory(), "tensorrt")], |
| {".engine"}, |
| ) |
|
|
|
|
| |
| |
| |
| class TQDMProgressMonitor(trt.IProgressMonitor): |
| def __init__(self): |
| trt.IProgressMonitor.__init__(self) |
| self._active_phases = {} |
| self._step_result = True |
| self.max_indent = 5 |
|
|
| def phase_start(self, phase_name, parent_phase, num_steps): |
| leave = False |
| try: |
| if parent_phase is not None: |
| nbIndents = ( |
| self._active_phases.get(parent_phase, {}).get( |
| "nbIndents", self.max_indent |
| ) |
| + 1 |
| ) |
| if nbIndents >= self.max_indent: |
| return |
| else: |
| nbIndents = 0 |
| leave = True |
| self._active_phases[phase_name] = { |
| "tq": tqdm( |
| total=num_steps, desc=phase_name, leave=leave, position=nbIndents |
| ), |
| "nbIndents": nbIndents, |
| "parent_phase": parent_phase, |
| } |
| except KeyboardInterrupt: |
| |
| |
| self._step_result = False |
|
|
| def phase_finish(self, phase_name): |
| try: |
| if phase_name in self._active_phases.keys(): |
| self._active_phases[phase_name]["tq"].update( |
| self._active_phases[phase_name]["tq"].total |
| - self._active_phases[phase_name]["tq"].n |
| ) |
|
|
| parent_phase = self._active_phases[phase_name].get("parent_phase", None) |
| while parent_phase is not None: |
| self._active_phases[parent_phase]["tq"].refresh() |
| parent_phase = self._active_phases[parent_phase].get( |
| "parent_phase", None |
| ) |
| if ( |
| self._active_phases[phase_name]["parent_phase"] |
| in self._active_phases.keys() |
| ): |
| self._active_phases[ |
| self._active_phases[phase_name]["parent_phase"] |
| ]["tq"].refresh() |
| del self._active_phases[phase_name] |
| except KeyboardInterrupt: |
| self._step_result = False |
|
|
| def step_complete(self, phase_name, step): |
| try: |
| if phase_name in self._active_phases.keys(): |
| self._active_phases[phase_name]["tq"].update( |
| step - self._active_phases[phase_name]["tq"].n |
| ) |
| return self._step_result |
| except KeyboardInterrupt: |
| |
| |
| return False |
|
|
|
|
| |
| |
| |
| class TRT_MODEL_CONVERSION_BASE: |
| def __init__(self): |
| self.output_dir = folder_paths.get_output_directory() |
| self.temp_dir = folder_paths.get_temp_directory() |
| self.timing_cache_path = os.path.normpath( |
| os.path.join( |
| os.path.dirname(os.path.realpath(__file__)), "timing_cache.trt" |
| ) |
| ) |
|
|
| RETURN_TYPES = () |
| FUNCTION = "convert" |
| OUTPUT_NODE = True |
| CATEGORY = "TensorRT" |
|
|
| @classmethod |
| def INPUT_TYPES(s): |
| raise NotImplementedError |
|
|
| |
| def _setup_timing_cache(self, config: trt.IBuilderConfig): |
| buffer = b"" |
| if os.path.exists(self.timing_cache_path): |
| with open(self.timing_cache_path, mode="rb") as timing_cache_file: |
| buffer = timing_cache_file.read() |
| trtlog(f"Read {len(buffer)} bytes from timing cache.") |
| else: |
| trtlog("No timing cache found; initializing a new one.") |
| timing_cache: trt.ITimingCache = config.create_timing_cache(buffer) |
| config.set_timing_cache(timing_cache, ignore_mismatch=True) |
|
|
| |
| def _save_timing_cache(self, config: trt.IBuilderConfig): |
| timing_cache: trt.ITimingCache = config.get_timing_cache() |
| with open(self.timing_cache_path, "wb") as timing_cache_file: |
| serialized = timing_cache.serialize() |
| timing_cache_file.write(memoryview(serialized)) |
| trtlog(f"Timing cache saved to {self.timing_cache_path}") |
|
|
| def _convert( |
| self, |
| model, |
| filename_prefix, |
| batch_size_min, |
| batch_size_opt, |
| batch_size_max, |
| height_min, |
| height_opt, |
| height_max, |
| width_min, |
| width_opt, |
| width_max, |
| context_min, |
| context_opt, |
| context_max, |
| num_video_frames, |
| is_static: bool, |
| ): |
| |
| |
| |
| trtlog( |
| f"PyTorch version: {torch.__version__}, TensorRT version: {trt.__version__}" |
| ) |
| trtlog( |
| f"Requested {'STATIC' if is_static else 'DYNAMIC'} TensorRT engine " |
| f"(b=[{batch_size_min},{batch_size_opt},{batch_size_max}], " |
| f"h=[{height_min},{height_opt},{height_max}], " |
| f"w=[{width_min},{width_opt},{width_max}], " |
| f"context=[{context_min},{context_opt},{context_max}], " |
| f"num_video_frames={num_video_frames})" |
| ) |
|
|
| output_onnx = os.path.normpath( |
| os.path.join( |
| os.path.join(self.temp_dir, "{}".format(time.time())), "model.onnx" |
| ) |
| ) |
| trtlog(f"Temporary ONNX path: {output_onnx}") |
|
|
| |
| |
| |
| comfy.model_management.unload_all_models() |
| comfy.model_management.load_models_gpu( |
| [model], force_patch_weights=True, force_full_load=True |
| ) |
|
|
| unet = model.model.diffusion_model |
| model_type = type(model.model).__name__ |
| trtlog(f"Detected model type: {model_type}") |
|
|
| context_dim = model.model.model_config.unet_config.get("context_dim", None) |
| context_len = 77 |
| context_len_min = context_len |
| y_dim = model.model.adm_channels |
| extra_input = {} |
| dtype = torch.float16 |
|
|
| |
| |
| |
| if isinstance(model.model, comfy.model_base.SD3): |
| context_embedder_config = model.model.model_config.unet_config.get( |
| "context_embedder_config", None |
| ) |
| if context_embedder_config is not None: |
| context_dim = context_embedder_config.get( |
| "params", {} |
| ).get("in_features", None) |
| |
| context_len = 154 |
| trtlog(f"SD3 context_dim={context_dim}, context_len={context_len}") |
| elif isinstance(model.model, comfy.model_base.AuraFlow): |
| context_dim = 2048 |
| context_len_min = 256 |
| context_len = 256 |
| trtlog( |
| f"AuraFlow context_dim={context_dim}, " |
| f"context_len_min={context_len_min}, context_len={context_len}" |
| ) |
| elif isinstance(model.model, comfy.model_base.Flux): |
| context_dim = model.model.model_config.unet_config.get( |
| "context_in_dim", None |
| ) |
| context_len_min = 256 |
| context_len = 256 |
| y_dim = model.model.model_config.unet_config.get("vec_in_dim", None) |
| extra_input = {"guidance": ()} |
| dtype = torch.bfloat16 |
| trtlog( |
| f"Flux context_dim={context_dim}, y_dim={y_dim}, " |
| f"context_len_min={context_len_min}, context_len={context_len}, " |
| f"extra_input={list(extra_input.keys())}, dtype={dtype}" |
| ) |
|
|
| if context_dim is None: |
| print("ERROR: model not supported (no context_dim).") |
| comfy.model_management.unload_all_models() |
| comfy.model_management.soft_empty_cache() |
| return () |
|
|
| input_names = ["x", "timesteps", "context"] |
| output_names = ["h"] |
|
|
| transformer_options = model.model_options["transformer_options"].copy() |
| use_temporal = model.model.model_config.unet_config.get( |
| "use_temporal_resblock", False |
| ) |
|
|
| |
| |
| |
| if use_temporal: |
| trtlog("Model uses temporal resblock (SVD-like). Adjusting batch sizes.") |
| batch_size_min = num_video_frames * batch_size_min |
| batch_size_opt = num_video_frames * batch_size_opt |
| batch_size_max = num_video_frames * batch_size_max |
|
|
| class SVD_UNET(torch.nn.Module): |
| def __init__(self, unet, transformer_options, num_video_frames): |
| super().__init__() |
| self.unet = unet |
| self.transformer_options = transformer_options |
| self.num_video_frames = num_video_frames |
|
|
| def forward(self, x, timesteps, context, y): |
| return self.unet( |
| x, |
| timesteps, |
| context, |
| y, |
| num_video_frames=self.num_video_frames, |
| transformer_options=self.transformer_options, |
| ) |
|
|
| unet = SVD_UNET(unet, transformer_options, num_video_frames) |
| context_len_min = context_len = 1 |
| trtlog( |
| f"SVD adjusted batch: " |
| f"b=[{batch_size_min},{batch_size_opt},{batch_size_max}], " |
| f"context_len_min={context_len_min}, context_len={context_len}" |
| ) |
|
|
| else: |
| |
| extra_keys = list(extra_input.keys()) |
|
|
| class UNET(torch.nn.Module): |
| def __init__(self, unet, transformer_options, y_dim, extra_keys): |
| super().__init__() |
| self.unet = unet |
| self.transformer_options = transformer_options |
| self.y_dim = y_dim |
| self.extra_keys = extra_keys |
|
|
| def forward(self, x, timesteps, context, y=None, guidance=None): |
| extra_args = {} |
| if self.y_dim is not None and self.y_dim > 0 and y is not None: |
| extra_args["y"] = y |
| if "guidance" in self.extra_keys and guidance is not None: |
| extra_args["guidance"] = guidance |
|
|
| return self.unet( |
| x, |
| timesteps, |
| context, |
| transformer_options=self.transformer_options, |
| **extra_args, |
| ) |
|
|
| unet = UNET(unet, transformer_options, y_dim, extra_keys) |
|
|
| |
| |
| |
| input_channels = model.model.model_config.unet_config.get("in_channels", 4) |
|
|
| inputs_shapes_min = ( |
| (batch_size_min, input_channels, height_min // 8, width_min // 8), |
| (batch_size_min,), |
| (batch_size_min, context_len_min * context_min, context_dim), |
| ) |
| inputs_shapes_opt = ( |
| (batch_size_opt, input_channels, height_opt // 8, width_opt // 8), |
| (batch_size_opt,), |
| (batch_size_opt, context_len * context_opt, context_dim), |
| ) |
| inputs_shapes_max = ( |
| (batch_size_max, input_channels, height_max // 8, width_max // 8), |
| (batch_size_max,), |
| (batch_size_max, context_len * context_max, context_dim), |
| ) |
|
|
| if y_dim is not None and y_dim > 0: |
| input_names.append("y") |
| inputs_shapes_min += ((batch_size_min, y_dim),) |
| inputs_shapes_opt += ((batch_size_opt, y_dim),) |
| inputs_shapes_max += ((batch_size_max, y_dim),) |
|
|
| |
| for k in extra_input: |
| input_names.append(k) |
| shape_suffix = extra_input[k] |
| inputs_shapes_min += ((batch_size_min,) + shape_suffix,) |
| inputs_shapes_opt += ((batch_size_opt,) + shape_suffix,) |
| inputs_shapes_max += ((batch_size_max,) + shape_suffix,) |
|
|
| |
| if context_max < context_min: |
| trtlog( |
| f"WARNING: context_max({context_max}) < context_min({context_min}), swapping." |
| ) |
| context_min, context_max = context_max, context_min |
|
|
| trtlog("Input names: " + ", ".join(input_names)) |
| for idx, name in enumerate(input_names): |
| trtlog( |
| f" {name}: " |
| f"min={inputs_shapes_min[idx]}, " |
| f"opt={inputs_shapes_opt[idx]}, " |
| f"max={inputs_shapes_max[idx]}" |
| ) |
|
|
| |
| |
| |
| |
| |
| dynamic_shapes = None |
|
|
| def _maybe_dim(name: str, min_v: int, max_v: int): |
| """Create Dim only if there is real dynamism (max > min).""" |
| if max_v < min_v: |
| trtlog( |
| f"WARNING: Dim {name} has min>{max_v}>{min_v}, swapping to fix." |
| ) |
| min_v, max_v = max_v, min_v |
| if max_v > min_v: |
| trtlog(f"Dim {name}: dynamic range [{min_v}, {max_v}]") |
| return Dim(name, min=min_v, max=max_v) |
| else: |
| trtlog(f"Dim {name}: static value {min_v}, not using Dim.") |
| return None |
|
|
| if not is_static: |
| |
| B = _maybe_dim("batch", batch_size_min, batch_size_max) |
| H = _maybe_dim("height", height_min // 8, height_max // 8) |
| W = _maybe_dim("width", width_min // 8, width_max // 8) |
| tokens_min = context_len_min * context_min |
| tokens_max = context_len * context_max |
| T = _maybe_dim("tokens", tokens_min, tokens_max) |
|
|
| dynamic_shapes = {} |
|
|
| |
| x_dyn = {} |
| if B is not None: |
| x_dyn[0] = B |
| if H is not None: |
| x_dyn[2] = H |
| if W is not None: |
| x_dyn[3] = W |
| if x_dyn: |
| dynamic_shapes["x"] = x_dyn |
|
|
| |
| if B is not None: |
| dynamic_shapes["timesteps"] = {0: B} |
|
|
| |
| ctx_dyn = {} |
| if B is not None: |
| ctx_dyn[0] = B |
| if T is not None: |
| ctx_dyn[1] = T |
| if ctx_dyn: |
| dynamic_shapes["context"] = ctx_dyn |
|
|
| |
| if "y" in input_names and B is not None: |
| dynamic_shapes["y"] = {0: B} |
|
|
| |
| if "guidance" in input_names and B is not None: |
| dynamic_shapes["guidance"] = {0: B} |
|
|
| if not dynamic_shapes: |
| trtlog( |
| "No dimensions are actually dynamic for DYNAMIC node. " |
| "Export will effectively be static." |
| ) |
| dynamic_shapes = None |
| else: |
| trtlog(f"dynamic_shapes spec: {dynamic_shapes}") |
| else: |
| trtlog("STATIC node: skipping torch.export.Dim and dynamic_shapes entirely.") |
|
|
| |
| |
| |
| inputs = () |
| for shape in inputs_shapes_opt: |
| inputs += ( |
| torch.zeros( |
| shape, |
| device=comfy.model_management.get_torch_device(), |
| dtype=dtype, |
| ), |
| ) |
|
|
| |
| |
| |
| |
| |
| os.makedirs(os.path.dirname(output_onnx), exist_ok=True) |
|
|
| trtlog( |
| f"Exporting UNet to ONNX with dynamo=True, " |
| f"opset_version={DEFAULT_ONNX_OPSET}, dtype={dtype}, " |
| f"output={output_onnx}" |
| ) |
| if dynamic_shapes is None: |
| trtlog("ONNX export will be STATIC (no dynamic_shapes).") |
| else: |
| trtlog("ONNX export will use dynamic_shapes (see spec above).") |
|
|
| try: |
| torch.onnx.export( |
| unet, |
| inputs, |
| output_onnx, |
| verbose=False, |
| input_names=input_names, |
| output_names=output_names, |
| opset_version=DEFAULT_ONNX_OPSET, |
| dynamo=True, |
| dynamic_shapes=dynamic_shapes, |
| |
| |
| |
| ) |
| trtlog("torch.onnx.export completed successfully.") |
| except Exception as e: |
| trtlog(f"ERROR during torch.onnx.export: {e}") |
| |
| comfy.model_management.unload_all_models() |
| comfy.model_management.soft_empty_cache() |
| raise |
|
|
| comfy.model_management.unload_all_models() |
| comfy.model_management.soft_empty_cache() |
|
|
| |
| |
| |
| logger = trt.Logger(trt.Logger.INFO) |
| builder = trt.Builder(logger) |
| trtlog("Created TensorRT builder.") |
|
|
| network = builder.create_network( |
| 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) |
| ) |
| parser = trt.OnnxParser(network, logger) |
| trtlog(f"Parsing ONNX file: {output_onnx}") |
| success = parser.parse_from_file(output_onnx) |
| for idx in range(parser.num_errors): |
| print(parser.get_error(idx)) |
|
|
| if not success: |
| print("ONNX load ERROR (TensorRT parser.parse_from_file returned False).") |
| return () |
|
|
| config = builder.create_builder_config() |
| profile = builder.create_optimization_profile() |
| self._setup_timing_cache(config) |
| config.progress_monitor = TQDMProgressMonitor() |
|
|
| trtlog("Creating optimization profile:") |
| prefix_encode = "" |
| for k in range(len(input_names)): |
| min_shape = inputs_shapes_min[k] |
| opt_shape = inputs_shapes_opt[k] |
| max_shape = inputs_shapes_max[k] |
| trtlog( |
| f" {input_names[k]}: min={min_shape}, opt={opt_shape}, max={max_shape}" |
| ) |
| profile.set_shape(input_names[k], min_shape, opt_shape, max_shape) |
|
|
| |
| encode = lambda a: ".".join(map(str, a)) |
| prefix_encode += "{}#{}#{}#{};".format( |
| input_names[k], encode(min_shape), encode(opt_shape), encode(max_shape) |
| ) |
|
|
| if dtype == torch.float16: |
| trtlog("Enabling FP16 mode in TensorRT builder config.") |
| config.set_flag(trt.BuilderFlag.FP16) |
| if dtype == torch.bfloat16: |
| trtlog("Enabling BF16 mode in TensorRT builder config.") |
| config.set_flag(trt.BuilderFlag.BF16) |
|
|
| config.add_optimization_profile(profile) |
|
|
| if is_static: |
| filename_prefix = "{}_${}".format( |
| filename_prefix, |
| "-".join( |
| ( |
| "stat", |
| "b", |
| str(batch_size_opt), |
| "h", |
| str(height_opt), |
| "w", |
| str(width_opt), |
| ) |
| ), |
| ) |
| else: |
| filename_prefix = "{}_${}".format( |
| filename_prefix, |
| "-".join( |
| ( |
| "dyn", |
| "b", |
| str(batch_size_min), |
| str(batch_size_max), |
| str(batch_size_opt), |
| "h", |
| str(height_min), |
| str(height_max), |
| str(height_opt), |
| "w", |
| str(width_min), |
| str(width_max), |
| str(width_opt), |
| ) |
| ), |
| ) |
|
|
| trtlog("Building serialized TensorRT engine. This may take a while...") |
| serialized_engine = builder.build_serialized_network(network, config) |
| if serialized_engine is None: |
| trtlog("ERROR: builder.build_serialized_network returned None.") |
| return () |
|
|
| full_output_folder, filename, counter, subfolder, filename_prefix = ( |
| folder_paths.get_save_image_path(filename_prefix, self.output_dir) |
| ) |
| output_trt_engine = os.path.join( |
| full_output_folder, f"{filename}_{counter:05}_.engine" |
| ) |
|
|
| trtlog(f"Writing TensorRT engine to: {output_trt_engine}") |
| os.makedirs(full_output_folder, exist_ok=True) |
| with open(output_trt_engine, "wb") as f: |
| f.write(serialized_engine) |
|
|
| self._save_timing_cache(config) |
| trtlog("TensorRT conversion complete.") |
|
|
| return () |
|
|
|
|
| |
| |
| |
| class DYNAMIC_TRT_MODEL_CONVERSION(TRT_MODEL_CONVERSION_BASE): |
| def __init__(self): |
| super(DYNAMIC_TRT_MODEL_CONVERSION, self).__init__() |
|
|
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "model": ("MODEL",), |
| "filename_prefix": ("STRING", {"default": "tensorrt/ComfyUI_DYN"}), |
| "batch_size_min": ( |
| "INT", |
| { |
| "default": 1, |
| "min": 1, |
| "max": 100, |
| "step": 1, |
| }, |
| ), |
| "batch_size_opt": ( |
| "INT", |
| { |
| "default": 1, |
| "min": 1, |
| "max": 100, |
| "step": 1, |
| }, |
| ), |
| "batch_size_max": ( |
| "INT", |
| { |
| "default": 1, |
| "min": 1, |
| "max": 100, |
| "step": 1, |
| }, |
| ), |
| "height_min": ( |
| "INT", |
| { |
| "default": 512, |
| "min": 256, |
| "max": 4096, |
| "step": 64, |
| }, |
| ), |
| "height_opt": ( |
| "INT", |
| { |
| "default": 512, |
| "min": 256, |
| "max": 4096, |
| "step": 64, |
| }, |
| ), |
| "height_max": ( |
| "INT", |
| { |
| "default": 512, |
| "min": 256, |
| "max": 4096, |
| "step": 64, |
| }, |
| ), |
| "width_min": ( |
| "INT", |
| { |
| "default": 512, |
| "min": 256, |
| "max": 4096, |
| "step": 64, |
| }, |
| ), |
| "width_opt": ( |
| "INT", |
| { |
| "default": 512, |
| "min": 256, |
| "max": 4096, |
| "step": 64, |
| }, |
| ), |
| "width_max": ( |
| "INT", |
| { |
| "default": 512, |
| "min": 256, |
| "max": 4096, |
| "step": 64, |
| }, |
| ), |
| "context_min": ( |
| "INT", |
| { |
| "default": 1, |
| "min": 1, |
| "max": 128, |
| "step": 1, |
| }, |
| ), |
| "context_opt": ( |
| "INT", |
| { |
| "default": 1, |
| "min": 1, |
| "max": 128, |
| "step": 1, |
| }, |
| ), |
| "context_max": ( |
| "INT", |
| { |
| "default": 1, |
| "min": 1, |
| "max": 128, |
| "step": 1, |
| }, |
| ), |
| "num_video_frames": ( |
| "INT", |
| { |
| "default": 14, |
| "min": 0, |
| "max": 1000, |
| "step": 1, |
| }, |
| ), |
| }, |
| } |
|
|
| def convert( |
| self, |
| model, |
| filename_prefix, |
| batch_size_min, |
| batch_size_opt, |
| batch_size_max, |
| height_min, |
| height_opt, |
| height_max, |
| width_min, |
| width_opt, |
| width_max, |
| context_min, |
| context_opt, |
| context_max, |
| num_video_frames, |
| ): |
| return super()._convert( |
| model, |
| filename_prefix, |
| batch_size_min, |
| batch_size_opt, |
| batch_size_max, |
| height_min, |
| height_opt, |
| height_max, |
| width_min, |
| width_opt, |
| width_max, |
| context_min, |
| context_opt, |
| context_max, |
| num_video_frames, |
| is_static=False, |
| ) |
|
|
|
|
| class STATIC_TRT_MODEL_CONVERSION(TRT_MODEL_CONVERSION_BASE): |
| def __init__(self): |
| super(STATIC_TRT_MODEL_CONVERSION, self).__init__() |
|
|
| @classmethod |
| def INPUT_TYPES(s): |
| return { |
| "required": { |
| "model": ("MODEL",), |
| "filename_prefix": ("STRING", {"default": "tensorrt/ComfyUI_STAT"}), |
| "batch_size_opt": ( |
| "INT", |
| { |
| "default": 1, |
| "min": 1, |
| "max": 100, |
| "step": 1, |
| }, |
| ), |
| "height_opt": ( |
| "INT", |
| { |
| "default": 512, |
| "min": 256, |
| "max": 4096, |
| "step": 64, |
| }, |
| ), |
| "width_opt": ( |
| "INT", |
| { |
| "default": 512, |
| "min": 256, |
| "max": 4096, |
| "step": 64, |
| }, |
| ), |
| "context_opt": ( |
| "INT", |
| { |
| "default": 1, |
| "min": 1, |
| "max": 128, |
| "step": 1, |
| }, |
| ), |
| "num_video_frames": ( |
| "INT", |
| { |
| "default": 14, |
| "min": 0, |
| "max": 1000, |
| "step": 1, |
| }, |
| ), |
| }, |
| } |
|
|
| def convert( |
| self, |
| model, |
| filename_prefix, |
| batch_size_opt, |
| height_opt, |
| width_opt, |
| context_opt, |
| num_video_frames, |
| ): |
| |
| return super()._convert( |
| model, |
| filename_prefix, |
| batch_size_opt, |
| batch_size_opt, |
| batch_size_opt, |
| height_opt, |
| height_opt, |
| height_opt, |
| width_opt, |
| width_opt, |
| width_opt, |
| context_opt, |
| context_opt, |
| context_opt, |
| num_video_frames, |
| is_static=True, |
| ) |
|
|
|
|
| NODE_CLASS_MAPPINGS = { |
| "DYNAMIC_TRT_MODEL_CONVERSION": DYNAMIC_TRT_MODEL_CONVERSION, |
| "STATIC_TRT_MODEL_CONVERSION": STATIC_TRT_MODEL_CONVERSION, |
| } |
|
|