|
|
import os |
|
|
import sys |
|
|
import time |
|
|
|
|
|
import torch |
|
|
import comfy.model_management |
|
|
|
|
|
import tensorrt as trt |
|
|
import folder_paths |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from torch.export import Dim |
|
|
except Exception as e: |
|
|
raise RuntimeError( |
|
|
"[TensorRTExport] torch.export.Dim not available. " |
|
|
"Please upgrade PyTorch to >= 2.1 / 2.5+ to use the Dynamo-based " |
|
|
"ONNX exporter with dynamic shapes." |
|
|
) from e |
|
|
|
|
|
|
|
|
def trtlog(msg: str): |
|
|
print(f"[TensorRTExport] {msg}", flush=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_ONNX_OPSET = None |
|
|
_env_opset = os.getenv("COMFY_TRT_ONNX_OPSET") |
|
|
if _env_opset is not None: |
|
|
try: |
|
|
DEFAULT_ONNX_OPSET = int(_env_opset) |
|
|
trtlog(f"Using opset_version from COMFY_TRT_ONNX_OPSET={DEFAULT_ONNX_OPSET}") |
|
|
except ValueError: |
|
|
trtlog( |
|
|
f"WARNING: invalid COMFY_TRT_ONNX_OPSET={_env_opset!r}, " |
|
|
"falling back to PyTorch recommended opset (None)." |
|
|
) |
|
|
DEFAULT_ONNX_OPSET = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "tensorrt" in folder_paths.folder_names_and_paths: |
|
|
folder_paths.folder_names_and_paths["tensorrt"][0].append( |
|
|
os.path.join(folder_paths.get_output_directory(), "tensorrt") |
|
|
) |
|
|
folder_paths.folder_names_and_paths["tensorrt"][1].add(".engine") |
|
|
else: |
|
|
folder_paths.folder_names_and_paths["tensorrt"] = ( |
|
|
[os.path.join(folder_paths.get_output_directory(), "tensorrt")], |
|
|
{".engine"}, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TQDMProgressMonitor(trt.IProgressMonitor): |
|
|
def __init__(self): |
|
|
trt.IProgressMonitor.__init__(self) |
|
|
self._active_phases = {} |
|
|
self._step_result = True |
|
|
self.max_indent = 5 |
|
|
|
|
|
def phase_start(self, phase_name, parent_phase, num_steps): |
|
|
leave = False |
|
|
try: |
|
|
if parent_phase is not None: |
|
|
nbIndents = ( |
|
|
self._active_phases.get(parent_phase, {}).get( |
|
|
"nbIndents", self.max_indent |
|
|
) |
|
|
+ 1 |
|
|
) |
|
|
if nbIndents >= self.max_indent: |
|
|
return |
|
|
else: |
|
|
nbIndents = 0 |
|
|
leave = True |
|
|
self._active_phases[phase_name] = { |
|
|
"tq": tqdm( |
|
|
total=num_steps, desc=phase_name, leave=leave, position=nbIndents |
|
|
), |
|
|
"nbIndents": nbIndents, |
|
|
"parent_phase": parent_phase, |
|
|
} |
|
|
except KeyboardInterrupt: |
|
|
|
|
|
|
|
|
self._step_result = False |
|
|
|
|
|
def phase_finish(self, phase_name): |
|
|
try: |
|
|
if phase_name in self._active_phases.keys(): |
|
|
self._active_phases[phase_name]["tq"].update( |
|
|
self._active_phases[phase_name]["tq"].total |
|
|
- self._active_phases[phase_name]["tq"].n |
|
|
) |
|
|
|
|
|
parent_phase = self._active_phases[phase_name].get("parent_phase", None) |
|
|
while parent_phase is not None: |
|
|
self._active_phases[parent_phase]["tq"].refresh() |
|
|
parent_phase = self._active_phases[parent_phase].get( |
|
|
"parent_phase", None |
|
|
) |
|
|
if ( |
|
|
self._active_phases[phase_name]["parent_phase"] |
|
|
in self._active_phases.keys() |
|
|
): |
|
|
self._active_phases[ |
|
|
self._active_phases[phase_name]["parent_phase"] |
|
|
]["tq"].refresh() |
|
|
del self._active_phases[phase_name] |
|
|
except KeyboardInterrupt: |
|
|
self._step_result = False |
|
|
|
|
|
def step_complete(self, phase_name, step): |
|
|
try: |
|
|
if phase_name in self._active_phases.keys(): |
|
|
self._active_phases[phase_name]["tq"].update( |
|
|
step - self._active_phases[phase_name]["tq"].n |
|
|
) |
|
|
return self._step_result |
|
|
except KeyboardInterrupt: |
|
|
|
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TRT_MODEL_CONVERSION_BASE: |
|
|
def __init__(self): |
|
|
self.output_dir = folder_paths.get_output_directory() |
|
|
self.temp_dir = folder_paths.get_temp_directory() |
|
|
self.timing_cache_path = os.path.normpath( |
|
|
os.path.join( |
|
|
os.path.dirname(os.path.realpath(__file__)), "timing_cache.trt" |
|
|
) |
|
|
) |
|
|
|
|
|
RETURN_TYPES = () |
|
|
FUNCTION = "convert" |
|
|
OUTPUT_NODE = True |
|
|
CATEGORY = "TensorRT" |
|
|
|
|
|
@classmethod |
|
|
def INPUT_TYPES(s): |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
def _setup_timing_cache(self, config: trt.IBuilderConfig): |
|
|
buffer = b"" |
|
|
if os.path.exists(self.timing_cache_path): |
|
|
with open(self.timing_cache_path, mode="rb") as timing_cache_file: |
|
|
buffer = timing_cache_file.read() |
|
|
trtlog(f"Read {len(buffer)} bytes from timing cache.") |
|
|
else: |
|
|
trtlog("No timing cache found; initializing a new one.") |
|
|
timing_cache: trt.ITimingCache = config.create_timing_cache(buffer) |
|
|
config.set_timing_cache(timing_cache, ignore_mismatch=True) |
|
|
|
|
|
|
|
|
def _save_timing_cache(self, config: trt.IBuilderConfig): |
|
|
timing_cache: trt.ITimingCache = config.get_timing_cache() |
|
|
with open(self.timing_cache_path, "wb") as timing_cache_file: |
|
|
serialized = timing_cache.serialize() |
|
|
timing_cache_file.write(memoryview(serialized)) |
|
|
trtlog(f"Timing cache saved to {self.timing_cache_path}") |
|
|
|
|
|
def _convert( |
|
|
self, |
|
|
model, |
|
|
filename_prefix, |
|
|
batch_size_min, |
|
|
batch_size_opt, |
|
|
batch_size_max, |
|
|
height_min, |
|
|
height_opt, |
|
|
height_max, |
|
|
width_min, |
|
|
width_opt, |
|
|
width_max, |
|
|
context_min, |
|
|
context_opt, |
|
|
context_max, |
|
|
num_video_frames, |
|
|
is_static: bool, |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
trtlog( |
|
|
f"PyTorch version: {torch.__version__}, TensorRT version: {trt.__version__}" |
|
|
) |
|
|
trtlog( |
|
|
f"Requested {'STATIC' if is_static else 'DYNAMIC'} TensorRT engine " |
|
|
f"(b=[{batch_size_min},{batch_size_opt},{batch_size_max}], " |
|
|
f"h=[{height_min},{height_opt},{height_max}], " |
|
|
f"w=[{width_min},{width_opt},{width_max}], " |
|
|
f"context=[{context_min},{context_opt},{context_max}], " |
|
|
f"num_video_frames={num_video_frames})" |
|
|
) |
|
|
|
|
|
output_onnx = os.path.normpath( |
|
|
os.path.join( |
|
|
os.path.join(self.temp_dir, "{}".format(time.time())), "model.onnx" |
|
|
) |
|
|
) |
|
|
trtlog(f"Temporary ONNX path: {output_onnx}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
comfy.model_management.unload_all_models() |
|
|
comfy.model_management.load_models_gpu( |
|
|
[model], force_patch_weights=True, force_full_load=True |
|
|
) |
|
|
|
|
|
unet = model.model.diffusion_model |
|
|
model_type = type(model.model).__name__ |
|
|
trtlog(f"Detected model type: {model_type}") |
|
|
|
|
|
context_dim = model.model.model_config.unet_config.get("context_dim", None) |
|
|
context_len = 77 |
|
|
context_len_min = context_len |
|
|
y_dim = model.model.adm_channels |
|
|
extra_input = {} |
|
|
dtype = torch.float16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(model.model, comfy.model_base.SD3): |
|
|
context_embedder_config = model.model.model_config.unet_config.get( |
|
|
"context_embedder_config", None |
|
|
) |
|
|
if context_embedder_config is not None: |
|
|
context_dim = context_embedder_config.get( |
|
|
"params", {} |
|
|
).get("in_features", None) |
|
|
|
|
|
context_len = 154 |
|
|
trtlog(f"SD3 context_dim={context_dim}, context_len={context_len}") |
|
|
elif isinstance(model.model, comfy.model_base.AuraFlow): |
|
|
context_dim = 2048 |
|
|
context_len_min = 256 |
|
|
context_len = 256 |
|
|
trtlog( |
|
|
f"AuraFlow context_dim={context_dim}, " |
|
|
f"context_len_min={context_len_min}, context_len={context_len}" |
|
|
) |
|
|
elif isinstance(model.model, comfy.model_base.Flux): |
|
|
context_dim = model.model.model_config.unet_config.get( |
|
|
"context_in_dim", None |
|
|
) |
|
|
context_len_min = 256 |
|
|
context_len = 256 |
|
|
y_dim = model.model.model_config.unet_config.get("vec_in_dim", None) |
|
|
extra_input = {"guidance": ()} |
|
|
dtype = torch.bfloat16 |
|
|
trtlog( |
|
|
f"Flux context_dim={context_dim}, y_dim={y_dim}, " |
|
|
f"context_len_min={context_len_min}, context_len={context_len}, " |
|
|
f"extra_input={list(extra_input.keys())}, dtype={dtype}" |
|
|
) |
|
|
|
|
|
if context_dim is None: |
|
|
print("ERROR: model not supported (no context_dim).") |
|
|
comfy.model_management.unload_all_models() |
|
|
comfy.model_management.soft_empty_cache() |
|
|
return () |
|
|
|
|
|
input_names = ["x", "timesteps", "context"] |
|
|
output_names = ["h"] |
|
|
|
|
|
transformer_options = model.model_options["transformer_options"].copy() |
|
|
use_temporal = model.model.model_config.unet_config.get( |
|
|
"use_temporal_resblock", False |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if use_temporal: |
|
|
trtlog("Model uses temporal resblock (SVD-like). Adjusting batch sizes.") |
|
|
batch_size_min = num_video_frames * batch_size_min |
|
|
batch_size_opt = num_video_frames * batch_size_opt |
|
|
batch_size_max = num_video_frames * batch_size_max |
|
|
|
|
|
class SVD_UNET(torch.nn.Module): |
|
|
def __init__(self, unet, transformer_options, num_video_frames): |
|
|
super().__init__() |
|
|
self.unet = unet |
|
|
self.transformer_options = transformer_options |
|
|
self.num_video_frames = num_video_frames |
|
|
|
|
|
def forward(self, x, timesteps, context, y): |
|
|
return self.unet( |
|
|
x, |
|
|
timesteps, |
|
|
context, |
|
|
y, |
|
|
num_video_frames=self.num_video_frames, |
|
|
transformer_options=self.transformer_options, |
|
|
) |
|
|
|
|
|
unet = SVD_UNET(unet, transformer_options, num_video_frames) |
|
|
context_len_min = context_len = 1 |
|
|
trtlog( |
|
|
f"SVD adjusted batch: " |
|
|
f"b=[{batch_size_min},{batch_size_opt},{batch_size_max}], " |
|
|
f"context_len_min={context_len_min}, context_len={context_len}" |
|
|
) |
|
|
|
|
|
else: |
|
|
|
|
|
extra_keys = list(extra_input.keys()) |
|
|
|
|
|
class UNET(torch.nn.Module): |
|
|
def __init__(self, unet, transformer_options, y_dim, extra_keys): |
|
|
super().__init__() |
|
|
self.unet = unet |
|
|
self.transformer_options = transformer_options |
|
|
self.y_dim = y_dim |
|
|
self.extra_keys = extra_keys |
|
|
|
|
|
def forward(self, x, timesteps, context, y=None, guidance=None): |
|
|
extra_args = {} |
|
|
if self.y_dim is not None and self.y_dim > 0 and y is not None: |
|
|
extra_args["y"] = y |
|
|
if "guidance" in self.extra_keys and guidance is not None: |
|
|
extra_args["guidance"] = guidance |
|
|
|
|
|
return self.unet( |
|
|
x, |
|
|
timesteps, |
|
|
context, |
|
|
transformer_options=self.transformer_options, |
|
|
**extra_args, |
|
|
) |
|
|
|
|
|
unet = UNET(unet, transformer_options, y_dim, extra_keys) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_channels = model.model.model_config.unet_config.get("in_channels", 4) |
|
|
|
|
|
inputs_shapes_min = ( |
|
|
(batch_size_min, input_channels, height_min // 8, width_min // 8), |
|
|
(batch_size_min,), |
|
|
(batch_size_min, context_len_min * context_min, context_dim), |
|
|
) |
|
|
inputs_shapes_opt = ( |
|
|
(batch_size_opt, input_channels, height_opt // 8, width_opt // 8), |
|
|
(batch_size_opt,), |
|
|
(batch_size_opt, context_len * context_opt, context_dim), |
|
|
) |
|
|
inputs_shapes_max = ( |
|
|
(batch_size_max, input_channels, height_max // 8, width_max // 8), |
|
|
(batch_size_max,), |
|
|
(batch_size_max, context_len * context_max, context_dim), |
|
|
) |
|
|
|
|
|
if y_dim is not None and y_dim > 0: |
|
|
input_names.append("y") |
|
|
inputs_shapes_min += ((batch_size_min, y_dim),) |
|
|
inputs_shapes_opt += ((batch_size_opt, y_dim),) |
|
|
inputs_shapes_max += ((batch_size_max, y_dim),) |
|
|
|
|
|
|
|
|
for k in extra_input: |
|
|
input_names.append(k) |
|
|
shape_suffix = extra_input[k] |
|
|
inputs_shapes_min += ((batch_size_min,) + shape_suffix,) |
|
|
inputs_shapes_opt += ((batch_size_opt,) + shape_suffix,) |
|
|
inputs_shapes_max += ((batch_size_max,) + shape_suffix,) |
|
|
|
|
|
|
|
|
if context_max < context_min: |
|
|
trtlog( |
|
|
f"WARNING: context_max({context_max}) < context_min({context_min}), swapping." |
|
|
) |
|
|
context_min, context_max = context_max, context_min |
|
|
|
|
|
trtlog("Input names: " + ", ".join(input_names)) |
|
|
for idx, name in enumerate(input_names): |
|
|
trtlog( |
|
|
f" {name}: " |
|
|
f"min={inputs_shapes_min[idx]}, " |
|
|
f"opt={inputs_shapes_opt[idx]}, " |
|
|
f"max={inputs_shapes_max[idx]}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dynamic_shapes = None |
|
|
|
|
|
def _maybe_dim(name: str, min_v: int, max_v: int): |
|
|
"""Create Dim only if there is real dynamism (max > min).""" |
|
|
if max_v < min_v: |
|
|
trtlog( |
|
|
f"WARNING: Dim {name} has min>{max_v}>{min_v}, swapping to fix." |
|
|
) |
|
|
min_v, max_v = max_v, min_v |
|
|
if max_v > min_v: |
|
|
trtlog(f"Dim {name}: dynamic range [{min_v}, {max_v}]") |
|
|
return Dim(name, min=min_v, max=max_v) |
|
|
else: |
|
|
trtlog(f"Dim {name}: static value {min_v}, not using Dim.") |
|
|
return None |
|
|
|
|
|
if not is_static: |
|
|
|
|
|
B = _maybe_dim("batch", batch_size_min, batch_size_max) |
|
|
H = _maybe_dim("height", height_min // 8, height_max // 8) |
|
|
W = _maybe_dim("width", width_min // 8, width_max // 8) |
|
|
tokens_min = context_len_min * context_min |
|
|
tokens_max = context_len * context_max |
|
|
T = _maybe_dim("tokens", tokens_min, tokens_max) |
|
|
|
|
|
dynamic_shapes = {} |
|
|
|
|
|
|
|
|
x_dyn = {} |
|
|
if B is not None: |
|
|
x_dyn[0] = B |
|
|
if H is not None: |
|
|
x_dyn[2] = H |
|
|
if W is not None: |
|
|
x_dyn[3] = W |
|
|
if x_dyn: |
|
|
dynamic_shapes["x"] = x_dyn |
|
|
|
|
|
|
|
|
if B is not None: |
|
|
dynamic_shapes["timesteps"] = {0: B} |
|
|
|
|
|
|
|
|
ctx_dyn = {} |
|
|
if B is not None: |
|
|
ctx_dyn[0] = B |
|
|
if T is not None: |
|
|
ctx_dyn[1] = T |
|
|
if ctx_dyn: |
|
|
dynamic_shapes["context"] = ctx_dyn |
|
|
|
|
|
|
|
|
if "y" in input_names and B is not None: |
|
|
dynamic_shapes["y"] = {0: B} |
|
|
|
|
|
|
|
|
if "guidance" in input_names and B is not None: |
|
|
dynamic_shapes["guidance"] = {0: B} |
|
|
|
|
|
if not dynamic_shapes: |
|
|
trtlog( |
|
|
"No dimensions are actually dynamic for DYNAMIC node. " |
|
|
"Export will effectively be static." |
|
|
) |
|
|
dynamic_shapes = None |
|
|
else: |
|
|
trtlog(f"dynamic_shapes spec: {dynamic_shapes}") |
|
|
else: |
|
|
trtlog("STATIC node: skipping torch.export.Dim and dynamic_shapes entirely.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = () |
|
|
for shape in inputs_shapes_opt: |
|
|
inputs += ( |
|
|
torch.zeros( |
|
|
shape, |
|
|
device=comfy.model_management.get_torch_device(), |
|
|
dtype=dtype, |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
os.makedirs(os.path.dirname(output_onnx), exist_ok=True) |
|
|
|
|
|
trtlog( |
|
|
f"Exporting UNet to ONNX with dynamo=True, " |
|
|
f"opset_version={DEFAULT_ONNX_OPSET}, dtype={dtype}, " |
|
|
f"output={output_onnx}" |
|
|
) |
|
|
if dynamic_shapes is None: |
|
|
trtlog("ONNX export will be STATIC (no dynamic_shapes).") |
|
|
else: |
|
|
trtlog("ONNX export will use dynamic_shapes (see spec above).") |
|
|
|
|
|
try: |
|
|
torch.onnx.export( |
|
|
unet, |
|
|
inputs, |
|
|
output_onnx, |
|
|
verbose=False, |
|
|
input_names=input_names, |
|
|
output_names=output_names, |
|
|
opset_version=DEFAULT_ONNX_OPSET, |
|
|
dynamo=True, |
|
|
dynamic_shapes=dynamic_shapes, |
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
trtlog("torch.onnx.export completed successfully.") |
|
|
except Exception as e: |
|
|
trtlog(f"ERROR during torch.onnx.export: {e}") |
|
|
|
|
|
comfy.model_management.unload_all_models() |
|
|
comfy.model_management.soft_empty_cache() |
|
|
raise |
|
|
|
|
|
comfy.model_management.unload_all_models() |
|
|
comfy.model_management.soft_empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = trt.Logger(trt.Logger.INFO) |
|
|
builder = trt.Builder(logger) |
|
|
trtlog("Created TensorRT builder.") |
|
|
|
|
|
network = builder.create_network( |
|
|
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) |
|
|
) |
|
|
parser = trt.OnnxParser(network, logger) |
|
|
trtlog(f"Parsing ONNX file: {output_onnx}") |
|
|
success = parser.parse_from_file(output_onnx) |
|
|
for idx in range(parser.num_errors): |
|
|
print(parser.get_error(idx)) |
|
|
|
|
|
if not success: |
|
|
print("ONNX load ERROR (TensorRT parser.parse_from_file returned False).") |
|
|
return () |
|
|
|
|
|
config = builder.create_builder_config() |
|
|
profile = builder.create_optimization_profile() |
|
|
self._setup_timing_cache(config) |
|
|
config.progress_monitor = TQDMProgressMonitor() |
|
|
|
|
|
trtlog("Creating optimization profile:") |
|
|
prefix_encode = "" |
|
|
for k in range(len(input_names)): |
|
|
min_shape = inputs_shapes_min[k] |
|
|
opt_shape = inputs_shapes_opt[k] |
|
|
max_shape = inputs_shapes_max[k] |
|
|
trtlog( |
|
|
f" {input_names[k]}: min={min_shape}, opt={opt_shape}, max={max_shape}" |
|
|
) |
|
|
profile.set_shape(input_names[k], min_shape, opt_shape, max_shape) |
|
|
|
|
|
|
|
|
encode = lambda a: ".".join(map(str, a)) |
|
|
prefix_encode += "{}#{}#{}#{};".format( |
|
|
input_names[k], encode(min_shape), encode(opt_shape), encode(max_shape) |
|
|
) |
|
|
|
|
|
if dtype == torch.float16: |
|
|
trtlog("Enabling FP16 mode in TensorRT builder config.") |
|
|
config.set_flag(trt.BuilderFlag.FP16) |
|
|
if dtype == torch.bfloat16: |
|
|
trtlog("Enabling BF16 mode in TensorRT builder config.") |
|
|
config.set_flag(trt.BuilderFlag.BF16) |
|
|
|
|
|
config.add_optimization_profile(profile) |
|
|
|
|
|
if is_static: |
|
|
filename_prefix = "{}_${}".format( |
|
|
filename_prefix, |
|
|
"-".join( |
|
|
( |
|
|
"stat", |
|
|
"b", |
|
|
str(batch_size_opt), |
|
|
"h", |
|
|
str(height_opt), |
|
|
"w", |
|
|
str(width_opt), |
|
|
) |
|
|
), |
|
|
) |
|
|
else: |
|
|
filename_prefix = "{}_${}".format( |
|
|
filename_prefix, |
|
|
"-".join( |
|
|
( |
|
|
"dyn", |
|
|
"b", |
|
|
str(batch_size_min), |
|
|
str(batch_size_max), |
|
|
str(batch_size_opt), |
|
|
"h", |
|
|
str(height_min), |
|
|
str(height_max), |
|
|
str(height_opt), |
|
|
"w", |
|
|
str(width_min), |
|
|
str(width_max), |
|
|
str(width_opt), |
|
|
) |
|
|
), |
|
|
) |
|
|
|
|
|
trtlog("Building serialized TensorRT engine. This may take a while...") |
|
|
serialized_engine = builder.build_serialized_network(network, config) |
|
|
if serialized_engine is None: |
|
|
trtlog("ERROR: builder.build_serialized_network returned None.") |
|
|
return () |
|
|
|
|
|
full_output_folder, filename, counter, subfolder, filename_prefix = ( |
|
|
folder_paths.get_save_image_path(filename_prefix, self.output_dir) |
|
|
) |
|
|
output_trt_engine = os.path.join( |
|
|
full_output_folder, f"{filename}_{counter:05}_.engine" |
|
|
) |
|
|
|
|
|
trtlog(f"Writing TensorRT engine to: {output_trt_engine}") |
|
|
os.makedirs(full_output_folder, exist_ok=True) |
|
|
with open(output_trt_engine, "wb") as f: |
|
|
f.write(serialized_engine) |
|
|
|
|
|
self._save_timing_cache(config) |
|
|
trtlog("TensorRT conversion complete.") |
|
|
|
|
|
return () |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DYNAMIC_TRT_MODEL_CONVERSION(TRT_MODEL_CONVERSION_BASE): |
|
|
def __init__(self): |
|
|
super(DYNAMIC_TRT_MODEL_CONVERSION, self).__init__() |
|
|
|
|
|
@classmethod |
|
|
def INPUT_TYPES(s): |
|
|
return { |
|
|
"required": { |
|
|
"model": ("MODEL",), |
|
|
"filename_prefix": ("STRING", {"default": "tensorrt/ComfyUI_DYN"}), |
|
|
"batch_size_min": ( |
|
|
"INT", |
|
|
{ |
|
|
"default": 1, |
|
|
"min": 1, |
|
|
"max": 100, |
|
|
"step": 1, |
|
|
}, |
|
|
), |
|
|
"batch_size_opt": ( |
|
|
"INT", |
|
|
{ |
|
|
"default": 1, |
|
|
"min": 1, |
|
|
"max": 100, |
|
|
"step": 1, |
|
|
}, |
|
|
), |
|
|
"batch_size_max": ( |
|
|
"INT", |
|
|
{ |
|
|
"default": 1, |
|
|
"min": 1, |
|
|
"max": 100, |
|
|
"step": 1, |
|
|
}, |
|
|
), |
|
|
"height_min": ( |
|
|
"INT", |
|
|
{ |
|
|
"default": 512, |
|
|
"min": 256, |
|
|
"max": 4096, |
|
|
"step": 64, |
|
|
}, |
|
|
), |
|
|
"height_opt": ( |
|
|
"INT", |
|
|
{ |
|
|
"default": 512, |
|
|
"min": 256, |
|
|
"max": 4096, |
|
|
"step": 64, |
|
|
}, |
|
|
), |
|
|
"height_max": ( |
|
|
"INT", |
|
|
{ |
|
|
"default": 512, |
|
|
"min": 256, |
|
|
"max": 4096, |
|
|
"step": 64, |
|
|
}, |
|
|
), |
|
|
"width_min": ( |
|
|
"INT", |
|
|
{ |
|
|
"default": 512, |
|
|
"min": 256, |
|
|
"max": 4096, |
|
|
"step": 64, |
|
|
}, |
|
|
), |
|
|
"width_opt": ( |
|
|
"INT", |
|
|
{ |
|
|
"default": 512, |
|
|
"min": 256, |
|
|
"max": 4096, |
|
|
"step": 64, |
|
|
}, |
|
|
), |
|
|
"width_max": ( |
|
|
"INT", |
|
|
{ |
|
|
"default": 512, |
|
|
"min": 256, |
|
|
"max": 4096, |
|
|
"step": 64, |
|
|
}, |
|
|
), |
|
|
"context_min": ( |
|
|
"INT", |
|
|
{ |
|
|
"default": 1, |
|
|
"min": 1, |
|
|
"max": 128, |
|
|
"step": 1, |
|
|
}, |
|
|
), |
|
|
"context_opt": ( |
|
|
"INT", |
|
|
{ |
|
|
"default": 1, |
|
|
"min": 1, |
|
|
"max": 128, |
|
|
"step": 1, |
|
|
}, |
|
|
), |
|
|
"context_max": ( |
|
|
"INT", |
|
|
{ |
|
|
"default": 1, |
|
|
"min": 1, |
|
|
"max": 128, |
|
|
"step": 1, |
|
|
}, |
|
|
), |
|
|
"num_video_frames": ( |
|
|
"INT", |
|
|
{ |
|
|
"default": 14, |
|
|
"min": 0, |
|
|
"max": 1000, |
|
|
"step": 1, |
|
|
}, |
|
|
), |
|
|
}, |
|
|
} |
|
|
|
|
|
def convert( |
|
|
self, |
|
|
model, |
|
|
filename_prefix, |
|
|
batch_size_min, |
|
|
batch_size_opt, |
|
|
batch_size_max, |
|
|
height_min, |
|
|
height_opt, |
|
|
height_max, |
|
|
width_min, |
|
|
width_opt, |
|
|
width_max, |
|
|
context_min, |
|
|
context_opt, |
|
|
context_max, |
|
|
num_video_frames, |
|
|
): |
|
|
return super()._convert( |
|
|
model, |
|
|
filename_prefix, |
|
|
batch_size_min, |
|
|
batch_size_opt, |
|
|
batch_size_max, |
|
|
height_min, |
|
|
height_opt, |
|
|
height_max, |
|
|
width_min, |
|
|
width_opt, |
|
|
width_max, |
|
|
context_min, |
|
|
context_opt, |
|
|
context_max, |
|
|
num_video_frames, |
|
|
is_static=False, |
|
|
) |
|
|
|
|
|
|
|
|
class STATIC_TRT_MODEL_CONVERSION(TRT_MODEL_CONVERSION_BASE): |
|
|
def __init__(self): |
|
|
super(STATIC_TRT_MODEL_CONVERSION, self).__init__() |
|
|
|
|
|
@classmethod |
|
|
def INPUT_TYPES(s): |
|
|
return { |
|
|
"required": { |
|
|
"model": ("MODEL",), |
|
|
"filename_prefix": ("STRING", {"default": "tensorrt/ComfyUI_STAT"}), |
|
|
"batch_size_opt": ( |
|
|
"INT", |
|
|
{ |
|
|
"default": 1, |
|
|
"min": 1, |
|
|
"max": 100, |
|
|
"step": 1, |
|
|
}, |
|
|
), |
|
|
"height_opt": ( |
|
|
"INT", |
|
|
{ |
|
|
"default": 512, |
|
|
"min": 256, |
|
|
"max": 4096, |
|
|
"step": 64, |
|
|
}, |
|
|
), |
|
|
"width_opt": ( |
|
|
"INT", |
|
|
{ |
|
|
"default": 512, |
|
|
"min": 256, |
|
|
"max": 4096, |
|
|
"step": 64, |
|
|
}, |
|
|
), |
|
|
"context_opt": ( |
|
|
"INT", |
|
|
{ |
|
|
"default": 1, |
|
|
"min": 1, |
|
|
"max": 128, |
|
|
"step": 1, |
|
|
}, |
|
|
), |
|
|
"num_video_frames": ( |
|
|
"INT", |
|
|
{ |
|
|
"default": 14, |
|
|
"min": 0, |
|
|
"max": 1000, |
|
|
"step": 1, |
|
|
}, |
|
|
), |
|
|
}, |
|
|
} |
|
|
|
|
|
def convert( |
|
|
self, |
|
|
model, |
|
|
filename_prefix, |
|
|
batch_size_opt, |
|
|
height_opt, |
|
|
width_opt, |
|
|
context_opt, |
|
|
num_video_frames, |
|
|
): |
|
|
|
|
|
return super()._convert( |
|
|
model, |
|
|
filename_prefix, |
|
|
batch_size_opt, |
|
|
batch_size_opt, |
|
|
batch_size_opt, |
|
|
height_opt, |
|
|
height_opt, |
|
|
height_opt, |
|
|
width_opt, |
|
|
width_opt, |
|
|
width_opt, |
|
|
context_opt, |
|
|
context_opt, |
|
|
context_opt, |
|
|
num_video_frames, |
|
|
is_static=True, |
|
|
) |
|
|
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
|
"DYNAMIC_TRT_MODEL_CONVERSION": DYNAMIC_TRT_MODEL_CONVERSION, |
|
|
"STATIC_TRT_MODEL_CONVERSION": STATIC_TRT_MODEL_CONVERSION, |
|
|
} |
|
|
|