MyCustomNodes / tensorrt_convert.py
saliacoel's picture
Upload tensorrt_convert.py
724b368 verified
import os
import sys
import time
import torch
import comfy.model_management
import tensorrt as trt
import folder_paths
from tqdm import tqdm
# -------------------------------------------------------------------------
# torch.export dynamic shapes support
# -------------------------------------------------------------------------
try:
from torch.export import Dim
except Exception as e:
raise RuntimeError(
"[TensorRTExport] torch.export.Dim not available. "
"Please upgrade PyTorch to >= 2.1 / 2.5+ to use the Dynamo-based "
"ONNX exporter with dynamic shapes."
) from e
def trtlog(msg: str):
print(f"[TensorRTExport] {msg}", flush=True)
# Opset handling:
# - If COMFY_TRT_ONNX_OPSET is set, use that integer.
# - Otherwise, leave opset_version=None so torch.onnx uses the
# recommended opset for this PyTorch version (e.g. 20 in 2.9+).
DEFAULT_ONNX_OPSET = None
_env_opset = os.getenv("COMFY_TRT_ONNX_OPSET")
if _env_opset is not None:
try:
DEFAULT_ONNX_OPSET = int(_env_opset)
trtlog(f"Using opset_version from COMFY_TRT_ONNX_OPSET={DEFAULT_ONNX_OPSET}")
except ValueError:
trtlog(
f"WARNING: invalid COMFY_TRT_ONNX_OPSET={_env_opset!r}, "
"falling back to PyTorch recommended opset (None)."
)
DEFAULT_ONNX_OPSET = None
# -------------------------------------------------------------------------
# Add output directory to TensorRT search path (ComfyUI integration)
# -------------------------------------------------------------------------
if "tensorrt" in folder_paths.folder_names_and_paths:
folder_paths.folder_names_and_paths["tensorrt"][0].append(
os.path.join(folder_paths.get_output_directory(), "tensorrt")
)
folder_paths.folder_names_and_paths["tensorrt"][1].add(".engine")
else:
folder_paths.folder_names_and_paths["tensorrt"] = (
[os.path.join(folder_paths.get_output_directory(), "tensorrt")],
{".engine"},
)
# -------------------------------------------------------------------------
# Progress monitor for TensorRT builds
# -------------------------------------------------------------------------
class TQDMProgressMonitor(trt.IProgressMonitor):
def __init__(self):
trt.IProgressMonitor.__init__(self)
self._active_phases = {}
self._step_result = True
self.max_indent = 5
def phase_start(self, phase_name, parent_phase, num_steps):
leave = False
try:
if parent_phase is not None:
nbIndents = (
self._active_phases.get(parent_phase, {}).get(
"nbIndents", self.max_indent
)
+ 1
)
if nbIndents >= self.max_indent:
return
else:
nbIndents = 0
leave = True
self._active_phases[phase_name] = {
"tq": tqdm(
total=num_steps, desc=phase_name, leave=leave, position=nbIndents
),
"nbIndents": nbIndents,
"parent_phase": parent_phase,
}
except KeyboardInterrupt:
# The phase_start callback cannot directly cancel the build,
# so request the cancellation from within step_complete.
self._step_result = False
def phase_finish(self, phase_name):
try:
if phase_name in self._active_phases.keys():
self._active_phases[phase_name]["tq"].update(
self._active_phases[phase_name]["tq"].total
- self._active_phases[phase_name]["tq"].n
)
parent_phase = self._active_phases[phase_name].get("parent_phase", None)
while parent_phase is not None:
self._active_phases[parent_phase]["tq"].refresh()
parent_phase = self._active_phases[parent_phase].get(
"parent_phase", None
)
if (
self._active_phases[phase_name]["parent_phase"]
in self._active_phases.keys()
):
self._active_phases[
self._active_phases[phase_name]["parent_phase"]
]["tq"].refresh()
del self._active_phases[phase_name]
except KeyboardInterrupt:
self._step_result = False
def step_complete(self, phase_name, step):
try:
if phase_name in self._active_phases.keys():
self._active_phases[phase_name]["tq"].update(
step - self._active_phases[phase_name]["tq"].n
)
return self._step_result
except KeyboardInterrupt:
# There is no need to propagate this exception to TensorRT.
# We can simply cancel the build.
return False
# -------------------------------------------------------------------------
# Base class for ONNX -> TensorRT conversion
# -------------------------------------------------------------------------
class TRT_MODEL_CONVERSION_BASE:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.temp_dir = folder_paths.get_temp_directory()
self.timing_cache_path = os.path.normpath(
os.path.join(
os.path.dirname(os.path.realpath(__file__)), "timing_cache.trt"
)
)
RETURN_TYPES = ()
FUNCTION = "convert"
OUTPUT_NODE = True
CATEGORY = "TensorRT"
@classmethod
def INPUT_TYPES(s):
raise NotImplementedError
# Sets up the builder to use the timing cache file, and creates it if it does not already exist
def _setup_timing_cache(self, config: trt.IBuilderConfig):
buffer = b""
if os.path.exists(self.timing_cache_path):
with open(self.timing_cache_path, mode="rb") as timing_cache_file:
buffer = timing_cache_file.read()
trtlog(f"Read {len(buffer)} bytes from timing cache.")
else:
trtlog("No timing cache found; initializing a new one.")
timing_cache: trt.ITimingCache = config.create_timing_cache(buffer)
config.set_timing_cache(timing_cache, ignore_mismatch=True)
# Saves the config's timing cache to file
def _save_timing_cache(self, config: trt.IBuilderConfig):
timing_cache: trt.ITimingCache = config.get_timing_cache()
with open(self.timing_cache_path, "wb") as timing_cache_file:
serialized = timing_cache.serialize()
timing_cache_file.write(memoryview(serialized))
trtlog(f"Timing cache saved to {self.timing_cache_path}")
def _convert(
self,
model,
filename_prefix,
batch_size_min,
batch_size_opt,
batch_size_max,
height_min,
height_opt,
height_max,
width_min,
width_opt,
width_max,
context_min,
context_opt,
context_max,
num_video_frames,
is_static: bool,
):
# -----------------------------------------------------------------
# Basic logging: versions & configuration
# -----------------------------------------------------------------
trtlog(
f"PyTorch version: {torch.__version__}, TensorRT version: {trt.__version__}"
)
trtlog(
f"Requested {'STATIC' if is_static else 'DYNAMIC'} TensorRT engine "
f"(b=[{batch_size_min},{batch_size_opt},{batch_size_max}], "
f"h=[{height_min},{height_opt},{height_max}], "
f"w=[{width_min},{width_opt},{width_max}], "
f"context=[{context_min},{context_opt},{context_max}], "
f"num_video_frames={num_video_frames})"
)
output_onnx = os.path.normpath(
os.path.join(
os.path.join(self.temp_dir, "{}".format(time.time())), "model.onnx"
)
)
trtlog(f"Temporary ONNX path: {output_onnx}")
# -----------------------------------------------------------------
# Load model to GPU
# -----------------------------------------------------------------
comfy.model_management.unload_all_models()
comfy.model_management.load_models_gpu(
[model], force_patch_weights=True, force_full_load=True
)
unet = model.model.diffusion_model
model_type = type(model.model).__name__
trtlog(f"Detected model type: {model_type}")
context_dim = model.model.model_config.unet_config.get("context_dim", None)
context_len = 77
context_len_min = context_len
y_dim = model.model.adm_channels
extra_input = {}
dtype = torch.float16
# -----------------------------------------------------------------
# Model-type specific tweaks
# -----------------------------------------------------------------
if isinstance(model.model, comfy.model_base.SD3): # SD3
context_embedder_config = model.model.model_config.unet_config.get(
"context_embedder_config", None
)
if context_embedder_config is not None:
context_dim = context_embedder_config.get(
"params", {}
).get("in_features", None)
# SD3 can have 77 or 154 depending on TE usage
context_len = 154
trtlog(f"SD3 context_dim={context_dim}, context_len={context_len}")
elif isinstance(model.model, comfy.model_base.AuraFlow):
context_dim = 2048
context_len_min = 256
context_len = 256
trtlog(
f"AuraFlow context_dim={context_dim}, "
f"context_len_min={context_len_min}, context_len={context_len}"
)
elif isinstance(model.model, comfy.model_base.Flux):
context_dim = model.model.model_config.unet_config.get(
"context_in_dim", None
)
context_len_min = 256
context_len = 256
y_dim = model.model.model_config.unet_config.get("vec_in_dim", None)
extra_input = {"guidance": ()}
dtype = torch.bfloat16
trtlog(
f"Flux context_dim={context_dim}, y_dim={y_dim}, "
f"context_len_min={context_len_min}, context_len={context_len}, "
f"extra_input={list(extra_input.keys())}, dtype={dtype}"
)
if context_dim is None:
print("ERROR: model not supported (no context_dim).")
comfy.model_management.unload_all_models()
comfy.model_management.soft_empty_cache()
return ()
input_names = ["x", "timesteps", "context"]
output_names = ["h"]
transformer_options = model.model_options["transformer_options"].copy()
use_temporal = model.model.model_config.unet_config.get(
"use_temporal_resblock", False
)
# -----------------------------------------------------------------
# Wrap UNet so argument names are stable for dynamic_shapes
# -----------------------------------------------------------------
if use_temporal: # SVD
trtlog("Model uses temporal resblock (SVD-like). Adjusting batch sizes.")
batch_size_min = num_video_frames * batch_size_min
batch_size_opt = num_video_frames * batch_size_opt
batch_size_max = num_video_frames * batch_size_max
class SVD_UNET(torch.nn.Module):
def __init__(self, unet, transformer_options, num_video_frames):
super().__init__()
self.unet = unet
self.transformer_options = transformer_options
self.num_video_frames = num_video_frames
def forward(self, x, timesteps, context, y):
return self.unet(
x,
timesteps,
context,
y,
num_video_frames=self.num_video_frames,
transformer_options=self.transformer_options,
)
unet = SVD_UNET(unet, transformer_options, num_video_frames)
context_len_min = context_len = 1
trtlog(
f"SVD adjusted batch: "
f"b=[{batch_size_min},{batch_size_opt},{batch_size_max}], "
f"context_len_min={context_len_min}, context_len={context_len}"
)
else:
# Generic wrapper with named extras (y, guidance)
extra_keys = list(extra_input.keys())
class UNET(torch.nn.Module):
def __init__(self, unet, transformer_options, y_dim, extra_keys):
super().__init__()
self.unet = unet
self.transformer_options = transformer_options
self.y_dim = y_dim
self.extra_keys = extra_keys
def forward(self, x, timesteps, context, y=None, guidance=None):
extra_args = {}
if self.y_dim is not None and self.y_dim > 0 and y is not None:
extra_args["y"] = y
if "guidance" in self.extra_keys and guidance is not None:
extra_args["guidance"] = guidance
return self.unet(
x,
timesteps,
context,
transformer_options=self.transformer_options,
**extra_args,
)
unet = UNET(unet, transformer_options, y_dim, extra_keys)
# -----------------------------------------------------------------
# Compute input shapes (min / opt / max)
# -----------------------------------------------------------------
input_channels = model.model.model_config.unet_config.get("in_channels", 4)
inputs_shapes_min = (
(batch_size_min, input_channels, height_min // 8, width_min // 8),
(batch_size_min,),
(batch_size_min, context_len_min * context_min, context_dim),
)
inputs_shapes_opt = (
(batch_size_opt, input_channels, height_opt // 8, width_opt // 8),
(batch_size_opt,),
(batch_size_opt, context_len * context_opt, context_dim),
)
inputs_shapes_max = (
(batch_size_max, input_channels, height_max // 8, width_max // 8),
(batch_size_max,),
(batch_size_max, context_len * context_max, context_dim),
)
if y_dim is not None and y_dim > 0:
input_names.append("y")
inputs_shapes_min += ((batch_size_min, y_dim),)
inputs_shapes_opt += ((batch_size_opt, y_dim),)
inputs_shapes_max += ((batch_size_max, y_dim),)
# Extra inputs (currently used for Flux guidance)
for k in extra_input:
input_names.append(k)
shape_suffix = extra_input[k] # e.g. () for scalar per batch
inputs_shapes_min += ((batch_size_min,) + shape_suffix,)
inputs_shapes_opt += ((batch_size_opt,) + shape_suffix,)
inputs_shapes_max += ((batch_size_max,) + shape_suffix,)
# Clamp context ranges sanely if the UI somehow passed inverted min/max
if context_max < context_min:
trtlog(
f"WARNING: context_max({context_max}) < context_min({context_min}), swapping."
)
context_min, context_max = context_max, context_min
trtlog("Input names: " + ", ".join(input_names))
for idx, name in enumerate(input_names):
trtlog(
f" {name}: "
f"min={inputs_shapes_min[idx]}, "
f"opt={inputs_shapes_opt[idx]}, "
f"max={inputs_shapes_max[idx]}"
)
# -----------------------------------------------------------------
# Build dynamic_shapes spec for torch.export / dynamo=True
# - STATIC node: no dynamic_shapes at all (fully static export)
# - DYNAMIC node: only create Dim if max > min
# -----------------------------------------------------------------
dynamic_shapes = None
def _maybe_dim(name: str, min_v: int, max_v: int):
"""Create Dim only if there is real dynamism (max > min)."""
if max_v < min_v:
trtlog(
f"WARNING: Dim {name} has min>{max_v}>{min_v}, swapping to fix."
)
min_v, max_v = max_v, min_v
if max_v > min_v:
trtlog(f"Dim {name}: dynamic range [{min_v}, {max_v}]")
return Dim(name, min=min_v, max=max_v)
else:
trtlog(f"Dim {name}: static value {min_v}, not using Dim.")
return None
if not is_static:
# Only build dynamic_shapes for the DYNAMIC node
B = _maybe_dim("batch", batch_size_min, batch_size_max)
H = _maybe_dim("height", height_min // 8, height_max // 8)
W = _maybe_dim("width", width_min // 8, width_max // 8)
tokens_min = context_len_min * context_min
tokens_max = context_len * context_max
T = _maybe_dim("tokens", tokens_min, tokens_max)
dynamic_shapes = {}
# x: [B, C, H, W]
x_dyn = {}
if B is not None:
x_dyn[0] = B
if H is not None:
x_dyn[2] = H
if W is not None:
x_dyn[3] = W
if x_dyn:
dynamic_shapes["x"] = x_dyn
# timesteps: [B]
if B is not None:
dynamic_shapes["timesteps"] = {0: B}
# context: [B, T, context_dim]
ctx_dyn = {}
if B is not None:
ctx_dyn[0] = B
if T is not None:
ctx_dyn[1] = T
if ctx_dyn:
dynamic_shapes["context"] = ctx_dyn
# y: [B, y_dim]
if "y" in input_names and B is not None:
dynamic_shapes["y"] = {0: B}
# guidance: [B, ...]
if "guidance" in input_names and B is not None:
dynamic_shapes["guidance"] = {0: B}
if not dynamic_shapes:
trtlog(
"No dimensions are actually dynamic for DYNAMIC node. "
"Export will effectively be static."
)
dynamic_shapes = None
else:
trtlog(f"dynamic_shapes spec: {dynamic_shapes}")
else:
trtlog("STATIC node: skipping torch.export.Dim and dynamic_shapes entirely.")
# -----------------------------------------------------------------
# Build example inputs (using OPT shapes)
# -----------------------------------------------------------------
inputs = ()
for shape in inputs_shapes_opt:
inputs += (
torch.zeros(
shape,
device=comfy.model_management.get_torch_device(),
dtype=dtype,
),
)
# -----------------------------------------------------------------
# ONNX export with Dynamo (dynamo=True)
# - For static: dynamic_shapes=None, so shapes are fully specialized.
# - For dynamic: dynamic_shapes guides symbolic shapes.
# -----------------------------------------------------------------
os.makedirs(os.path.dirname(output_onnx), exist_ok=True)
trtlog(
f"Exporting UNet to ONNX with dynamo=True, "
f"opset_version={DEFAULT_ONNX_OPSET}, dtype={dtype}, "
f"output={output_onnx}"
)
if dynamic_shapes is None:
trtlog("ONNX export will be STATIC (no dynamic_shapes).")
else:
trtlog("ONNX export will use dynamic_shapes (see spec above).")
try:
torch.onnx.export(
unet,
inputs,
output_onnx,
verbose=False,
input_names=input_names,
output_names=output_names,
opset_version=DEFAULT_ONNX_OPSET,
dynamo=True,
dynamic_shapes=dynamic_shapes,
# NOTE:
# - We intentionally do NOT pass dynamic_axes here.
# dynamic_axes is for the legacy TorchScript exporter.
)
trtlog("torch.onnx.export completed successfully.")
except Exception as e:
trtlog(f"ERROR during torch.onnx.export: {e}")
# Clean up GPU state before re-raising
comfy.model_management.unload_all_models()
comfy.model_management.soft_empty_cache()
raise
comfy.model_management.unload_all_models()
comfy.model_management.soft_empty_cache()
# -----------------------------------------------------------------
# TensorRT conversion starts here
# -----------------------------------------------------------------
logger = trt.Logger(trt.Logger.INFO)
builder = trt.Builder(logger)
trtlog("Created TensorRT builder.")
network = builder.create_network(
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, logger)
trtlog(f"Parsing ONNX file: {output_onnx}")
success = parser.parse_from_file(output_onnx)
for idx in range(parser.num_errors):
print(parser.get_error(idx))
if not success:
print("ONNX load ERROR (TensorRT parser.parse_from_file returned False).")
return ()
config = builder.create_builder_config()
profile = builder.create_optimization_profile()
self._setup_timing_cache(config)
config.progress_monitor = TQDMProgressMonitor()
trtlog("Creating optimization profile:")
prefix_encode = ""
for k in range(len(input_names)):
min_shape = inputs_shapes_min[k]
opt_shape = inputs_shapes_opt[k]
max_shape = inputs_shapes_max[k]
trtlog(
f" {input_names[k]}: min={min_shape}, opt={opt_shape}, max={max_shape}"
)
profile.set_shape(input_names[k], min_shape, opt_shape, max_shape)
# Encode shapes to filename
encode = lambda a: ".".join(map(str, a))
prefix_encode += "{}#{}#{}#{};".format(
input_names[k], encode(min_shape), encode(opt_shape), encode(max_shape)
)
if dtype == torch.float16:
trtlog("Enabling FP16 mode in TensorRT builder config.")
config.set_flag(trt.BuilderFlag.FP16)
if dtype == torch.bfloat16:
trtlog("Enabling BF16 mode in TensorRT builder config.")
config.set_flag(trt.BuilderFlag.BF16)
config.add_optimization_profile(profile)
if is_static:
filename_prefix = "{}_${}".format(
filename_prefix,
"-".join(
(
"stat",
"b",
str(batch_size_opt),
"h",
str(height_opt),
"w",
str(width_opt),
)
),
)
else:
filename_prefix = "{}_${}".format(
filename_prefix,
"-".join(
(
"dyn",
"b",
str(batch_size_min),
str(batch_size_max),
str(batch_size_opt),
"h",
str(height_min),
str(height_max),
str(height_opt),
"w",
str(width_min),
str(width_max),
str(width_opt),
)
),
)
trtlog("Building serialized TensorRT engine. This may take a while...")
serialized_engine = builder.build_serialized_network(network, config)
if serialized_engine is None:
trtlog("ERROR: builder.build_serialized_network returned None.")
return ()
full_output_folder, filename, counter, subfolder, filename_prefix = (
folder_paths.get_save_image_path(filename_prefix, self.output_dir)
)
output_trt_engine = os.path.join(
full_output_folder, f"{filename}_{counter:05}_.engine"
)
trtlog(f"Writing TensorRT engine to: {output_trt_engine}")
os.makedirs(full_output_folder, exist_ok=True)
with open(output_trt_engine, "wb") as f:
f.write(serialized_engine)
self._save_timing_cache(config)
trtlog("TensorRT conversion complete.")
return ()
# -------------------------------------------------------------------------
# Dynamic / Static wrapper nodes
# -------------------------------------------------------------------------
class DYNAMIC_TRT_MODEL_CONVERSION(TRT_MODEL_CONVERSION_BASE):
def __init__(self):
super(DYNAMIC_TRT_MODEL_CONVERSION, self).__init__()
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"filename_prefix": ("STRING", {"default": "tensorrt/ComfyUI_DYN"}),
"batch_size_min": (
"INT",
{
"default": 1,
"min": 1,
"max": 100,
"step": 1,
},
),
"batch_size_opt": (
"INT",
{
"default": 1,
"min": 1,
"max": 100,
"step": 1,
},
),
"batch_size_max": (
"INT",
{
"default": 1,
"min": 1,
"max": 100,
"step": 1,
},
),
"height_min": (
"INT",
{
"default": 512,
"min": 256,
"max": 4096,
"step": 64,
},
),
"height_opt": (
"INT",
{
"default": 512,
"min": 256,
"max": 4096,
"step": 64,
},
),
"height_max": (
"INT",
{
"default": 512,
"min": 256,
"max": 4096,
"step": 64,
},
),
"width_min": (
"INT",
{
"default": 512,
"min": 256,
"max": 4096,
"step": 64,
},
),
"width_opt": (
"INT",
{
"default": 512,
"min": 256,
"max": 4096,
"step": 64,
},
),
"width_max": (
"INT",
{
"default": 512,
"min": 256,
"max": 4096,
"step": 64,
},
),
"context_min": (
"INT",
{
"default": 1,
"min": 1,
"max": 128,
"step": 1,
},
),
"context_opt": (
"INT",
{
"default": 1,
"min": 1,
"max": 128,
"step": 1,
},
),
"context_max": (
"INT",
{
"default": 1,
"min": 1,
"max": 128,
"step": 1,
},
),
"num_video_frames": (
"INT",
{
"default": 14,
"min": 0,
"max": 1000,
"step": 1,
},
),
},
}
def convert(
self,
model,
filename_prefix,
batch_size_min,
batch_size_opt,
batch_size_max,
height_min,
height_opt,
height_max,
width_min,
width_opt,
width_max,
context_min,
context_opt,
context_max,
num_video_frames,
):
return super()._convert(
model,
filename_prefix,
batch_size_min,
batch_size_opt,
batch_size_max,
height_min,
height_opt,
height_max,
width_min,
width_opt,
width_max,
context_min,
context_opt,
context_max,
num_video_frames,
is_static=False,
)
class STATIC_TRT_MODEL_CONVERSION(TRT_MODEL_CONVERSION_BASE):
def __init__(self):
super(STATIC_TRT_MODEL_CONVERSION, self).__init__()
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"filename_prefix": ("STRING", {"default": "tensorrt/ComfyUI_STAT"}),
"batch_size_opt": (
"INT",
{
"default": 1,
"min": 1,
"max": 100,
"step": 1,
},
),
"height_opt": (
"INT",
{
"default": 512,
"min": 256,
"max": 4096,
"step": 64,
},
),
"width_opt": (
"INT",
{
"default": 512,
"min": 256,
"max": 4096,
"step": 64,
},
),
"context_opt": (
"INT",
{
"default": 1,
"min": 1,
"max": 128,
"step": 1,
},
),
"num_video_frames": (
"INT",
{
"default": 14,
"min": 0,
"max": 1000,
"step": 1,
},
),
},
}
def convert(
self,
model,
filename_prefix,
batch_size_opt,
height_opt,
width_opt,
context_opt,
num_video_frames,
):
# STATIC: all min/opt/max are identical
return super()._convert(
model,
filename_prefix,
batch_size_opt,
batch_size_opt,
batch_size_opt,
height_opt,
height_opt,
height_opt,
width_opt,
width_opt,
width_opt,
context_opt,
context_opt,
context_opt,
num_video_frames,
is_static=True,
)
NODE_CLASS_MAPPINGS = {
"DYNAMIC_TRT_MODEL_CONVERSION": DYNAMIC_TRT_MODEL_CONVERSION,
"STATIC_TRT_MODEL_CONVERSION": STATIC_TRT_MODEL_CONVERSION,
}