import os import sys import time import torch import comfy.model_management import tensorrt as trt import folder_paths from tqdm import tqdm # ------------------------------------------------------------------------- # torch.export dynamic shapes support # ------------------------------------------------------------------------- try: from torch.export import Dim except Exception as e: raise RuntimeError( "[TensorRTExport] torch.export.Dim not available. " "Please upgrade PyTorch to >= 2.1 / 2.5+ to use the Dynamo-based " "ONNX exporter with dynamic shapes." ) from e def trtlog(msg: str): print(f"[TensorRTExport] {msg}", flush=True) # Opset handling: # - If COMFY_TRT_ONNX_OPSET is set, use that integer. # - Otherwise, leave opset_version=None so torch.onnx uses the # recommended opset for this PyTorch version (e.g. 20 in 2.9+). DEFAULT_ONNX_OPSET = None _env_opset = os.getenv("COMFY_TRT_ONNX_OPSET") if _env_opset is not None: try: DEFAULT_ONNX_OPSET = int(_env_opset) trtlog(f"Using opset_version from COMFY_TRT_ONNX_OPSET={DEFAULT_ONNX_OPSET}") except ValueError: trtlog( f"WARNING: invalid COMFY_TRT_ONNX_OPSET={_env_opset!r}, " "falling back to PyTorch recommended opset (None)." ) DEFAULT_ONNX_OPSET = None # ------------------------------------------------------------------------- # Add output directory to TensorRT search path (ComfyUI integration) # ------------------------------------------------------------------------- if "tensorrt" in folder_paths.folder_names_and_paths: folder_paths.folder_names_and_paths["tensorrt"][0].append( os.path.join(folder_paths.get_output_directory(), "tensorrt") ) folder_paths.folder_names_and_paths["tensorrt"][1].add(".engine") else: folder_paths.folder_names_and_paths["tensorrt"] = ( [os.path.join(folder_paths.get_output_directory(), "tensorrt")], {".engine"}, ) # ------------------------------------------------------------------------- # Progress monitor for TensorRT builds # ------------------------------------------------------------------------- class TQDMProgressMonitor(trt.IProgressMonitor): def __init__(self): trt.IProgressMonitor.__init__(self) self._active_phases = {} self._step_result = True self.max_indent = 5 def phase_start(self, phase_name, parent_phase, num_steps): leave = False try: if parent_phase is not None: nbIndents = ( self._active_phases.get(parent_phase, {}).get( "nbIndents", self.max_indent ) + 1 ) if nbIndents >= self.max_indent: return else: nbIndents = 0 leave = True self._active_phases[phase_name] = { "tq": tqdm( total=num_steps, desc=phase_name, leave=leave, position=nbIndents ), "nbIndents": nbIndents, "parent_phase": parent_phase, } except KeyboardInterrupt: # The phase_start callback cannot directly cancel the build, # so request the cancellation from within step_complete. self._step_result = False def phase_finish(self, phase_name): try: if phase_name in self._active_phases.keys(): self._active_phases[phase_name]["tq"].update( self._active_phases[phase_name]["tq"].total - self._active_phases[phase_name]["tq"].n ) parent_phase = self._active_phases[phase_name].get("parent_phase", None) while parent_phase is not None: self._active_phases[parent_phase]["tq"].refresh() parent_phase = self._active_phases[parent_phase].get( "parent_phase", None ) if ( self._active_phases[phase_name]["parent_phase"] in self._active_phases.keys() ): self._active_phases[ self._active_phases[phase_name]["parent_phase"] ]["tq"].refresh() del self._active_phases[phase_name] except KeyboardInterrupt: self._step_result = False def step_complete(self, phase_name, step): try: if phase_name in self._active_phases.keys(): self._active_phases[phase_name]["tq"].update( step - self._active_phases[phase_name]["tq"].n ) return self._step_result except KeyboardInterrupt: # There is no need to propagate this exception to TensorRT. # We can simply cancel the build. return False # ------------------------------------------------------------------------- # Base class for ONNX -> TensorRT conversion # ------------------------------------------------------------------------- class TRT_MODEL_CONVERSION_BASE: def __init__(self): self.output_dir = folder_paths.get_output_directory() self.temp_dir = folder_paths.get_temp_directory() self.timing_cache_path = os.path.normpath( os.path.join( os.path.dirname(os.path.realpath(__file__)), "timing_cache.trt" ) ) RETURN_TYPES = () FUNCTION = "convert" OUTPUT_NODE = True CATEGORY = "TensorRT" @classmethod def INPUT_TYPES(s): raise NotImplementedError # Sets up the builder to use the timing cache file, and creates it if it does not already exist def _setup_timing_cache(self, config: trt.IBuilderConfig): buffer = b"" if os.path.exists(self.timing_cache_path): with open(self.timing_cache_path, mode="rb") as timing_cache_file: buffer = timing_cache_file.read() trtlog(f"Read {len(buffer)} bytes from timing cache.") else: trtlog("No timing cache found; initializing a new one.") timing_cache: trt.ITimingCache = config.create_timing_cache(buffer) config.set_timing_cache(timing_cache, ignore_mismatch=True) # Saves the config's timing cache to file def _save_timing_cache(self, config: trt.IBuilderConfig): timing_cache: trt.ITimingCache = config.get_timing_cache() with open(self.timing_cache_path, "wb") as timing_cache_file: serialized = timing_cache.serialize() timing_cache_file.write(memoryview(serialized)) trtlog(f"Timing cache saved to {self.timing_cache_path}") def _convert( self, model, filename_prefix, batch_size_min, batch_size_opt, batch_size_max, height_min, height_opt, height_max, width_min, width_opt, width_max, context_min, context_opt, context_max, num_video_frames, is_static: bool, ): # ----------------------------------------------------------------- # Basic logging: versions & configuration # ----------------------------------------------------------------- trtlog( f"PyTorch version: {torch.__version__}, TensorRT version: {trt.__version__}" ) trtlog( f"Requested {'STATIC' if is_static else 'DYNAMIC'} TensorRT engine " f"(b=[{batch_size_min},{batch_size_opt},{batch_size_max}], " f"h=[{height_min},{height_opt},{height_max}], " f"w=[{width_min},{width_opt},{width_max}], " f"context=[{context_min},{context_opt},{context_max}], " f"num_video_frames={num_video_frames})" ) output_onnx = os.path.normpath( os.path.join( os.path.join(self.temp_dir, "{}".format(time.time())), "model.onnx" ) ) trtlog(f"Temporary ONNX path: {output_onnx}") # ----------------------------------------------------------------- # Load model to GPU # ----------------------------------------------------------------- comfy.model_management.unload_all_models() comfy.model_management.load_models_gpu( [model], force_patch_weights=True, force_full_load=True ) unet = model.model.diffusion_model model_type = type(model.model).__name__ trtlog(f"Detected model type: {model_type}") context_dim = model.model.model_config.unet_config.get("context_dim", None) context_len = 77 context_len_min = context_len y_dim = model.model.adm_channels extra_input = {} dtype = torch.float16 # ----------------------------------------------------------------- # Model-type specific tweaks # ----------------------------------------------------------------- if isinstance(model.model, comfy.model_base.SD3): # SD3 context_embedder_config = model.model.model_config.unet_config.get( "context_embedder_config", None ) if context_embedder_config is not None: context_dim = context_embedder_config.get( "params", {} ).get("in_features", None) # SD3 can have 77 or 154 depending on TE usage context_len = 154 trtlog(f"SD3 context_dim={context_dim}, context_len={context_len}") elif isinstance(model.model, comfy.model_base.AuraFlow): context_dim = 2048 context_len_min = 256 context_len = 256 trtlog( f"AuraFlow context_dim={context_dim}, " f"context_len_min={context_len_min}, context_len={context_len}" ) elif isinstance(model.model, comfy.model_base.Flux): context_dim = model.model.model_config.unet_config.get( "context_in_dim", None ) context_len_min = 256 context_len = 256 y_dim = model.model.model_config.unet_config.get("vec_in_dim", None) extra_input = {"guidance": ()} dtype = torch.bfloat16 trtlog( f"Flux context_dim={context_dim}, y_dim={y_dim}, " f"context_len_min={context_len_min}, context_len={context_len}, " f"extra_input={list(extra_input.keys())}, dtype={dtype}" ) if context_dim is None: print("ERROR: model not supported (no context_dim).") comfy.model_management.unload_all_models() comfy.model_management.soft_empty_cache() return () input_names = ["x", "timesteps", "context"] output_names = ["h"] transformer_options = model.model_options["transformer_options"].copy() use_temporal = model.model.model_config.unet_config.get( "use_temporal_resblock", False ) # ----------------------------------------------------------------- # Wrap UNet so argument names are stable for dynamic_shapes # ----------------------------------------------------------------- if use_temporal: # SVD trtlog("Model uses temporal resblock (SVD-like). Adjusting batch sizes.") batch_size_min = num_video_frames * batch_size_min batch_size_opt = num_video_frames * batch_size_opt batch_size_max = num_video_frames * batch_size_max class SVD_UNET(torch.nn.Module): def __init__(self, unet, transformer_options, num_video_frames): super().__init__() self.unet = unet self.transformer_options = transformer_options self.num_video_frames = num_video_frames def forward(self, x, timesteps, context, y): return self.unet( x, timesteps, context, y, num_video_frames=self.num_video_frames, transformer_options=self.transformer_options, ) unet = SVD_UNET(unet, transformer_options, num_video_frames) context_len_min = context_len = 1 trtlog( f"SVD adjusted batch: " f"b=[{batch_size_min},{batch_size_opt},{batch_size_max}], " f"context_len_min={context_len_min}, context_len={context_len}" ) else: # Generic wrapper with named extras (y, guidance) extra_keys = list(extra_input.keys()) class UNET(torch.nn.Module): def __init__(self, unet, transformer_options, y_dim, extra_keys): super().__init__() self.unet = unet self.transformer_options = transformer_options self.y_dim = y_dim self.extra_keys = extra_keys def forward(self, x, timesteps, context, y=None, guidance=None): extra_args = {} if self.y_dim is not None and self.y_dim > 0 and y is not None: extra_args["y"] = y if "guidance" in self.extra_keys and guidance is not None: extra_args["guidance"] = guidance return self.unet( x, timesteps, context, transformer_options=self.transformer_options, **extra_args, ) unet = UNET(unet, transformer_options, y_dim, extra_keys) # ----------------------------------------------------------------- # Compute input shapes (min / opt / max) # ----------------------------------------------------------------- input_channels = model.model.model_config.unet_config.get("in_channels", 4) inputs_shapes_min = ( (batch_size_min, input_channels, height_min // 8, width_min // 8), (batch_size_min,), (batch_size_min, context_len_min * context_min, context_dim), ) inputs_shapes_opt = ( (batch_size_opt, input_channels, height_opt // 8, width_opt // 8), (batch_size_opt,), (batch_size_opt, context_len * context_opt, context_dim), ) inputs_shapes_max = ( (batch_size_max, input_channels, height_max // 8, width_max // 8), (batch_size_max,), (batch_size_max, context_len * context_max, context_dim), ) if y_dim is not None and y_dim > 0: input_names.append("y") inputs_shapes_min += ((batch_size_min, y_dim),) inputs_shapes_opt += ((batch_size_opt, y_dim),) inputs_shapes_max += ((batch_size_max, y_dim),) # Extra inputs (currently used for Flux guidance) for k in extra_input: input_names.append(k) shape_suffix = extra_input[k] # e.g. () for scalar per batch inputs_shapes_min += ((batch_size_min,) + shape_suffix,) inputs_shapes_opt += ((batch_size_opt,) + shape_suffix,) inputs_shapes_max += ((batch_size_max,) + shape_suffix,) # Clamp context ranges sanely if the UI somehow passed inverted min/max if context_max < context_min: trtlog( f"WARNING: context_max({context_max}) < context_min({context_min}), swapping." ) context_min, context_max = context_max, context_min trtlog("Input names: " + ", ".join(input_names)) for idx, name in enumerate(input_names): trtlog( f" {name}: " f"min={inputs_shapes_min[idx]}, " f"opt={inputs_shapes_opt[idx]}, " f"max={inputs_shapes_max[idx]}" ) # ----------------------------------------------------------------- # Build dynamic_shapes spec for torch.export / dynamo=True # - STATIC node: no dynamic_shapes at all (fully static export) # - DYNAMIC node: only create Dim if max > min # ----------------------------------------------------------------- dynamic_shapes = None def _maybe_dim(name: str, min_v: int, max_v: int): """Create Dim only if there is real dynamism (max > min).""" if max_v < min_v: trtlog( f"WARNING: Dim {name} has min>{max_v}>{min_v}, swapping to fix." ) min_v, max_v = max_v, min_v if max_v > min_v: trtlog(f"Dim {name}: dynamic range [{min_v}, {max_v}]") return Dim(name, min=min_v, max=max_v) else: trtlog(f"Dim {name}: static value {min_v}, not using Dim.") return None if not is_static: # Only build dynamic_shapes for the DYNAMIC node B = _maybe_dim("batch", batch_size_min, batch_size_max) H = _maybe_dim("height", height_min // 8, height_max // 8) W = _maybe_dim("width", width_min // 8, width_max // 8) tokens_min = context_len_min * context_min tokens_max = context_len * context_max T = _maybe_dim("tokens", tokens_min, tokens_max) dynamic_shapes = {} # x: [B, C, H, W] x_dyn = {} if B is not None: x_dyn[0] = B if H is not None: x_dyn[2] = H if W is not None: x_dyn[3] = W if x_dyn: dynamic_shapes["x"] = x_dyn # timesteps: [B] if B is not None: dynamic_shapes["timesteps"] = {0: B} # context: [B, T, context_dim] ctx_dyn = {} if B is not None: ctx_dyn[0] = B if T is not None: ctx_dyn[1] = T if ctx_dyn: dynamic_shapes["context"] = ctx_dyn # y: [B, y_dim] if "y" in input_names and B is not None: dynamic_shapes["y"] = {0: B} # guidance: [B, ...] if "guidance" in input_names and B is not None: dynamic_shapes["guidance"] = {0: B} if not dynamic_shapes: trtlog( "No dimensions are actually dynamic for DYNAMIC node. " "Export will effectively be static." ) dynamic_shapes = None else: trtlog(f"dynamic_shapes spec: {dynamic_shapes}") else: trtlog("STATIC node: skipping torch.export.Dim and dynamic_shapes entirely.") # ----------------------------------------------------------------- # Build example inputs (using OPT shapes) # ----------------------------------------------------------------- inputs = () for shape in inputs_shapes_opt: inputs += ( torch.zeros( shape, device=comfy.model_management.get_torch_device(), dtype=dtype, ), ) # ----------------------------------------------------------------- # ONNX export with Dynamo (dynamo=True) # - For static: dynamic_shapes=None, so shapes are fully specialized. # - For dynamic: dynamic_shapes guides symbolic shapes. # ----------------------------------------------------------------- os.makedirs(os.path.dirname(output_onnx), exist_ok=True) trtlog( f"Exporting UNet to ONNX with dynamo=True, " f"opset_version={DEFAULT_ONNX_OPSET}, dtype={dtype}, " f"output={output_onnx}" ) if dynamic_shapes is None: trtlog("ONNX export will be STATIC (no dynamic_shapes).") else: trtlog("ONNX export will use dynamic_shapes (see spec above).") try: torch.onnx.export( unet, inputs, output_onnx, verbose=False, input_names=input_names, output_names=output_names, opset_version=DEFAULT_ONNX_OPSET, dynamo=True, dynamic_shapes=dynamic_shapes, # NOTE: # - We intentionally do NOT pass dynamic_axes here. # dynamic_axes is for the legacy TorchScript exporter. ) trtlog("torch.onnx.export completed successfully.") except Exception as e: trtlog(f"ERROR during torch.onnx.export: {e}") # Clean up GPU state before re-raising comfy.model_management.unload_all_models() comfy.model_management.soft_empty_cache() raise comfy.model_management.unload_all_models() comfy.model_management.soft_empty_cache() # ----------------------------------------------------------------- # TensorRT conversion starts here # ----------------------------------------------------------------- logger = trt.Logger(trt.Logger.INFO) builder = trt.Builder(logger) trtlog("Created TensorRT builder.") network = builder.create_network( 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) ) parser = trt.OnnxParser(network, logger) trtlog(f"Parsing ONNX file: {output_onnx}") success = parser.parse_from_file(output_onnx) for idx in range(parser.num_errors): print(parser.get_error(idx)) if not success: print("ONNX load ERROR (TensorRT parser.parse_from_file returned False).") return () config = builder.create_builder_config() profile = builder.create_optimization_profile() self._setup_timing_cache(config) config.progress_monitor = TQDMProgressMonitor() trtlog("Creating optimization profile:") prefix_encode = "" for k in range(len(input_names)): min_shape = inputs_shapes_min[k] opt_shape = inputs_shapes_opt[k] max_shape = inputs_shapes_max[k] trtlog( f" {input_names[k]}: min={min_shape}, opt={opt_shape}, max={max_shape}" ) profile.set_shape(input_names[k], min_shape, opt_shape, max_shape) # Encode shapes to filename encode = lambda a: ".".join(map(str, a)) prefix_encode += "{}#{}#{}#{};".format( input_names[k], encode(min_shape), encode(opt_shape), encode(max_shape) ) if dtype == torch.float16: trtlog("Enabling FP16 mode in TensorRT builder config.") config.set_flag(trt.BuilderFlag.FP16) if dtype == torch.bfloat16: trtlog("Enabling BF16 mode in TensorRT builder config.") config.set_flag(trt.BuilderFlag.BF16) config.add_optimization_profile(profile) if is_static: filename_prefix = "{}_${}".format( filename_prefix, "-".join( ( "stat", "b", str(batch_size_opt), "h", str(height_opt), "w", str(width_opt), ) ), ) else: filename_prefix = "{}_${}".format( filename_prefix, "-".join( ( "dyn", "b", str(batch_size_min), str(batch_size_max), str(batch_size_opt), "h", str(height_min), str(height_max), str(height_opt), "w", str(width_min), str(width_max), str(width_opt), ) ), ) trtlog("Building serialized TensorRT engine. This may take a while...") serialized_engine = builder.build_serialized_network(network, config) if serialized_engine is None: trtlog("ERROR: builder.build_serialized_network returned None.") return () full_output_folder, filename, counter, subfolder, filename_prefix = ( folder_paths.get_save_image_path(filename_prefix, self.output_dir) ) output_trt_engine = os.path.join( full_output_folder, f"{filename}_{counter:05}_.engine" ) trtlog(f"Writing TensorRT engine to: {output_trt_engine}") os.makedirs(full_output_folder, exist_ok=True) with open(output_trt_engine, "wb") as f: f.write(serialized_engine) self._save_timing_cache(config) trtlog("TensorRT conversion complete.") return () # ------------------------------------------------------------------------- # Dynamic / Static wrapper nodes # ------------------------------------------------------------------------- class DYNAMIC_TRT_MODEL_CONVERSION(TRT_MODEL_CONVERSION_BASE): def __init__(self): super(DYNAMIC_TRT_MODEL_CONVERSION, self).__init__() @classmethod def INPUT_TYPES(s): return { "required": { "model": ("MODEL",), "filename_prefix": ("STRING", {"default": "tensorrt/ComfyUI_DYN"}), "batch_size_min": ( "INT", { "default": 1, "min": 1, "max": 100, "step": 1, }, ), "batch_size_opt": ( "INT", { "default": 1, "min": 1, "max": 100, "step": 1, }, ), "batch_size_max": ( "INT", { "default": 1, "min": 1, "max": 100, "step": 1, }, ), "height_min": ( "INT", { "default": 512, "min": 256, "max": 4096, "step": 64, }, ), "height_opt": ( "INT", { "default": 512, "min": 256, "max": 4096, "step": 64, }, ), "height_max": ( "INT", { "default": 512, "min": 256, "max": 4096, "step": 64, }, ), "width_min": ( "INT", { "default": 512, "min": 256, "max": 4096, "step": 64, }, ), "width_opt": ( "INT", { "default": 512, "min": 256, "max": 4096, "step": 64, }, ), "width_max": ( "INT", { "default": 512, "min": 256, "max": 4096, "step": 64, }, ), "context_min": ( "INT", { "default": 1, "min": 1, "max": 128, "step": 1, }, ), "context_opt": ( "INT", { "default": 1, "min": 1, "max": 128, "step": 1, }, ), "context_max": ( "INT", { "default": 1, "min": 1, "max": 128, "step": 1, }, ), "num_video_frames": ( "INT", { "default": 14, "min": 0, "max": 1000, "step": 1, }, ), }, } def convert( self, model, filename_prefix, batch_size_min, batch_size_opt, batch_size_max, height_min, height_opt, height_max, width_min, width_opt, width_max, context_min, context_opt, context_max, num_video_frames, ): return super()._convert( model, filename_prefix, batch_size_min, batch_size_opt, batch_size_max, height_min, height_opt, height_max, width_min, width_opt, width_max, context_min, context_opt, context_max, num_video_frames, is_static=False, ) class STATIC_TRT_MODEL_CONVERSION(TRT_MODEL_CONVERSION_BASE): def __init__(self): super(STATIC_TRT_MODEL_CONVERSION, self).__init__() @classmethod def INPUT_TYPES(s): return { "required": { "model": ("MODEL",), "filename_prefix": ("STRING", {"default": "tensorrt/ComfyUI_STAT"}), "batch_size_opt": ( "INT", { "default": 1, "min": 1, "max": 100, "step": 1, }, ), "height_opt": ( "INT", { "default": 512, "min": 256, "max": 4096, "step": 64, }, ), "width_opt": ( "INT", { "default": 512, "min": 256, "max": 4096, "step": 64, }, ), "context_opt": ( "INT", { "default": 1, "min": 1, "max": 128, "step": 1, }, ), "num_video_frames": ( "INT", { "default": 14, "min": 0, "max": 1000, "step": 1, }, ), }, } def convert( self, model, filename_prefix, batch_size_opt, height_opt, width_opt, context_opt, num_video_frames, ): # STATIC: all min/opt/max are identical return super()._convert( model, filename_prefix, batch_size_opt, batch_size_opt, batch_size_opt, height_opt, height_opt, height_opt, width_opt, width_opt, width_opt, context_opt, context_opt, context_opt, num_video_frames, is_static=True, ) NODE_CLASS_MAPPINGS = { "DYNAMIC_TRT_MODEL_CONVERSION": DYNAMIC_TRT_MODEL_CONVERSION, "STATIC_TRT_MODEL_CONVERSION": STATIC_TRT_MODEL_CONVERSION, }