|
|
import torch
|
|
|
import sys
|
|
|
import os
|
|
|
import time
|
|
|
import comfy.model_management
|
|
|
|
|
|
import tensorrt as trt
|
|
|
import folder_paths
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "tensorrt" in folder_paths.folder_names_and_paths:
|
|
|
folder_paths.folder_names_and_paths["tensorrt"][0].append(
|
|
|
os.path.join(folder_paths.get_output_directory(), "tensorrt")
|
|
|
)
|
|
|
folder_paths.folder_names_and_paths["tensorrt"][1].add(".engine")
|
|
|
else:
|
|
|
folder_paths.folder_names_and_paths["tensorrt"] = (
|
|
|
[os.path.join(folder_paths.get_output_directory(), "tensorrt")],
|
|
|
{".engine"},
|
|
|
)
|
|
|
|
|
|
class TQDMProgressMonitor(trt.IProgressMonitor):
|
|
|
def __init__(self):
|
|
|
trt.IProgressMonitor.__init__(self)
|
|
|
self._active_phases = {}
|
|
|
self._step_result = True
|
|
|
self.max_indent = 5
|
|
|
|
|
|
def phase_start(self, phase_name, parent_phase, num_steps):
|
|
|
leave = False
|
|
|
try:
|
|
|
if parent_phase is not None:
|
|
|
nbIndents = (
|
|
|
self._active_phases.get(parent_phase, {}).get(
|
|
|
"nbIndents", self.max_indent
|
|
|
)
|
|
|
+ 1
|
|
|
)
|
|
|
if nbIndents >= self.max_indent:
|
|
|
return
|
|
|
else:
|
|
|
nbIndents = 0
|
|
|
leave = True
|
|
|
self._active_phases[phase_name] = {
|
|
|
"tq": tqdm(
|
|
|
total=num_steps, desc=phase_name, leave=leave, position=nbIndents
|
|
|
),
|
|
|
"nbIndents": nbIndents,
|
|
|
"parent_phase": parent_phase,
|
|
|
}
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
|
_step_result = False
|
|
|
|
|
|
def phase_finish(self, phase_name):
|
|
|
try:
|
|
|
if phase_name in self._active_phases.keys():
|
|
|
self._active_phases[phase_name]["tq"].update(
|
|
|
self._active_phases[phase_name]["tq"].total
|
|
|
- self._active_phases[phase_name]["tq"].n
|
|
|
)
|
|
|
|
|
|
parent_phase = self._active_phases[phase_name].get("parent_phase", None)
|
|
|
while parent_phase is not None:
|
|
|
self._active_phases[parent_phase]["tq"].refresh()
|
|
|
parent_phase = self._active_phases[parent_phase].get(
|
|
|
"parent_phase", None
|
|
|
)
|
|
|
if (
|
|
|
self._active_phases[phase_name]["parent_phase"]
|
|
|
in self._active_phases.keys()
|
|
|
):
|
|
|
self._active_phases[
|
|
|
self._active_phases[phase_name]["parent_phase"]
|
|
|
]["tq"].refresh()
|
|
|
del self._active_phases[phase_name]
|
|
|
pass
|
|
|
except KeyboardInterrupt:
|
|
|
_step_result = False
|
|
|
|
|
|
def step_complete(self, phase_name, step):
|
|
|
try:
|
|
|
if phase_name in self._active_phases.keys():
|
|
|
self._active_phases[phase_name]["tq"].update(
|
|
|
step - self._active_phases[phase_name]["tq"].n
|
|
|
)
|
|
|
return self._step_result
|
|
|
except KeyboardInterrupt:
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
class TRT_MODEL_CONVERSION_BASE:
|
|
|
def __init__(self):
|
|
|
self.output_dir = folder_paths.get_output_directory()
|
|
|
self.temp_dir = folder_paths.get_temp_directory()
|
|
|
self.timing_cache_path = os.path.normpath(
|
|
|
os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "timing_cache.trt"))
|
|
|
)
|
|
|
|
|
|
RETURN_TYPES = ()
|
|
|
FUNCTION = "convert"
|
|
|
OUTPUT_NODE = True
|
|
|
CATEGORY = "TensorRT"
|
|
|
|
|
|
@classmethod
|
|
|
def INPUT_TYPES(s):
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
def _setup_timing_cache(self, config: trt.IBuilderConfig):
|
|
|
buffer = b""
|
|
|
if os.path.exists(self.timing_cache_path):
|
|
|
with open(self.timing_cache_path, mode="rb") as timing_cache_file:
|
|
|
buffer = timing_cache_file.read()
|
|
|
print("Read {} bytes from timing cache.".format(len(buffer)))
|
|
|
else:
|
|
|
print("No timing cache found; Initializing a new one.")
|
|
|
timing_cache: trt.ITimingCache = config.create_timing_cache(buffer)
|
|
|
config.set_timing_cache(timing_cache, ignore_mismatch=True)
|
|
|
|
|
|
|
|
|
def _save_timing_cache(self, config: trt.IBuilderConfig):
|
|
|
timing_cache: trt.ITimingCache = config.get_timing_cache()
|
|
|
with open(self.timing_cache_path, "wb") as timing_cache_file:
|
|
|
timing_cache_file.write(memoryview(timing_cache.serialize()))
|
|
|
|
|
|
def _convert(
|
|
|
self,
|
|
|
model,
|
|
|
filename_prefix,
|
|
|
batch_size_min,
|
|
|
batch_size_opt,
|
|
|
batch_size_max,
|
|
|
height_min,
|
|
|
height_opt,
|
|
|
height_max,
|
|
|
width_min,
|
|
|
width_opt,
|
|
|
width_max,
|
|
|
context_min,
|
|
|
context_opt,
|
|
|
context_max,
|
|
|
num_video_frames,
|
|
|
is_static: bool,
|
|
|
):
|
|
|
output_onnx = os.path.normpath(
|
|
|
os.path.join(
|
|
|
os.path.join(self.temp_dir, "{}".format(time.time())), "model.onnx"
|
|
|
)
|
|
|
)
|
|
|
|
|
|
comfy.model_management.unload_all_models()
|
|
|
comfy.model_management.load_models_gpu([model], force_patch_weights=True, force_full_load=True)
|
|
|
unet = model.model.diffusion_model
|
|
|
|
|
|
context_dim = model.model.model_config.unet_config.get("context_dim", None)
|
|
|
context_len = 77
|
|
|
context_len_min = context_len
|
|
|
y_dim = model.model.adm_channels
|
|
|
extra_input = {}
|
|
|
dtype = torch.float16
|
|
|
|
|
|
if isinstance(model.model, comfy.model_base.SD3):
|
|
|
context_embedder_config = model.model.model_config.unet_config.get("context_embedder_config", None)
|
|
|
if context_embedder_config is not None:
|
|
|
context_dim = context_embedder_config.get("params", {}).get("in_features", None)
|
|
|
context_len = 154
|
|
|
elif isinstance(model.model, comfy.model_base.AuraFlow):
|
|
|
context_dim = 2048
|
|
|
context_len_min = 256
|
|
|
context_len = 256
|
|
|
elif isinstance(model.model, comfy.model_base.Flux):
|
|
|
context_dim = model.model.model_config.unet_config.get("context_in_dim", None)
|
|
|
context_len_min = 256
|
|
|
context_len = 256
|
|
|
y_dim = model.model.model_config.unet_config.get("vec_in_dim", None)
|
|
|
extra_input = {"guidance": ()}
|
|
|
dtype = torch.bfloat16
|
|
|
|
|
|
if context_dim is not None:
|
|
|
input_names = ["x", "timesteps", "context"]
|
|
|
output_names = ["h"]
|
|
|
|
|
|
dynamic_axes = {
|
|
|
"x": {0: "batch", 2: "height", 3: "width"},
|
|
|
"timesteps": {0: "batch"},
|
|
|
"context": {0: "batch", 1: "num_embeds"},
|
|
|
}
|
|
|
|
|
|
transformer_options = model.model_options['transformer_options'].copy()
|
|
|
if model.model.model_config.unet_config.get(
|
|
|
"use_temporal_resblock", False
|
|
|
):
|
|
|
batch_size_min = num_video_frames * batch_size_min
|
|
|
batch_size_opt = num_video_frames * batch_size_opt
|
|
|
batch_size_max = num_video_frames * batch_size_max
|
|
|
|
|
|
class UNET(torch.nn.Module):
|
|
|
def forward(self, x, timesteps, context, y):
|
|
|
return self.unet(
|
|
|
x,
|
|
|
timesteps,
|
|
|
context,
|
|
|
y,
|
|
|
num_video_frames=self.num_video_frames,
|
|
|
transformer_options=self.transformer_options,
|
|
|
)
|
|
|
|
|
|
svd_unet = UNET()
|
|
|
svd_unet.num_video_frames = num_video_frames
|
|
|
svd_unet.unet = unet
|
|
|
svd_unet.transformer_options = transformer_options
|
|
|
unet = svd_unet
|
|
|
context_len_min = context_len = 1
|
|
|
else:
|
|
|
class UNET(torch.nn.Module):
|
|
|
def forward(self, x, timesteps, context, *args):
|
|
|
extras = input_names[3:]
|
|
|
extra_args = {}
|
|
|
for i in range(len(extras)):
|
|
|
extra_args[extras[i]] = args[i]
|
|
|
return self.unet(x, timesteps, context, transformer_options=self.transformer_options, **extra_args)
|
|
|
|
|
|
_unet = UNET()
|
|
|
_unet.unet = unet
|
|
|
_unet.transformer_options = transformer_options
|
|
|
unet = _unet
|
|
|
|
|
|
input_channels = model.model.model_config.unet_config.get("in_channels", 4)
|
|
|
|
|
|
inputs_shapes_min = (
|
|
|
(batch_size_min, input_channels, height_min // 8, width_min // 8),
|
|
|
(batch_size_min,),
|
|
|
(batch_size_min, context_len_min * context_min, context_dim),
|
|
|
)
|
|
|
inputs_shapes_opt = (
|
|
|
(batch_size_opt, input_channels, height_opt // 8, width_opt // 8),
|
|
|
(batch_size_opt,),
|
|
|
(batch_size_opt, context_len * context_opt, context_dim),
|
|
|
)
|
|
|
inputs_shapes_max = (
|
|
|
(batch_size_max, input_channels, height_max // 8, width_max // 8),
|
|
|
(batch_size_max,),
|
|
|
(batch_size_max, context_len * context_max, context_dim),
|
|
|
)
|
|
|
|
|
|
if y_dim > 0:
|
|
|
input_names.append("y")
|
|
|
dynamic_axes["y"] = {0: "batch"}
|
|
|
inputs_shapes_min += ((batch_size_min, y_dim),)
|
|
|
inputs_shapes_opt += ((batch_size_opt, y_dim),)
|
|
|
inputs_shapes_max += ((batch_size_max, y_dim),)
|
|
|
|
|
|
for k in extra_input:
|
|
|
input_names.append(k)
|
|
|
dynamic_axes[k] = {0: "batch"}
|
|
|
inputs_shapes_min += ((batch_size_min,) + extra_input[k],)
|
|
|
inputs_shapes_opt += ((batch_size_opt,) + extra_input[k],)
|
|
|
inputs_shapes_max += ((batch_size_max,) + extra_input[k],)
|
|
|
|
|
|
|
|
|
inputs = ()
|
|
|
for shape in inputs_shapes_opt:
|
|
|
inputs += (
|
|
|
torch.zeros(
|
|
|
shape,
|
|
|
device=comfy.model_management.get_torch_device(),
|
|
|
dtype=dtype,
|
|
|
),
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
print("ERROR: model not supported.")
|
|
|
return ()
|
|
|
|
|
|
os.makedirs(os.path.dirname(output_onnx), exist_ok=True)
|
|
|
torch.onnx.export(
|
|
|
unet,
|
|
|
inputs,
|
|
|
output_onnx,
|
|
|
verbose=False,
|
|
|
input_names=input_names,
|
|
|
output_names=output_names,
|
|
|
opset_version=17,
|
|
|
dynamic_axes=dynamic_axes,
|
|
|
)
|
|
|
|
|
|
comfy.model_management.unload_all_models()
|
|
|
comfy.model_management.soft_empty_cache()
|
|
|
|
|
|
|
|
|
logger = trt.Logger(trt.Logger.INFO)
|
|
|
builder = trt.Builder(logger)
|
|
|
|
|
|
network = builder.create_network(
|
|
|
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
|
|
)
|
|
|
parser = trt.OnnxParser(network, logger)
|
|
|
success = parser.parse_from_file(output_onnx)
|
|
|
for idx in range(parser.num_errors):
|
|
|
print(parser.get_error(idx))
|
|
|
|
|
|
if not success:
|
|
|
print("ONNX load ERROR")
|
|
|
return ()
|
|
|
|
|
|
config = builder.create_builder_config()
|
|
|
profile = builder.create_optimization_profile()
|
|
|
self._setup_timing_cache(config)
|
|
|
config.progress_monitor = TQDMProgressMonitor()
|
|
|
|
|
|
prefix_encode = ""
|
|
|
for k in range(len(input_names)):
|
|
|
min_shape = inputs_shapes_min[k]
|
|
|
opt_shape = inputs_shapes_opt[k]
|
|
|
max_shape = inputs_shapes_max[k]
|
|
|
profile.set_shape(input_names[k], min_shape, opt_shape, max_shape)
|
|
|
|
|
|
|
|
|
encode = lambda a: ".".join(map(lambda x: str(x), a))
|
|
|
prefix_encode += "{}#{}#{}#{};".format(
|
|
|
input_names[k], encode(min_shape), encode(opt_shape), encode(max_shape)
|
|
|
)
|
|
|
|
|
|
if dtype == torch.float16:
|
|
|
config.set_flag(trt.BuilderFlag.FP16)
|
|
|
if dtype == torch.bfloat16:
|
|
|
config.set_flag(trt.BuilderFlag.BF16)
|
|
|
|
|
|
config.add_optimization_profile(profile)
|
|
|
|
|
|
if is_static:
|
|
|
filename_prefix = "{}_${}".format(
|
|
|
filename_prefix,
|
|
|
"-".join(
|
|
|
(
|
|
|
"stat",
|
|
|
"b",
|
|
|
str(batch_size_opt),
|
|
|
"h",
|
|
|
str(height_opt),
|
|
|
"w",
|
|
|
str(width_opt),
|
|
|
)
|
|
|
),
|
|
|
)
|
|
|
else:
|
|
|
filename_prefix = "{}_${}".format(
|
|
|
filename_prefix,
|
|
|
"-".join(
|
|
|
(
|
|
|
"dyn",
|
|
|
"b",
|
|
|
str(batch_size_min),
|
|
|
str(batch_size_max),
|
|
|
str(batch_size_opt),
|
|
|
"h",
|
|
|
str(height_min),
|
|
|
str(height_max),
|
|
|
str(height_opt),
|
|
|
"w",
|
|
|
str(width_min),
|
|
|
str(width_max),
|
|
|
str(width_opt),
|
|
|
)
|
|
|
),
|
|
|
)
|
|
|
|
|
|
serialized_engine = builder.build_serialized_network(network, config)
|
|
|
|
|
|
full_output_folder, filename, counter, subfolder, filename_prefix = (
|
|
|
folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
|
|
)
|
|
|
output_trt_engine = os.path.join(
|
|
|
full_output_folder, f"{filename}_{counter:05}_.engine"
|
|
|
)
|
|
|
|
|
|
with open(output_trt_engine, "wb") as f:
|
|
|
f.write(serialized_engine)
|
|
|
|
|
|
self._save_timing_cache(config)
|
|
|
|
|
|
return ()
|
|
|
|
|
|
|
|
|
class DYNAMIC_TRT_MODEL_CONVERSION(TRT_MODEL_CONVERSION_BASE):
|
|
|
def __init__(self):
|
|
|
super(DYNAMIC_TRT_MODEL_CONVERSION, self).__init__()
|
|
|
|
|
|
@classmethod
|
|
|
def INPUT_TYPES(s):
|
|
|
return {
|
|
|
"required": {
|
|
|
"model": ("MODEL",),
|
|
|
"filename_prefix": ("STRING", {"default": "tensorrt/ComfyUI_DYN"}),
|
|
|
"batch_size_min": (
|
|
|
"INT",
|
|
|
{
|
|
|
"default": 1,
|
|
|
"min": 1,
|
|
|
"max": 100,
|
|
|
"step": 1,
|
|
|
},
|
|
|
),
|
|
|
"batch_size_opt": (
|
|
|
"INT",
|
|
|
{
|
|
|
"default": 1,
|
|
|
"min": 1,
|
|
|
"max": 100,
|
|
|
"step": 1,
|
|
|
},
|
|
|
),
|
|
|
"batch_size_max": (
|
|
|
"INT",
|
|
|
{
|
|
|
"default": 1,
|
|
|
"min": 1,
|
|
|
"max": 100,
|
|
|
"step": 1,
|
|
|
},
|
|
|
),
|
|
|
"height_min": (
|
|
|
"INT",
|
|
|
{
|
|
|
"default": 512,
|
|
|
"min": 256,
|
|
|
"max": 4096,
|
|
|
"step": 64,
|
|
|
},
|
|
|
),
|
|
|
"height_opt": (
|
|
|
"INT",
|
|
|
{
|
|
|
"default": 512,
|
|
|
"min": 256,
|
|
|
"max": 4096,
|
|
|
"step": 64,
|
|
|
},
|
|
|
),
|
|
|
"height_max": (
|
|
|
"INT",
|
|
|
{
|
|
|
"default": 512,
|
|
|
"min": 256,
|
|
|
"max": 4096,
|
|
|
"step": 64,
|
|
|
},
|
|
|
),
|
|
|
"width_min": (
|
|
|
"INT",
|
|
|
{
|
|
|
"default": 512,
|
|
|
"min": 256,
|
|
|
"max": 4096,
|
|
|
"step": 64,
|
|
|
},
|
|
|
),
|
|
|
"width_opt": (
|
|
|
"INT",
|
|
|
{
|
|
|
"default": 512,
|
|
|
"min": 256,
|
|
|
"max": 4096,
|
|
|
"step": 64,
|
|
|
},
|
|
|
),
|
|
|
"width_max": (
|
|
|
"INT",
|
|
|
{
|
|
|
"default": 512,
|
|
|
"min": 256,
|
|
|
"max": 4096,
|
|
|
"step": 64,
|
|
|
},
|
|
|
),
|
|
|
"context_min": (
|
|
|
"INT",
|
|
|
{
|
|
|
"default": 1,
|
|
|
"min": 1,
|
|
|
"max": 128,
|
|
|
"step": 1,
|
|
|
},
|
|
|
),
|
|
|
"context_opt": (
|
|
|
"INT",
|
|
|
{
|
|
|
"default": 1,
|
|
|
"min": 1,
|
|
|
"max": 128,
|
|
|
"step": 1,
|
|
|
},
|
|
|
),
|
|
|
"context_max": (
|
|
|
"INT",
|
|
|
{
|
|
|
"default": 1,
|
|
|
"min": 1,
|
|
|
"max": 128,
|
|
|
"step": 1,
|
|
|
},
|
|
|
),
|
|
|
"num_video_frames": (
|
|
|
"INT",
|
|
|
{
|
|
|
"default": 14,
|
|
|
"min": 0,
|
|
|
"max": 1000,
|
|
|
"step": 1,
|
|
|
},
|
|
|
),
|
|
|
},
|
|
|
}
|
|
|
|
|
|
def convert(
|
|
|
self,
|
|
|
model,
|
|
|
filename_prefix,
|
|
|
batch_size_min,
|
|
|
batch_size_opt,
|
|
|
batch_size_max,
|
|
|
height_min,
|
|
|
height_opt,
|
|
|
height_max,
|
|
|
width_min,
|
|
|
width_opt,
|
|
|
width_max,
|
|
|
context_min,
|
|
|
context_opt,
|
|
|
context_max,
|
|
|
num_video_frames,
|
|
|
):
|
|
|
return super()._convert(
|
|
|
model,
|
|
|
filename_prefix,
|
|
|
batch_size_min,
|
|
|
batch_size_opt,
|
|
|
batch_size_max,
|
|
|
height_min,
|
|
|
height_opt,
|
|
|
height_max,
|
|
|
width_min,
|
|
|
width_opt,
|
|
|
width_max,
|
|
|
context_min,
|
|
|
context_opt,
|
|
|
context_max,
|
|
|
num_video_frames,
|
|
|
is_static=False,
|
|
|
)
|
|
|
|
|
|
|
|
|
class STATIC_TRT_MODEL_CONVERSION(TRT_MODEL_CONVERSION_BASE):
|
|
|
def __init__(self):
|
|
|
super(STATIC_TRT_MODEL_CONVERSION, self).__init__()
|
|
|
|
|
|
@classmethod
|
|
|
def INPUT_TYPES(s):
|
|
|
return {
|
|
|
"required": {
|
|
|
"model": ("MODEL",),
|
|
|
"filename_prefix": ("STRING", {"default": "tensorrt/ComfyUI_STAT"}),
|
|
|
"batch_size_opt": (
|
|
|
"INT",
|
|
|
{
|
|
|
"default": 1,
|
|
|
"min": 1,
|
|
|
"max": 100,
|
|
|
"step": 1,
|
|
|
},
|
|
|
),
|
|
|
"height_opt": (
|
|
|
"INT",
|
|
|
{
|
|
|
"default": 512,
|
|
|
"min": 256,
|
|
|
"max": 4096,
|
|
|
"step": 64,
|
|
|
},
|
|
|
),
|
|
|
"width_opt": (
|
|
|
"INT",
|
|
|
{
|
|
|
"default": 512,
|
|
|
"min": 256,
|
|
|
"max": 4096,
|
|
|
"step": 64,
|
|
|
},
|
|
|
),
|
|
|
"context_opt": (
|
|
|
"INT",
|
|
|
{
|
|
|
"default": 1,
|
|
|
"min": 1,
|
|
|
"max": 128,
|
|
|
"step": 1,
|
|
|
},
|
|
|
),
|
|
|
"num_video_frames": (
|
|
|
"INT",
|
|
|
{
|
|
|
"default": 14,
|
|
|
"min": 0,
|
|
|
"max": 1000,
|
|
|
"step": 1,
|
|
|
},
|
|
|
),
|
|
|
},
|
|
|
}
|
|
|
|
|
|
def convert(
|
|
|
self,
|
|
|
model,
|
|
|
filename_prefix,
|
|
|
batch_size_opt,
|
|
|
height_opt,
|
|
|
width_opt,
|
|
|
context_opt,
|
|
|
num_video_frames,
|
|
|
):
|
|
|
return super()._convert(
|
|
|
model,
|
|
|
filename_prefix,
|
|
|
batch_size_opt,
|
|
|
batch_size_opt,
|
|
|
batch_size_opt,
|
|
|
height_opt,
|
|
|
height_opt,
|
|
|
height_opt,
|
|
|
width_opt,
|
|
|
width_opt,
|
|
|
width_opt,
|
|
|
context_opt,
|
|
|
context_opt,
|
|
|
context_opt,
|
|
|
num_video_frames,
|
|
|
is_static=True,
|
|
|
)
|
|
|
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
|
"DYNAMIC_TRT_MODEL_CONVERSION": DYNAMIC_TRT_MODEL_CONVERSION,
|
|
|
"STATIC_TRT_MODEL_CONVERSION": STATIC_TRT_MODEL_CONVERSION,
|
|
|
}
|
|
|
|