Upload 2 files
Browse files- tensorrt_convert.py +1 -0
- tensorrt_loader.py +91 -85
tensorrt_convert.py
CHANGED
|
@@ -287,6 +287,7 @@ class TRT_MODEL_CONVERSION_BASE:
|
|
| 287 |
input_names=input_names,
|
| 288 |
output_names=output_names,
|
| 289 |
opset_version=17,
|
|
|
|
| 290 |
dynamo=False, # <— force legacy ONNX exporter, no torch.export/dynamic_shapes
|
| 291 |
)
|
| 292 |
|
|
|
|
| 287 |
input_names=input_names,
|
| 288 |
output_names=output_names,
|
| 289 |
opset_version=17,
|
| 290 |
+
dynamic_axes=dynamic_axes, # KEEP dynamic axes
|
| 291 |
dynamo=False, # <— force legacy ONNX exporter, no torch.export/dynamic_shapes
|
| 292 |
)
|
| 293 |
|
tensorrt_loader.py
CHANGED
|
@@ -43,108 +43,113 @@ def trt_datatype_to_torch(datatype):
|
|
| 43 |
class TrTUnet:
|
| 44 |
def __init__(self, engine_path):
|
| 45 |
with open(engine_path, "rb") as f:
|
| 46 |
-
|
| 47 |
-
self.engine = runtime.deserialize_cuda_engine(engine_bytes)
|
| 48 |
self.context = self.engine.create_execution_context()
|
| 49 |
-
# Default precision – overridden to bfloat16 for Flux in TensorRTLoader
|
| 50 |
-
self.dtype = torch.float16
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
"""
|
| 55 |
-
x: [B, C, H, W]
|
| 56 |
-
timesteps: [B]
|
| 57 |
-
context: [B,
|
| 58 |
-
y: [B,
|
| 59 |
-
Other kwargs (control, transformer_options, guidance, ...) are ignored
|
| 60 |
-
at TensorRT level, but must be accepted to match Comfy's callsite.
|
| 61 |
"""
|
| 62 |
|
| 63 |
-
#
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
x = _prep(x)
|
| 73 |
-
timesteps = _prep(timesteps)
|
| 74 |
-
context = _prep(context)
|
| 75 |
-
y = _prep(y)
|
| 76 |
-
|
| 77 |
-
# Discover engine IO tensors
|
| 78 |
-
tensor_names = [
|
| 79 |
-
self.engine.get_tensor_name(i)
|
| 80 |
-
for i in range(self.engine.num_io_tensors)
|
| 81 |
-
]
|
| 82 |
-
input_names = [
|
| 83 |
-
n for n in tensor_names
|
| 84 |
-
if self.engine.get_tensor_mode(n) == trt.TensorIOMode.INPUT
|
| 85 |
-
]
|
| 86 |
-
output_names = [
|
| 87 |
-
n for n in tensor_names
|
| 88 |
-
if self.engine.get_tensor_mode(n) == trt.TensorIOMode.OUTPUT
|
| 89 |
-
]
|
| 90 |
-
|
| 91 |
-
# Build a dict of available tensors by name
|
| 92 |
-
available = {"x": x, "timesteps": timesteps, "context": context}
|
| 93 |
if y is not None:
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
#
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
)
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
self.context.set_input_shape(name, tuple(t.shape))
|
| 116 |
-
self.context.set_tensor_address(name, t.data_ptr())
|
| 117 |
|
| 118 |
-
#
|
| 119 |
missing = self.context.infer_shapes()
|
| 120 |
if missing:
|
| 121 |
-
raise RuntimeError(
|
| 122 |
-
f"TensorRT shape inference failed, unresolved tensors: {missing}"
|
| 123 |
-
)
|
| 124 |
-
|
| 125 |
-
# Ensure the context has enough device memory for the resolved shapes
|
| 126 |
-
self.context.update_device_memory_size_for_shapes()
|
| 127 |
|
| 128 |
-
#
|
| 129 |
-
|
|
|
|
|
|
|
| 130 |
for name in output_names:
|
| 131 |
-
out_dims = self.context.get_tensor_shape(name)
|
| 132 |
out_shape = tuple(int(d) for d in out_dims)
|
| 133 |
-
out_dtype = trt_datatype_to_torch(self.engine.get_tensor_dtype(name))
|
| 134 |
-
out_tensor = torch.empty(out_shape, device=device, dtype=out_dtype)
|
| 135 |
-
self.context.set_tensor_address(name, out_tensor.data_ptr())
|
| 136 |
-
outputs.append(out_tensor)
|
| 137 |
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
self.context.execute_async_v3(stream_handle=stream)
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
def load_state_dict(self, sd, strict=False):
|
| 146 |
-
|
| 147 |
-
return
|
| 148 |
|
| 149 |
def state_dict(self):
|
| 150 |
return {}
|
|
@@ -152,6 +157,7 @@ class TrTUnet:
|
|
| 152 |
|
| 153 |
|
| 154 |
|
|
|
|
| 155 |
class TensorRTLoader:
|
| 156 |
@classmethod
|
| 157 |
def INPUT_TYPES(s):
|
|
|
|
| 43 |
class TrTUnet:
|
| 44 |
def __init__(self, engine_path):
|
| 45 |
with open(engine_path, "rb") as f:
|
| 46 |
+
self.engine = runtime.deserialize_cuda_engine(f.read())
|
|
|
|
| 47 |
self.context = self.engine.create_execution_context()
|
|
|
|
|
|
|
| 48 |
|
| 49 |
+
# Default torch device / dtype for allocations
|
| 50 |
+
self.device = comfy.model_management.get_torch_device()
|
| 51 |
+
self.default_dtype = torch.float16 # fallback if something unknown shows up
|
| 52 |
+
|
| 53 |
+
def _trt_dtype_to_torch(self, trt_dtype):
|
| 54 |
+
dt = trt_datatype_to_torch(trt_dtype)
|
| 55 |
+
return dt if dt is not None else self.default_dtype
|
| 56 |
+
|
| 57 |
+
def __call__(self, x, timesteps, context, y=None, control=None, transformer_options=None, **kwargs):
|
| 58 |
"""
|
| 59 |
+
x : [B, C, H, W]
|
| 60 |
+
timesteps : [B]
|
| 61 |
+
context : [B, N, D]
|
| 62 |
+
y : [B, y_dim] (optional, SDXL etc.)
|
|
|
|
|
|
|
| 63 |
"""
|
| 64 |
|
| 65 |
+
# -----------------------------
|
| 66 |
+
# 1. Build dict of actual inputs
|
| 67 |
+
# -----------------------------
|
| 68 |
+
model_inputs = {
|
| 69 |
+
"x": x,
|
| 70 |
+
"timesteps": timesteps,
|
| 71 |
+
"context": context,
|
| 72 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
if y is not None:
|
| 74 |
+
model_inputs["y"] = y
|
| 75 |
+
|
| 76 |
+
# If your engine has extra inputs (e.g. 'guidance' for Flux),
|
| 77 |
+
# they must either come from kwargs or be absent from the engine.
|
| 78 |
+
tensor_names = [self.engine.get_tensor_name(i) for i in range(self.engine.num_io_tensors)]
|
| 79 |
+
input_names = [n for n in tensor_names if self.engine.get_tensor_mode(n) == trt.TensorIOMode.INPUT]
|
| 80 |
+
output_names = [n for n in tensor_names if self.engine.get_tensor_mode(n) == trt.TensorIOMode.OUTPUT]
|
| 81 |
+
|
| 82 |
+
# Fill missing inputs from kwargs if present
|
| 83 |
+
for name in input_names:
|
| 84 |
+
if name in model_inputs:
|
| 85 |
+
continue
|
| 86 |
+
if name in kwargs:
|
| 87 |
+
model_inputs[name] = kwargs[name]
|
| 88 |
+
|
| 89 |
+
if len(model_inputs) != len(input_names):
|
| 90 |
+
missing = [n for n in input_names if n not in model_inputs]
|
| 91 |
+
raise RuntimeError(
|
| 92 |
+
f"TensorRT UNet: missing required inputs for engine: {missing} "
|
| 93 |
+
f"(have {list(model_inputs.keys())})"
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
# -----------------------------
|
| 97 |
+
# 2. Convert each input to engine dtype + bind it
|
| 98 |
+
# -----------------------------
|
| 99 |
+
for name in input_names:
|
| 100 |
+
t = model_inputs[name]
|
| 101 |
+
|
| 102 |
+
# Move to correct device
|
| 103 |
+
if t.device != self.device:
|
| 104 |
+
t = t.to(self.device)
|
| 105 |
+
|
| 106 |
+
# Match TensorRT's expected dtype for this tensor
|
| 107 |
+
trt_dtype = self.engine.get_tensor_dtype(name)
|
| 108 |
+
torch_dtype = self._trt_dtype_to_torch(trt_dtype)
|
| 109 |
+
if t.dtype != torch_dtype:
|
| 110 |
+
t = t.to(dtype=torch_dtype)
|
| 111 |
+
|
| 112 |
+
# Update back (so later code sees the converted tensor if needed)
|
| 113 |
+
model_inputs[name] = t
|
| 114 |
+
|
| 115 |
+
# Set runtime shape and bind memory
|
| 116 |
self.context.set_input_shape(name, tuple(t.shape))
|
| 117 |
+
self.context.set_tensor_address(name, int(t.data_ptr()))
|
| 118 |
|
| 119 |
+
# Make sure all shapes are resolved
|
| 120 |
missing = self.context.infer_shapes()
|
| 121 |
if missing:
|
| 122 |
+
raise RuntimeError(f"TensorRT shape inference failed, unresolved tensors: {missing}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
+
# -----------------------------
|
| 125 |
+
# 3. Allocate & bind outputs
|
| 126 |
+
# -----------------------------
|
| 127 |
+
outputs = {}
|
| 128 |
for name in output_names:
|
| 129 |
+
out_dims = self.context.get_tensor_shape(name) # trt.Dims
|
| 130 |
out_shape = tuple(int(d) for d in out_dims)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
+
trt_dtype = self.engine.get_tensor_dtype(name)
|
| 133 |
+
torch_dtype = self._trt_dtype_to_torch(trt_dtype)
|
|
|
|
| 134 |
|
| 135 |
+
out_tensor = torch.empty(out_shape, device=self.device, dtype=torch_dtype)
|
| 136 |
+
self.context.set_tensor_address(name, int(out_tensor.data_ptr()))
|
| 137 |
+
outputs[name] = out_tensor
|
| 138 |
+
|
| 139 |
+
# -----------------------------
|
| 140 |
+
# 4. Execute on the current torch CUDA stream
|
| 141 |
+
# -----------------------------
|
| 142 |
+
stream = torch.cuda.current_stream(self.device)
|
| 143 |
+
self.context.execute_async_v3(stream_handle=stream.cuda_stream)
|
| 144 |
+
|
| 145 |
+
# No need to sync explicitly; ComfyUI uses the same default stream.
|
| 146 |
+
|
| 147 |
+
# Return outputs in a stable order
|
| 148 |
+
out_list = [outputs[name] for name in output_names]
|
| 149 |
+
return out_list[0] if len(out_list) == 1 else tuple(out_list)
|
| 150 |
|
| 151 |
def load_state_dict(self, sd, strict=False):
|
| 152 |
+
pass
|
|
|
|
| 153 |
|
| 154 |
def state_dict(self):
|
| 155 |
return {}
|
|
|
|
| 157 |
|
| 158 |
|
| 159 |
|
| 160 |
+
|
| 161 |
class TensorRTLoader:
|
| 162 |
@classmethod
|
| 163 |
def INPUT_TYPES(s):
|