saliacoel commited on
Commit
e24ff1f
·
verified ·
1 Parent(s): c6054f8

Upload 2 files

Browse files
Files changed (2) hide show
  1. tensorrt_convert.py +1 -0
  2. tensorrt_loader.py +91 -85
tensorrt_convert.py CHANGED
@@ -287,6 +287,7 @@ class TRT_MODEL_CONVERSION_BASE:
287
  input_names=input_names,
288
  output_names=output_names,
289
  opset_version=17,
 
290
  dynamo=False, # <— force legacy ONNX exporter, no torch.export/dynamic_shapes
291
  )
292
 
 
287
  input_names=input_names,
288
  output_names=output_names,
289
  opset_version=17,
290
+ dynamic_axes=dynamic_axes, # KEEP dynamic axes
291
  dynamo=False, # <— force legacy ONNX exporter, no torch.export/dynamic_shapes
292
  )
293
 
tensorrt_loader.py CHANGED
@@ -43,108 +43,113 @@ def trt_datatype_to_torch(datatype):
43
  class TrTUnet:
44
  def __init__(self, engine_path):
45
  with open(engine_path, "rb") as f:
46
- engine_bytes = f.read()
47
- self.engine = runtime.deserialize_cuda_engine(engine_bytes)
48
  self.context = self.engine.create_execution_context()
49
- # Default precision – overridden to bfloat16 for Flux in TensorRTLoader
50
- self.dtype = torch.float16
51
 
52
- def __call__(self, x, timesteps, context, y=None,
53
- control=None, transformer_options=None, **kwargs):
 
 
 
 
 
 
 
54
  """
55
- x: [B, C, H, W]
56
- timesteps: [B]
57
- context: [B, T, Ctxt]
58
- y: [B, adm_dim] (SDXL / SD3 / etc.)
59
- Other kwargs (control, transformer_options, guidance, ...) are ignored
60
- at TensorRT level, but must be accepted to match Comfy's callsite.
61
  """
62
 
63
- # Use latent device as canonical device
64
- device = x.device
65
-
66
- # Helper to put everything on the right device / dtype and contiguous
67
- def _prep(t):
68
- if t is None:
69
- return None
70
- return t.to(device=device, dtype=self.dtype).contiguous()
71
-
72
- x = _prep(x)
73
- timesteps = _prep(timesteps)
74
- context = _prep(context)
75
- y = _prep(y)
76
-
77
- # Discover engine IO tensors
78
- tensor_names = [
79
- self.engine.get_tensor_name(i)
80
- for i in range(self.engine.num_io_tensors)
81
- ]
82
- input_names = [
83
- n for n in tensor_names
84
- if self.engine.get_tensor_mode(n) == trt.TensorIOMode.INPUT
85
- ]
86
- output_names = [
87
- n for n in tensor_names
88
- if self.engine.get_tensor_mode(n) == trt.TensorIOMode.OUTPUT
89
- ]
90
-
91
- # Build a dict of available tensors by name
92
- available = {"x": x, "timesteps": timesteps, "context": context}
93
  if y is not None:
94
- available["y"] = y
95
-
96
- # Allow passing extra inputs (e.g. "guidance" for Flux) via kwargs
97
- for k, v in kwargs.items():
98
- if isinstance(v, torch.Tensor):
99
- available[k] = _prep(v)
100
-
101
- # Canonical order, so we never accidentally swap x/timesteps/context/y
102
- canonical_order = {"x": 0, "timesteps": 1, "context": 2, "y": 3}
103
- input_names_sorted = sorted(
104
- input_names,
105
- key=lambda n: canonical_order.get(n, 100),
106
- )
107
-
108
- # Bind all inputs – every engine input must get a valid tensor
109
- for name in input_names_sorted:
110
- if name not in available or available[name] is None:
111
- raise RuntimeError(
112
- f"TensorRT engine expects input '{name}' but no tensor was provided."
113
- )
114
- t = available[name]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  self.context.set_input_shape(name, tuple(t.shape))
116
- self.context.set_tensor_address(name, t.data_ptr())
117
 
118
- # Infer shapes (resolve dynamic dims)
119
  missing = self.context.infer_shapes()
120
  if missing:
121
- raise RuntimeError(
122
- f"TensorRT shape inference failed, unresolved tensors: {missing}"
123
- )
124
-
125
- # Ensure the context has enough device memory for the resolved shapes
126
- self.context.update_device_memory_size_for_shapes()
127
 
128
- # Allocate and bind outputs
129
- outputs = []
 
 
130
  for name in output_names:
131
- out_dims = self.context.get_tensor_shape(name)
132
  out_shape = tuple(int(d) for d in out_dims)
133
- out_dtype = trt_datatype_to_torch(self.engine.get_tensor_dtype(name))
134
- out_tensor = torch.empty(out_shape, device=device, dtype=out_dtype)
135
- self.context.set_tensor_address(name, out_tensor.data_ptr())
136
- outputs.append(out_tensor)
137
 
138
- # Execute on the current PyTorch stream for correct ordering
139
- stream = torch.cuda.current_stream(device).cuda_stream
140
- self.context.execute_async_v3(stream_handle=stream)
141
 
142
- # Comfy's apply_model() will call .float() on this anyway
143
- return outputs[0] if len(outputs) == 1 else tuple(outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  def load_state_dict(self, sd, strict=False):
146
- # No-op – weights are inside the TensorRT engine file.
147
- return
148
 
149
  def state_dict(self):
150
  return {}
@@ -152,6 +157,7 @@ class TrTUnet:
152
 
153
 
154
 
 
155
  class TensorRTLoader:
156
  @classmethod
157
  def INPUT_TYPES(s):
 
43
  class TrTUnet:
44
  def __init__(self, engine_path):
45
  with open(engine_path, "rb") as f:
46
+ self.engine = runtime.deserialize_cuda_engine(f.read())
 
47
  self.context = self.engine.create_execution_context()
 
 
48
 
49
+ # Default torch device / dtype for allocations
50
+ self.device = comfy.model_management.get_torch_device()
51
+ self.default_dtype = torch.float16 # fallback if something unknown shows up
52
+
53
+ def _trt_dtype_to_torch(self, trt_dtype):
54
+ dt = trt_datatype_to_torch(trt_dtype)
55
+ return dt if dt is not None else self.default_dtype
56
+
57
+ def __call__(self, x, timesteps, context, y=None, control=None, transformer_options=None, **kwargs):
58
  """
59
+ x : [B, C, H, W]
60
+ timesteps : [B]
61
+ context : [B, N, D]
62
+ y : [B, y_dim] (optional, SDXL etc.)
 
 
63
  """
64
 
65
+ # -----------------------------
66
+ # 1. Build dict of actual inputs
67
+ # -----------------------------
68
+ model_inputs = {
69
+ "x": x,
70
+ "timesteps": timesteps,
71
+ "context": context,
72
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  if y is not None:
74
+ model_inputs["y"] = y
75
+
76
+ # If your engine has extra inputs (e.g. 'guidance' for Flux),
77
+ # they must either come from kwargs or be absent from the engine.
78
+ tensor_names = [self.engine.get_tensor_name(i) for i in range(self.engine.num_io_tensors)]
79
+ input_names = [n for n in tensor_names if self.engine.get_tensor_mode(n) == trt.TensorIOMode.INPUT]
80
+ output_names = [n for n in tensor_names if self.engine.get_tensor_mode(n) == trt.TensorIOMode.OUTPUT]
81
+
82
+ # Fill missing inputs from kwargs if present
83
+ for name in input_names:
84
+ if name in model_inputs:
85
+ continue
86
+ if name in kwargs:
87
+ model_inputs[name] = kwargs[name]
88
+
89
+ if len(model_inputs) != len(input_names):
90
+ missing = [n for n in input_names if n not in model_inputs]
91
+ raise RuntimeError(
92
+ f"TensorRT UNet: missing required inputs for engine: {missing} "
93
+ f"(have {list(model_inputs.keys())})"
94
+ )
95
+
96
+ # -----------------------------
97
+ # 2. Convert each input to engine dtype + bind it
98
+ # -----------------------------
99
+ for name in input_names:
100
+ t = model_inputs[name]
101
+
102
+ # Move to correct device
103
+ if t.device != self.device:
104
+ t = t.to(self.device)
105
+
106
+ # Match TensorRT's expected dtype for this tensor
107
+ trt_dtype = self.engine.get_tensor_dtype(name)
108
+ torch_dtype = self._trt_dtype_to_torch(trt_dtype)
109
+ if t.dtype != torch_dtype:
110
+ t = t.to(dtype=torch_dtype)
111
+
112
+ # Update back (so later code sees the converted tensor if needed)
113
+ model_inputs[name] = t
114
+
115
+ # Set runtime shape and bind memory
116
  self.context.set_input_shape(name, tuple(t.shape))
117
+ self.context.set_tensor_address(name, int(t.data_ptr()))
118
 
119
+ # Make sure all shapes are resolved
120
  missing = self.context.infer_shapes()
121
  if missing:
122
+ raise RuntimeError(f"TensorRT shape inference failed, unresolved tensors: {missing}")
 
 
 
 
 
123
 
124
+ # -----------------------------
125
+ # 3. Allocate & bind outputs
126
+ # -----------------------------
127
+ outputs = {}
128
  for name in output_names:
129
+ out_dims = self.context.get_tensor_shape(name) # trt.Dims
130
  out_shape = tuple(int(d) for d in out_dims)
 
 
 
 
131
 
132
+ trt_dtype = self.engine.get_tensor_dtype(name)
133
+ torch_dtype = self._trt_dtype_to_torch(trt_dtype)
 
134
 
135
+ out_tensor = torch.empty(out_shape, device=self.device, dtype=torch_dtype)
136
+ self.context.set_tensor_address(name, int(out_tensor.data_ptr()))
137
+ outputs[name] = out_tensor
138
+
139
+ # -----------------------------
140
+ # 4. Execute on the current torch CUDA stream
141
+ # -----------------------------
142
+ stream = torch.cuda.current_stream(self.device)
143
+ self.context.execute_async_v3(stream_handle=stream.cuda_stream)
144
+
145
+ # No need to sync explicitly; ComfyUI uses the same default stream.
146
+
147
+ # Return outputs in a stable order
148
+ out_list = [outputs[name] for name in output_names]
149
+ return out_list[0] if len(out_list) == 1 else tuple(out_list)
150
 
151
  def load_state_dict(self, sd, strict=False):
152
+ pass
 
153
 
154
  def state_dict(self):
155
  return {}
 
157
 
158
 
159
 
160
+
161
  class TensorRTLoader:
162
  @classmethod
163
  def INPUT_TYPES(s):