saliacoel commited on
Commit
7c128cf
·
verified ·
1 Parent(s): 91e9495

Upload tensorrt_convert.py

Browse files
Files changed (1) hide show
  1. tensorrt_convert.py +325 -122
tensorrt_convert.py CHANGED
@@ -1,17 +1,52 @@
1
- import torch
2
- import sys
3
  import os
 
4
  import time
 
 
5
  import comfy.model_management
6
 
7
  import tensorrt as trt
8
  import folder_paths
9
  from tqdm import tqdm
10
 
11
- # TODO:
12
- # Make it more generic: less model specific code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # add output directory to tensorrt search path
 
 
15
  if "tensorrt" in folder_paths.folder_names_and_paths:
16
  folder_paths.folder_names_and_paths["tensorrt"][0].append(
17
  os.path.join(folder_paths.get_output_directory(), "tensorrt")
@@ -23,6 +58,10 @@ else:
23
  {".engine"},
24
  )
25
 
 
 
 
 
26
  class TQDMProgressMonitor(trt.IProgressMonitor):
27
  def __init__(self):
28
  trt.IProgressMonitor.__init__(self)
@@ -53,8 +92,9 @@ class TQDMProgressMonitor(trt.IProgressMonitor):
53
  "parent_phase": parent_phase,
54
  }
55
  except KeyboardInterrupt:
56
- # The phase_start callback cannot directly cancel the build, so request the cancellation from within step_complete.
57
- _step_result = False
 
58
 
59
  def phase_finish(self, phase_name):
60
  try:
@@ -78,9 +118,8 @@ class TQDMProgressMonitor(trt.IProgressMonitor):
78
  self._active_phases[phase_name]["parent_phase"]
79
  ]["tq"].refresh()
80
  del self._active_phases[phase_name]
81
- pass
82
  except KeyboardInterrupt:
83
- _step_result = False
84
 
85
  def step_complete(self, phase_name, step):
86
  try:
@@ -90,16 +129,22 @@ class TQDMProgressMonitor(trt.IProgressMonitor):
90
  )
91
  return self._step_result
92
  except KeyboardInterrupt:
93
- # There is no need to propagate this exception to TensorRT. We can simply cancel the build.
 
94
  return False
95
-
96
 
 
 
 
 
97
  class TRT_MODEL_CONVERSION_BASE:
98
  def __init__(self):
99
  self.output_dir = folder_paths.get_output_directory()
100
  self.temp_dir = folder_paths.get_temp_directory()
101
  self.timing_cache_path = os.path.normpath(
102
- os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "timing_cache.trt"))
 
 
103
  )
104
 
105
  RETURN_TYPES = ()
@@ -117,9 +162,9 @@ class TRT_MODEL_CONVERSION_BASE:
117
  if os.path.exists(self.timing_cache_path):
118
  with open(self.timing_cache_path, mode="rb") as timing_cache_file:
119
  buffer = timing_cache_file.read()
120
- print("Read {} bytes from timing cache.".format(len(buffer)))
121
  else:
122
- print("No timing cache found; Initializing a new one.")
123
  timing_cache: trt.ITimingCache = config.create_timing_cache(buffer)
124
  config.set_timing_cache(timing_cache, ignore_mismatch=True)
125
 
@@ -127,7 +172,9 @@ class TRT_MODEL_CONVERSION_BASE:
127
  def _save_timing_cache(self, config: trt.IBuilderConfig):
128
  timing_cache: trt.ITimingCache = config.get_timing_cache()
129
  with open(self.timing_cache_path, "wb") as timing_cache_file:
130
- timing_cache_file.write(memoryview(timing_cache.serialize()))
 
 
131
 
132
  def _convert(
133
  self,
@@ -148,15 +195,39 @@ class TRT_MODEL_CONVERSION_BASE:
148
  num_video_frames,
149
  is_static: bool,
150
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  output_onnx = os.path.normpath(
152
  os.path.join(
153
  os.path.join(self.temp_dir, "{}".format(time.time())), "model.onnx"
154
  )
155
  )
 
156
 
 
 
 
157
  comfy.model_management.unload_all_models()
158
- comfy.model_management.load_models_gpu([model], force_patch_weights=True, force_full_load=True)
 
 
 
159
  unet = model.model.diffusion_model
 
 
160
 
161
  context_dim = model.model.model_config.unet_config.get("context_dim", None)
162
  context_len = 77
@@ -165,149 +236,265 @@ class TRT_MODEL_CONVERSION_BASE:
165
  extra_input = {}
166
  dtype = torch.float16
167
 
168
- if isinstance(model.model, comfy.model_base.SD3): #SD3
169
- context_embedder_config = model.model.model_config.unet_config.get("context_embedder_config", None)
 
 
 
 
 
170
  if context_embedder_config is not None:
171
- context_dim = context_embedder_config.get("params", {}).get("in_features", None)
172
- context_len = 154 #NOTE: SD3 can have 77 or 154 depending on which text encoders are used, this is why context_len_min stays 77
 
 
 
 
173
  elif isinstance(model.model, comfy.model_base.AuraFlow):
174
  context_dim = 2048
175
  context_len_min = 256
176
  context_len = 256
 
 
 
 
177
  elif isinstance(model.model, comfy.model_base.Flux):
178
- context_dim = model.model.model_config.unet_config.get("context_in_dim", None)
 
 
179
  context_len_min = 256
180
  context_len = 256
181
  y_dim = model.model.model_config.unet_config.get("vec_in_dim", None)
182
  extra_input = {"guidance": ()}
183
  dtype = torch.bfloat16
 
 
 
 
 
184
 
185
- if context_dim is not None:
186
- input_names = ["x", "timesteps", "context"]
187
- output_names = ["h"]
 
 
188
 
189
- dynamic_axes = {
190
- "x": {0: "batch", 2: "height", 3: "width"},
191
- "timesteps": {0: "batch"},
192
- "context": {0: "batch", 1: "num_embeds"},
193
- }
194
 
195
- transformer_options = model.model_options['transformer_options'].copy()
196
- if model.model.model_config.unet_config.get(
197
- "use_temporal_resblock", False
198
- ): # SVD
199
- batch_size_min = num_video_frames * batch_size_min
200
- batch_size_opt = num_video_frames * batch_size_opt
201
- batch_size_max = num_video_frames * batch_size_max
202
-
203
- class UNET(torch.nn.Module):
204
- def forward(self, x, timesteps, context, y):
205
- return self.unet(
206
- x,
207
- timesteps,
208
- context,
209
- y,
210
- num_video_frames=self.num_video_frames,
211
- transformer_options=self.transformer_options,
212
- )
213
-
214
- svd_unet = UNET()
215
- svd_unet.num_video_frames = num_video_frames
216
- svd_unet.unet = unet
217
- svd_unet.transformer_options = transformer_options
218
- unet = svd_unet
219
- context_len_min = context_len = 1
220
- else:
221
- class UNET(torch.nn.Module):
222
- def forward(self, x, timesteps, context, *args):
223
- extras = input_names[3:]
224
- extra_args = {}
225
- for i in range(len(extras)):
226
- extra_args[extras[i]] = args[i]
227
- return self.unet(x, timesteps, context, transformer_options=self.transformer_options, **extra_args)
228
-
229
- _unet = UNET()
230
- _unet.unet = unet
231
- _unet.transformer_options = transformer_options
232
- unet = _unet
233
-
234
- input_channels = model.model.model_config.unet_config.get("in_channels", 4)
235
-
236
- inputs_shapes_min = (
237
- (batch_size_min, input_channels, height_min // 8, width_min // 8),
238
- (batch_size_min,),
239
- (batch_size_min, context_len_min * context_min, context_dim),
240
  )
241
- inputs_shapes_opt = (
242
- (batch_size_opt, input_channels, height_opt // 8, width_opt // 8),
243
- (batch_size_opt,),
244
- (batch_size_opt, context_len * context_opt, context_dim),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  )
246
- inputs_shapes_max = (
247
- (batch_size_max, input_channels, height_max // 8, width_max // 8),
248
- (batch_size_max,),
249
- (batch_size_max, context_len * context_max, context_dim),
 
 
 
 
 
250
  )
251
 
252
- if y_dim > 0:
253
- input_names.append("y")
254
- dynamic_axes["y"] = {0: "batch"}
255
- inputs_shapes_min += ((batch_size_min, y_dim),)
256
- inputs_shapes_opt += ((batch_size_opt, y_dim),)
257
- inputs_shapes_max += ((batch_size_max, y_dim),)
258
-
259
- for k in extra_input:
260
- input_names.append(k)
261
- dynamic_axes[k] = {0: "batch"}
262
- inputs_shapes_min += ((batch_size_min,) + extra_input[k],)
263
- inputs_shapes_opt += ((batch_size_opt,) + extra_input[k],)
264
- inputs_shapes_max += ((batch_size_max,) + extra_input[k],)
265
-
266
-
267
- inputs = ()
268
- for shape in inputs_shapes_opt:
269
- inputs += (
270
- torch.zeros(
271
- shape,
272
- device=comfy.model_management.get_torch_device(),
273
- dtype=dtype,
274
- ),
275
- )
276
 
277
- else:
278
- print("ERROR: model not supported.")
279
- return ()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
 
 
 
 
281
  os.makedirs(os.path.dirname(output_onnx), exist_ok=True)
282
- torch.onnx.export(
283
- unet,
284
- inputs,
285
- output_onnx,
286
- verbose=False,
287
- input_names=input_names,
288
- output_names=output_names,
289
- opset_version=17,
290
- dynamic_axes=False,
291
- dynamo=False,
292
  )
293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  comfy.model_management.unload_all_models()
295
  comfy.model_management.soft_empty_cache()
296
 
297
- # TRT conversion starts here
 
 
298
  logger = trt.Logger(trt.Logger.INFO)
299
  builder = trt.Builder(logger)
 
300
 
301
  network = builder.create_network(
302
  1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
303
  )
304
  parser = trt.OnnxParser(network, logger)
 
305
  success = parser.parse_from_file(output_onnx)
306
  for idx in range(parser.num_errors):
307
  print(parser.get_error(idx))
308
 
309
  if not success:
310
- print("ONNX load ERROR")
311
  return ()
312
 
313
  config = builder.create_builder_config()
@@ -315,22 +502,28 @@ class TRT_MODEL_CONVERSION_BASE:
315
  self._setup_timing_cache(config)
316
  config.progress_monitor = TQDMProgressMonitor()
317
 
 
318
  prefix_encode = ""
319
  for k in range(len(input_names)):
320
  min_shape = inputs_shapes_min[k]
321
  opt_shape = inputs_shapes_opt[k]
322
  max_shape = inputs_shapes_max[k]
 
 
 
323
  profile.set_shape(input_names[k], min_shape, opt_shape, max_shape)
324
 
325
  # Encode shapes to filename
326
- encode = lambda a: ".".join(map(lambda x: str(x), a))
327
  prefix_encode += "{}#{}#{}#{};".format(
328
  input_names[k], encode(min_shape), encode(opt_shape), encode(max_shape)
329
  )
330
 
331
  if dtype == torch.float16:
 
332
  config.set_flag(trt.BuilderFlag.FP16)
333
  if dtype == torch.bfloat16:
 
334
  config.set_flag(trt.BuilderFlag.BF16)
335
 
336
  config.add_optimization_profile(profile)
@@ -372,7 +565,11 @@ class TRT_MODEL_CONVERSION_BASE:
372
  ),
373
  )
374
 
 
375
  serialized_engine = builder.build_serialized_network(network, config)
 
 
 
376
 
377
  full_output_folder, filename, counter, subfolder, filename_prefix = (
378
  folder_paths.get_save_image_path(filename_prefix, self.output_dir)
@@ -381,14 +578,20 @@ class TRT_MODEL_CONVERSION_BASE:
381
  full_output_folder, f"{filename}_{counter:05}_.engine"
382
  )
383
 
 
 
384
  with open(output_trt_engine, "wb") as f:
385
  f.write(serialized_engine)
386
 
387
  self._save_timing_cache(config)
 
388
 
389
  return ()
390
 
391
 
 
 
 
392
  class DYNAMIC_TRT_MODEL_CONVERSION(TRT_MODEL_CONVERSION_BASE):
393
  def __init__(self):
394
  super(DYNAMIC_TRT_MODEL_CONVERSION, self).__init__()
 
 
 
1
  import os
2
+ import sys
3
  import time
4
+
5
+ import torch
6
  import comfy.model_management
7
 
8
  import tensorrt as trt
9
  import folder_paths
10
  from tqdm import tqdm
11
 
12
+ # -------------------------------------------------------------------------
13
+ # torch.export dynamic shapes support
14
+ # -------------------------------------------------------------------------
15
+ try:
16
+ from torch.export import Dim
17
+ except Exception as e:
18
+ raise RuntimeError(
19
+ "[TensorRTExport] torch.export.Dim not available. "
20
+ "Please upgrade PyTorch to >= 2.1 / 2.5+ to use the Dynamo-based "
21
+ "ONNX exporter with dynamic shapes."
22
+ ) from e
23
+
24
+
25
+ def trtlog(msg: str):
26
+ print(f"[TensorRTExport] {msg}", flush=True)
27
+
28
+
29
+ # Opset handling:
30
+ # - If COMFY_TRT_ONNX_OPSET is set, use that integer.
31
+ # - Otherwise, leave opset_version=None so torch.onnx uses the
32
+ # recommended opset for this PyTorch version (e.g. 20 on 2.9).
33
+ DEFAULT_ONNX_OPSET = None
34
+ _env_opset = os.getenv("COMFY_TRT_ONNX_OPSET")
35
+ if _env_opset is not None:
36
+ try:
37
+ DEFAULT_ONNX_OPSET = int(_env_opset)
38
+ trtlog(f"Using opset_version from COMFY_TRT_ONNX_OPSET={DEFAULT_ONNX_OPSET}")
39
+ except ValueError:
40
+ trtlog(
41
+ f"WARNING: invalid COMFY_TRT_ONNX_OPSET={_env_opset!r}, "
42
+ "falling back to PyTorch recommended opset (None)."
43
+ )
44
+ DEFAULT_ONNX_OPSET = None
45
+
46
 
47
+ # -------------------------------------------------------------------------
48
+ # Add output directory to TensorRT search path (ComfyUI integration)
49
+ # -------------------------------------------------------------------------
50
  if "tensorrt" in folder_paths.folder_names_and_paths:
51
  folder_paths.folder_names_and_paths["tensorrt"][0].append(
52
  os.path.join(folder_paths.get_output_directory(), "tensorrt")
 
58
  {".engine"},
59
  )
60
 
61
+
62
+ # -------------------------------------------------------------------------
63
+ # Progress monitor for TensorRT builds
64
+ # -------------------------------------------------------------------------
65
  class TQDMProgressMonitor(trt.IProgressMonitor):
66
  def __init__(self):
67
  trt.IProgressMonitor.__init__(self)
 
92
  "parent_phase": parent_phase,
93
  }
94
  except KeyboardInterrupt:
95
+ # The phase_start callback cannot directly cancel the build,
96
+ # so request the cancellation from within step_complete.
97
+ self._step_result = False
98
 
99
  def phase_finish(self, phase_name):
100
  try:
 
118
  self._active_phases[phase_name]["parent_phase"]
119
  ]["tq"].refresh()
120
  del self._active_phases[phase_name]
 
121
  except KeyboardInterrupt:
122
+ self._step_result = False
123
 
124
  def step_complete(self, phase_name, step):
125
  try:
 
129
  )
130
  return self._step_result
131
  except KeyboardInterrupt:
132
+ # There is no need to propagate this exception to TensorRT.
133
+ # We can simply cancel the build.
134
  return False
 
135
 
136
+
137
+ # -------------------------------------------------------------------------
138
+ # Base class for ONNX -> TensorRT conversion
139
+ # -------------------------------------------------------------------------
140
  class TRT_MODEL_CONVERSION_BASE:
141
  def __init__(self):
142
  self.output_dir = folder_paths.get_output_directory()
143
  self.temp_dir = folder_paths.get_temp_directory()
144
  self.timing_cache_path = os.path.normpath(
145
+ os.path.join(
146
+ os.path.dirname(os.path.realpath(__file__)), "timing_cache.trt"
147
+ )
148
  )
149
 
150
  RETURN_TYPES = ()
 
162
  if os.path.exists(self.timing_cache_path):
163
  with open(self.timing_cache_path, mode="rb") as timing_cache_file:
164
  buffer = timing_cache_file.read()
165
+ trtlog(f"Read {len(buffer)} bytes from timing cache.")
166
  else:
167
+ trtlog("No timing cache found; initializing a new one.")
168
  timing_cache: trt.ITimingCache = config.create_timing_cache(buffer)
169
  config.set_timing_cache(timing_cache, ignore_mismatch=True)
170
 
 
172
  def _save_timing_cache(self, config: trt.IBuilderConfig):
173
  timing_cache: trt.ITimingCache = config.get_timing_cache()
174
  with open(self.timing_cache_path, "wb") as timing_cache_file:
175
+ serialized = timing_cache.serialize()
176
+ timing_cache_file.write(memoryview(serialized))
177
+ trtlog(f"Timing cache saved to {self.timing_cache_path}")
178
 
179
  def _convert(
180
  self,
 
195
  num_video_frames,
196
  is_static: bool,
197
  ):
198
+ # -----------------------------------------------------------------
199
+ # Basic logging: versions & configuration
200
+ # -----------------------------------------------------------------
201
+ trtlog(
202
+ f"PyTorch version: {torch.__version__}, TensorRT version: {trt.__version__}"
203
+ )
204
+ trtlog(
205
+ f"Requested {'STATIC' if is_static else 'DYNAMIC'} TensorRT engine "
206
+ f"(b=[{batch_size_min},{batch_size_opt},{batch_size_max}], "
207
+ f"h=[{height_min},{height_opt},{height_max}], "
208
+ f"w=[{width_min},{width_opt},{width_max}], "
209
+ f"context=[{context_min},{context_opt},{context_max}], "
210
+ f"num_video_frames={num_video_frames})"
211
+ )
212
+
213
  output_onnx = os.path.normpath(
214
  os.path.join(
215
  os.path.join(self.temp_dir, "{}".format(time.time())), "model.onnx"
216
  )
217
  )
218
+ trtlog(f"Temporary ONNX path: {output_onnx}")
219
 
220
+ # -----------------------------------------------------------------
221
+ # Load model to GPU
222
+ # -----------------------------------------------------------------
223
  comfy.model_management.unload_all_models()
224
+ comfy.model_management.load_models_gpu(
225
+ [model], force_patch_weights=True, force_full_load=True
226
+ )
227
+
228
  unet = model.model.diffusion_model
229
+ model_type = type(model.model).__name__
230
+ trtlog(f"Detected model type: {model_type}")
231
 
232
  context_dim = model.model.model_config.unet_config.get("context_dim", None)
233
  context_len = 77
 
236
  extra_input = {}
237
  dtype = torch.float16
238
 
239
+ # -----------------------------------------------------------------
240
+ # Model-type specific tweaks
241
+ # -----------------------------------------------------------------
242
+ if isinstance(model.model, comfy.model_base.SD3): # SD3
243
+ context_embedder_config = model.model.model_config.unet_config.get(
244
+ "context_embedder_config", None
245
+ )
246
  if context_embedder_config is not None:
247
+ context_dim = context_embedder_config.get(
248
+ "params", {}
249
+ ).get("in_features", None)
250
+ # SD3 can have 77 or 154 depending on TE usage
251
+ context_len = 154
252
+ trtlog(f"SD3 context_dim={context_dim}, context_len={context_len}")
253
  elif isinstance(model.model, comfy.model_base.AuraFlow):
254
  context_dim = 2048
255
  context_len_min = 256
256
  context_len = 256
257
+ trtlog(
258
+ f"AuraFlow context_dim={context_dim}, "
259
+ f"context_len_min={context_len_min}, context_len={context_len}"
260
+ )
261
  elif isinstance(model.model, comfy.model_base.Flux):
262
+ context_dim = model.model.model_config.unet_config.get(
263
+ "context_in_dim", None
264
+ )
265
  context_len_min = 256
266
  context_len = 256
267
  y_dim = model.model.model_config.unet_config.get("vec_in_dim", None)
268
  extra_input = {"guidance": ()}
269
  dtype = torch.bfloat16
270
+ trtlog(
271
+ f"Flux context_dim={context_dim}, y_dim={y_dim}, "
272
+ f"context_len_min={context_len_min}, context_len={context_len}, "
273
+ f"extra_input={list(extra_input.keys())}, dtype={dtype}"
274
+ )
275
 
276
+ if context_dim is None:
277
+ print("ERROR: model not supported (no context_dim).")
278
+ comfy.model_management.unload_all_models()
279
+ comfy.model_management.soft_empty_cache()
280
+ return ()
281
 
282
+ input_names = ["x", "timesteps", "context"]
283
+ output_names = ["h"]
 
 
 
284
 
285
+ transformer_options = model.model_options["transformer_options"].copy()
286
+ use_temporal = model.model.model_config.unet_config.get(
287
+ "use_temporal_resblock", False
288
+ )
289
+
290
+ # -----------------------------------------------------------------
291
+ # Wrap UNet so argument names are stable for dynamic_shapes
292
+ # -----------------------------------------------------------------
293
+ if use_temporal: # SVD
294
+ trtlog("Model uses temporal resblock (SVD-like). Adjusting batch sizes.")
295
+ batch_size_min = num_video_frames * batch_size_min
296
+ batch_size_opt = num_video_frames * batch_size_opt
297
+ batch_size_max = num_video_frames * batch_size_max
298
+
299
+ class SVD_UNET(torch.nn.Module):
300
+ def __init__(self, unet, transformer_options, num_video_frames):
301
+ super().__init__()
302
+ self.unet = unet
303
+ self.transformer_options = transformer_options
304
+ self.num_video_frames = num_video_frames
305
+
306
+ def forward(self, x, timesteps, context, y):
307
+ return self.unet(
308
+ x,
309
+ timesteps,
310
+ context,
311
+ y,
312
+ num_video_frames=self.num_video_frames,
313
+ transformer_options=self.transformer_options,
314
+ )
315
+
316
+ unet = SVD_UNET(unet, transformer_options, num_video_frames)
317
+ context_len_min = context_len = 1
318
+ trtlog(
319
+ f"SVD adjusted batch: "
320
+ f"b=[{batch_size_min},{batch_size_opt},{batch_size_max}], "
321
+ f"context_len_min={context_len_min}, context_len={context_len}"
 
 
 
 
 
 
 
 
322
  )
323
+
324
+ else:
325
+ # Generic wrapper with named extras (y, guidance)
326
+ extra_keys = list(extra_input.keys())
327
+
328
+ class UNET(torch.nn.Module):
329
+ def __init__(self, unet, transformer_options, y_dim, extra_keys):
330
+ super().__init__()
331
+ self.unet = unet
332
+ self.transformer_options = transformer_options
333
+ self.y_dim = y_dim
334
+ self.extra_keys = extra_keys
335
+
336
+ def forward(self, x, timesteps, context, y=None, guidance=None):
337
+ extra_args = {}
338
+ if self.y_dim is not None and self.y_dim > 0 and y is not None:
339
+ extra_args["y"] = y
340
+ if "guidance" in self.extra_keys and guidance is not None:
341
+ extra_args["guidance"] = guidance
342
+
343
+ return self.unet(
344
+ x,
345
+ timesteps,
346
+ context,
347
+ transformer_options=self.transformer_options,
348
+ **extra_args,
349
+ )
350
+
351
+ unet = UNET(unet, transformer_options, y_dim, extra_keys)
352
+
353
+ # -----------------------------------------------------------------
354
+ # Compute input shapes (min / opt / max)
355
+ # -----------------------------------------------------------------
356
+ input_channels = model.model.model_config.unet_config.get("in_channels", 4)
357
+
358
+ inputs_shapes_min = (
359
+ (batch_size_min, input_channels, height_min // 8, width_min // 8),
360
+ (batch_size_min,),
361
+ (batch_size_min, context_len_min * context_min, context_dim),
362
+ )
363
+ inputs_shapes_opt = (
364
+ (batch_size_opt, input_channels, height_opt // 8, width_opt // 8),
365
+ (batch_size_opt,),
366
+ (batch_size_opt, context_len * context_opt, context_dim),
367
+ )
368
+ inputs_shapes_max = (
369
+ (batch_size_max, input_channels, height_max // 8, width_max // 8),
370
+ (batch_size_max,),
371
+ (batch_size_max, context_len * context_max, context_dim),
372
+ )
373
+
374
+ if y_dim is not None and y_dim > 0:
375
+ input_names.append("y")
376
+ inputs_shapes_min += ((batch_size_min, y_dim),)
377
+ inputs_shapes_opt += ((batch_size_opt, y_dim),)
378
+ inputs_shapes_max += ((batch_size_max, y_dim),)
379
+
380
+ # Extra inputs (currently used for Flux guidance)
381
+ for k in extra_input:
382
+ input_names.append(k)
383
+ shape_suffix = extra_input[k] # e.g. () for scalar per batch
384
+ inputs_shapes_min += ((batch_size_min,) + shape_suffix,)
385
+ inputs_shapes_opt += ((batch_size_opt,) + shape_suffix,)
386
+ inputs_shapes_max += ((batch_size_max,) + shape_suffix,)
387
+
388
+ # Clamp context ranges sanely if the UI somehow passed inverted min/max
389
+ if context_max < context_min:
390
+ trtlog(
391
+ f"WARNING: context_max({context_max}) < context_min({context_min}), swapping."
392
  )
393
+ context_min, context_max = context_max, context_min
394
+
395
+ trtlog("Input names: " + ", ".join(input_names))
396
+ for idx, name in enumerate(input_names):
397
+ trtlog(
398
+ f" {name}: "
399
+ f"min={inputs_shapes_min[idx]}, "
400
+ f"opt={inputs_shapes_opt[idx]}, "
401
+ f"max={inputs_shapes_max[idx]}"
402
  )
403
 
404
+ # -----------------------------------------------------------------
405
+ # Build dynamic_shapes spec for torch.export / dynamo=True
406
+ # -----------------------------------------------------------------
407
+ B = Dim("batch", min=batch_size_min, max=batch_size_max)
408
+ H = Dim("height", min=height_min // 8, max=height_max // 8)
409
+ W = Dim("width", min=width_min // 8, max=width_max // 8)
410
+ T = Dim(
411
+ "tokens",
412
+ min=context_len_min * context_min,
413
+ max=context_len * context_max,
414
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
415
 
416
+ dynamic_shapes = {
417
+ "x": {0: B, 2: H, 3: W},
418
+ "timesteps": {0: B},
419
+ "context": {0: B, 1: T},
420
+ }
421
+
422
+ if "y" in input_names:
423
+ dynamic_shapes["y"] = {0: B}
424
+ if "guidance" in input_names:
425
+ dynamic_shapes["guidance"] = {0: B}
426
+
427
+ trtlog(f"dynamic_shapes spec: {dynamic_shapes}")
428
+
429
+ # -----------------------------------------------------------------
430
+ # Build example inputs (using OPT shapes)
431
+ # -----------------------------------------------------------------
432
+ inputs = ()
433
+ for shape in inputs_shapes_opt:
434
+ inputs += (
435
+ torch.zeros(
436
+ shape,
437
+ device=comfy.model_management.get_torch_device(),
438
+ dtype=dtype,
439
+ ),
440
+ )
441
 
442
+ # -----------------------------------------------------------------
443
+ # ONNX export with Dynamo + dynamic_shapes
444
+ # -----------------------------------------------------------------
445
  os.makedirs(os.path.dirname(output_onnx), exist_ok=True)
446
+
447
+ trtlog(
448
+ f"Exporting UNet to ONNX with dynamo=True, "
449
+ f"opset_version={DEFAULT_ONNX_OPSET}, dtype={dtype}, "
450
+ f"output={output_onnx}"
 
 
 
 
 
451
  )
452
 
453
+ try:
454
+ torch.onnx.export(
455
+ unet,
456
+ inputs,
457
+ output_onnx,
458
+ verbose=False,
459
+ input_names=input_names,
460
+ output_names=output_names,
461
+ opset_version=DEFAULT_ONNX_OPSET,
462
+ dynamo=True,
463
+ dynamic_shapes=dynamic_shapes,
464
+ # NOTE:
465
+ # - We intentionally do NOT pass dynamic_axes here.
466
+ # dynamic_axes is for the legacy TorchScript exporter,
467
+ # dynamic_shapes + dynamo=True is the modern path.
468
+ )
469
+ trtlog("torch.onnx.export completed successfully.")
470
+ except Exception as e:
471
+ trtlog(f"ERROR during torch.onnx.export: {e}")
472
+ # Clean up GPU state before re-raising
473
+ comfy.model_management.unload_all_models()
474
+ comfy.model_management.soft_empty_cache()
475
+ raise
476
+
477
  comfy.model_management.unload_all_models()
478
  comfy.model_management.soft_empty_cache()
479
 
480
+ # -----------------------------------------------------------------
481
+ # TensorRT conversion starts here
482
+ # -----------------------------------------------------------------
483
  logger = trt.Logger(trt.Logger.INFO)
484
  builder = trt.Builder(logger)
485
+ trtlog("Created TensorRT builder.")
486
 
487
  network = builder.create_network(
488
  1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
489
  )
490
  parser = trt.OnnxParser(network, logger)
491
+ trtlog(f"Parsing ONNX file: {output_onnx}")
492
  success = parser.parse_from_file(output_onnx)
493
  for idx in range(parser.num_errors):
494
  print(parser.get_error(idx))
495
 
496
  if not success:
497
+ print("ONNX load ERROR (TensorRT parser.parse_from_file returned False).")
498
  return ()
499
 
500
  config = builder.create_builder_config()
 
502
  self._setup_timing_cache(config)
503
  config.progress_monitor = TQDMProgressMonitor()
504
 
505
+ trtlog("Creating optimization profile:")
506
  prefix_encode = ""
507
  for k in range(len(input_names)):
508
  min_shape = inputs_shapes_min[k]
509
  opt_shape = inputs_shapes_opt[k]
510
  max_shape = inputs_shapes_max[k]
511
+ trtlog(
512
+ f" {input_names[k]}: min={min_shape}, opt={opt_shape}, max={max_shape}"
513
+ )
514
  profile.set_shape(input_names[k], min_shape, opt_shape, max_shape)
515
 
516
  # Encode shapes to filename
517
+ encode = lambda a: ".".join(map(str, a))
518
  prefix_encode += "{}#{}#{}#{};".format(
519
  input_names[k], encode(min_shape), encode(opt_shape), encode(max_shape)
520
  )
521
 
522
  if dtype == torch.float16:
523
+ trtlog("Enabling FP16 mode in TensorRT builder config.")
524
  config.set_flag(trt.BuilderFlag.FP16)
525
  if dtype == torch.bfloat16:
526
+ trtlog("Enabling BF16 mode in TensorRT builder config.")
527
  config.set_flag(trt.BuilderFlag.BF16)
528
 
529
  config.add_optimization_profile(profile)
 
565
  ),
566
  )
567
 
568
+ trtlog("Building serialized TensorRT engine. This may take a while...")
569
  serialized_engine = builder.build_serialized_network(network, config)
570
+ if serialized_engine is None:
571
+ trtlog("ERROR: builder.build_serialized_network returned None.")
572
+ return ()
573
 
574
  full_output_folder, filename, counter, subfolder, filename_prefix = (
575
  folder_paths.get_save_image_path(filename_prefix, self.output_dir)
 
578
  full_output_folder, f"{filename}_{counter:05}_.engine"
579
  )
580
 
581
+ trtlog(f"Writing TensorRT engine to: {output_trt_engine}")
582
+ os.makedirs(full_output_folder, exist_ok=True)
583
  with open(output_trt_engine, "wb") as f:
584
  f.write(serialized_engine)
585
 
586
  self._save_timing_cache(config)
587
+ trtlog("TensorRT conversion complete.")
588
 
589
  return ()
590
 
591
 
592
+ # -------------------------------------------------------------------------
593
+ # Dynamic / Static wrapper nodes
594
+ # -------------------------------------------------------------------------
595
  class DYNAMIC_TRT_MODEL_CONVERSION(TRT_MODEL_CONVERSION_BASE):
596
  def __init__(self):
597
  super(DYNAMIC_TRT_MODEL_CONVERSION, self).__init__()