saliacoel commited on
Commit
c6054f8
·
verified ·
1 Parent(s): 1306c71

Upload tensorrt_loader.py

Browse files
Files changed (1) hide show
  1. tensorrt_loader.py +77 -96
tensorrt_loader.py CHANGED
@@ -1,5 +1,3 @@
1
- #Put this in the custom_nodes folder, put your tensorrt engine files in ComfyUI/models/tensorrt/ (you will have to create the directory)
2
-
3
  import torch
4
  import os
5
 
@@ -24,59 +22,59 @@ trt.init_libnvinfer_plugins(None, "")
24
  logger = trt.Logger(trt.Logger.INFO)
25
  runtime = trt.Runtime(logger)
26
 
27
- # Is there a function that already exists for this?
28
  def trt_datatype_to_torch(datatype):
29
- if datatype == trt.float16:
 
30
  return torch.float16
31
- elif datatype == trt.float32:
32
  return torch.float32
33
- elif datatype == trt.int32:
34
- return torch.int32
35
- elif datatype == trt.bfloat16:
 
36
  return torch.bfloat16
 
 
 
 
 
37
 
38
  class TrTUnet:
39
  def __init__(self, engine_path):
40
  with open(engine_path, "rb") as f:
41
- self.engine = runtime.deserialize_cuda_engine(f.read())
 
42
  self.context = self.engine.create_execution_context()
43
- # default dtype in case something doesn't have a specific TRT dtype
44
  self.dtype = torch.float16
45
 
46
- def set_bindings_shape(self, inputs, split_batch):
47
- # still here in case something else calls it, but the new __call__
48
- # no longer uses this split-batch path
49
- for k in inputs:
50
- shape = inputs[k].shape
51
- shape = [shape[0] // split_batch] + list(shape[1:])
52
- self.context.set_input_shape(k, shape)
53
-
54
  def __call__(self, x, timesteps, context, y=None,
55
  control=None, transformer_options=None, **kwargs):
56
  """
57
- Run the TensorRT UNet.
58
-
59
- - `control` and `transformer_options` are accepted for API compatibility
60
- with Comfy, but ignored by the TRT engine.
61
- - Any extra tensor inputs (e.g. `guidance` for Flux) are taken from
62
- **kwargs and matched by name to the engine’s input tensors.
63
  """
64
 
65
- # Collect all tensors we might need by name
66
- available = {
67
- "x": x,
68
- "timesteps": timesteps,
69
- "context": context,
70
- }
71
- if y is not None:
72
- available["y"] = y
73
 
74
- # Extra conds (e.g. 'guidance', etc.) may come in via kwargs
75
- for name, value in kwargs.items():
76
- if isinstance(value, torch.Tensor):
77
- available[name] = value
 
78
 
79
- # Query engine IO tensors
 
 
 
 
 
80
  tensor_names = [
81
  self.engine.get_tensor_name(i)
82
  for i in range(self.engine.num_io_tensors)
@@ -90,87 +88,70 @@ class TrTUnet:
90
  if self.engine.get_tensor_mode(n) == trt.TensorIOMode.OUTPUT
91
  ]
92
 
93
- # Sanity check: we must have a tensor for every input
94
- missing = [n for n in input_names if n not in available]
95
- if missing:
96
- raise RuntimeError(
97
- f"Missing tensors for TensorRT engine inputs: {missing}. "
98
- f"Available: {list(available.keys())}"
99
- )
100
-
101
- device = x.device
102
-
103
- # Bind inputs: fix dtype + device, set shapes and addresses
104
- for name in input_names:
105
- t = available[name]
106
-
107
- if not t.is_contiguous():
108
- t = t.contiguous()
109
 
110
- # Match engine dtype
111
- trt_dtype = self.engine.get_tensor_dtype(name)
112
- torch_dtype = trt_datatype_to_torch(trt_dtype)
113
- if torch_dtype is None:
 
 
 
 
 
 
 
 
 
 
 
114
  raise RuntimeError(
115
- f"Unsupported TensorRT dtype {trt_dtype} for input '{name}'"
116
  )
117
-
118
- if t.dtype != torch_dtype:
119
- t = t.to(dtype=torch_dtype)
120
-
121
- if t.device != device:
122
- t = t.to(device)
123
-
124
- # Save back in case we changed it
125
- available[name] = t
126
-
127
- # Tell TRT the runtime shape and bind the memory
128
  self.context.set_input_shape(name, tuple(t.shape))
129
  self.context.set_tensor_address(name, t.data_ptr())
130
 
131
- # Let TRT resolve all dynamic shapes (outputs etc.)
132
- unresolved = self.context.infer_shapes()
133
- if unresolved:
134
  raise RuntimeError(
135
- f"TensorRT shape inference failed, unresolved tensors: {unresolved}"
136
  )
137
 
 
 
 
138
  # Allocate and bind outputs
139
  outputs = []
140
  for name in output_names:
141
- dims = self.context.get_tensor_shape(name) # trt.Dims
 
 
 
 
 
142
 
143
- # Guard against the old nbDims == -1 issue
144
- if hasattr(dims, "nb_dims") and dims.nb_dims < 0:
145
- raise RuntimeError(f"Output '{name}' has invalid dims: {dims}")
146
 
147
- shape = [int(d) for d in dims]
148
- trt_dtype = self.engine.get_tensor_dtype(name)
149
- torch_dtype = trt_datatype_to_torch(trt_dtype)
150
-
151
- out = torch.empty(shape, device=device, dtype=torch_dtype)
152
- self.context.set_tensor_address(name, out.data_ptr())
153
- outputs.append(out)
154
-
155
- # Run on the default torch CUDA stream
156
- stream = torch.cuda.default_stream(device)
157
- self.context.execute_async_v3(stream_handle=stream.cuda_stream)
158
-
159
- # Return single tensor or a tuple
160
- if len(outputs) == 1:
161
- return outputs[0]
162
- return tuple(outputs)
163
 
164
  def load_state_dict(self, sd, strict=False):
165
- # Nothing to load for a serialized TensorRT engine
166
- pass
167
 
168
  def state_dict(self):
169
- # Keep API compatible with nn.Module
170
  return {}
171
 
172
 
173
 
 
174
  class TensorRTLoader:
175
  @classmethod
176
  def INPUT_TYPES(s):
 
 
 
1
  import torch
2
  import os
3
 
 
22
  logger = trt.Logger(trt.Logger.INFO)
23
  runtime = trt.Runtime(logger)
24
 
25
+
26
  def trt_datatype_to_torch(datatype):
27
+ # Works for TRT 8/9/10
28
+ if datatype in (getattr(trt, "float16", None), getattr(trt.DataType, "HALF", None)):
29
  return torch.float16
30
+ if datatype in (getattr(trt, "float32", None), getattr(trt.DataType, "FLOAT", None)):
31
  return torch.float32
32
+ if hasattr(trt, "bfloat16") and datatype in (
33
+ getattr(trt, "bfloat16", None),
34
+ getattr(trt.DataType, "BF16", None),
35
+ ):
36
  return torch.bfloat16
37
+ if datatype in (getattr(trt, "int32", None), getattr(trt.DataType, "INT32", None)):
38
+ return torch.int32
39
+ # Fallback – shouldn't normally hit this for UNets
40
+ return torch.float32
41
+
42
 
43
  class TrTUnet:
44
  def __init__(self, engine_path):
45
  with open(engine_path, "rb") as f:
46
+ engine_bytes = f.read()
47
+ self.engine = runtime.deserialize_cuda_engine(engine_bytes)
48
  self.context = self.engine.create_execution_context()
49
+ # Default precision overridden to bfloat16 for Flux in TensorRTLoader
50
  self.dtype = torch.float16
51
 
 
 
 
 
 
 
 
 
52
  def __call__(self, x, timesteps, context, y=None,
53
  control=None, transformer_options=None, **kwargs):
54
  """
55
+ x: [B, C, H, W]
56
+ timesteps: [B]
57
+ context: [B, T, Ctxt]
58
+ y: [B, adm_dim] (SDXL / SD3 / etc.)
59
+ Other kwargs (control, transformer_options, guidance, ...) are ignored
60
+ at TensorRT level, but must be accepted to match Comfy's callsite.
61
  """
62
 
63
+ # Use latent device as canonical device
64
+ device = x.device
 
 
 
 
 
 
65
 
66
+ # Helper to put everything on the right device / dtype and contiguous
67
+ def _prep(t):
68
+ if t is None:
69
+ return None
70
+ return t.to(device=device, dtype=self.dtype).contiguous()
71
 
72
+ x = _prep(x)
73
+ timesteps = _prep(timesteps)
74
+ context = _prep(context)
75
+ y = _prep(y)
76
+
77
+ # Discover engine IO tensors
78
  tensor_names = [
79
  self.engine.get_tensor_name(i)
80
  for i in range(self.engine.num_io_tensors)
 
88
  if self.engine.get_tensor_mode(n) == trt.TensorIOMode.OUTPUT
89
  ]
90
 
91
+ # Build a dict of available tensors by name
92
+ available = {"x": x, "timesteps": timesteps, "context": context}
93
+ if y is not None:
94
+ available["y"] = y
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ # Allow passing extra inputs (e.g. "guidance" for Flux) via kwargs
97
+ for k, v in kwargs.items():
98
+ if isinstance(v, torch.Tensor):
99
+ available[k] = _prep(v)
100
+
101
+ # Canonical order, so we never accidentally swap x/timesteps/context/y
102
+ canonical_order = {"x": 0, "timesteps": 1, "context": 2, "y": 3}
103
+ input_names_sorted = sorted(
104
+ input_names,
105
+ key=lambda n: canonical_order.get(n, 100),
106
+ )
107
+
108
+ # Bind all inputs – every engine input must get a valid tensor
109
+ for name in input_names_sorted:
110
+ if name not in available or available[name] is None:
111
  raise RuntimeError(
112
+ f"TensorRT engine expects input '{name}' but no tensor was provided."
113
  )
114
+ t = available[name]
 
 
 
 
 
 
 
 
 
 
115
  self.context.set_input_shape(name, tuple(t.shape))
116
  self.context.set_tensor_address(name, t.data_ptr())
117
 
118
+ # Infer shapes (resolve dynamic dims)
119
+ missing = self.context.infer_shapes()
120
+ if missing:
121
  raise RuntimeError(
122
+ f"TensorRT shape inference failed, unresolved tensors: {missing}"
123
  )
124
 
125
+ # Ensure the context has enough device memory for the resolved shapes
126
+ self.context.update_device_memory_size_for_shapes()
127
+
128
  # Allocate and bind outputs
129
  outputs = []
130
  for name in output_names:
131
+ out_dims = self.context.get_tensor_shape(name)
132
+ out_shape = tuple(int(d) for d in out_dims)
133
+ out_dtype = trt_datatype_to_torch(self.engine.get_tensor_dtype(name))
134
+ out_tensor = torch.empty(out_shape, device=device, dtype=out_dtype)
135
+ self.context.set_tensor_address(name, out_tensor.data_ptr())
136
+ outputs.append(out_tensor)
137
 
138
+ # Execute on the current PyTorch stream for correct ordering
139
+ stream = torch.cuda.current_stream(device).cuda_stream
140
+ self.context.execute_async_v3(stream_handle=stream)
141
 
142
+ # Comfy's apply_model() will call .float() on this anyway
143
+ return outputs[0] if len(outputs) == 1 else tuple(outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  def load_state_dict(self, sd, strict=False):
146
+ # No-op weights are inside the TensorRT engine file.
147
+ return
148
 
149
  def state_dict(self):
 
150
  return {}
151
 
152
 
153
 
154
+
155
  class TensorRTLoader:
156
  @classmethod
157
  def INPUT_TYPES(s):