saliacoel commited on
Commit
724b368
·
verified ·
1 Parent(s): 7c128cf

Upload tensorrt_convert.py

Browse files
Files changed (1) hide show
  1. tensorrt_convert.py +80 -24
tensorrt_convert.py CHANGED
@@ -29,7 +29,7 @@ def trtlog(msg: str):
29
  # Opset handling:
30
  # - If COMFY_TRT_ONNX_OPSET is set, use that integer.
31
  # - Otherwise, leave opset_version=None so torch.onnx uses the
32
- # recommended opset for this PyTorch version (e.g. 20 on 2.9).
33
  DEFAULT_ONNX_OPSET = None
34
  _env_opset = os.getenv("COMFY_TRT_ONNX_OPSET")
35
  if _env_opset is not None:
@@ -403,28 +403,78 @@ class TRT_MODEL_CONVERSION_BASE:
403
 
404
  # -----------------------------------------------------------------
405
  # Build dynamic_shapes spec for torch.export / dynamo=True
 
 
406
  # -----------------------------------------------------------------
407
- B = Dim("batch", min=batch_size_min, max=batch_size_max)
408
- H = Dim("height", min=height_min // 8, max=height_max // 8)
409
- W = Dim("width", min=width_min // 8, max=width_max // 8)
410
- T = Dim(
411
- "tokens",
412
- min=context_len_min * context_min,
413
- max=context_len * context_max,
414
- )
415
-
416
- dynamic_shapes = {
417
- "x": {0: B, 2: H, 3: W},
418
- "timesteps": {0: B},
419
- "context": {0: B, 1: T},
420
- }
421
 
422
- if "y" in input_names:
423
- dynamic_shapes["y"] = {0: B}
424
- if "guidance" in input_names:
425
- dynamic_shapes["guidance"] = {0: B}
426
-
427
- trtlog(f"dynamic_shapes spec: {dynamic_shapes}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
 
429
  # -----------------------------------------------------------------
430
  # Build example inputs (using OPT shapes)
@@ -440,7 +490,9 @@ class TRT_MODEL_CONVERSION_BASE:
440
  )
441
 
442
  # -----------------------------------------------------------------
443
- # ONNX export with Dynamo + dynamic_shapes
 
 
444
  # -----------------------------------------------------------------
445
  os.makedirs(os.path.dirname(output_onnx), exist_ok=True)
446
 
@@ -449,6 +501,10 @@ class TRT_MODEL_CONVERSION_BASE:
449
  f"opset_version={DEFAULT_ONNX_OPSET}, dtype={dtype}, "
450
  f"output={output_onnx}"
451
  )
 
 
 
 
452
 
453
  try:
454
  torch.onnx.export(
@@ -463,8 +519,7 @@ class TRT_MODEL_CONVERSION_BASE:
463
  dynamic_shapes=dynamic_shapes,
464
  # NOTE:
465
  # - We intentionally do NOT pass dynamic_axes here.
466
- # dynamic_axes is for the legacy TorchScript exporter,
467
- # dynamic_shapes + dynamo=True is the modern path.
468
  )
469
  trtlog("torch.onnx.export completed successfully.")
470
  except Exception as e:
@@ -828,6 +883,7 @@ class STATIC_TRT_MODEL_CONVERSION(TRT_MODEL_CONVERSION_BASE):
828
  context_opt,
829
  num_video_frames,
830
  ):
 
831
  return super()._convert(
832
  model,
833
  filename_prefix,
 
29
  # Opset handling:
30
  # - If COMFY_TRT_ONNX_OPSET is set, use that integer.
31
  # - Otherwise, leave opset_version=None so torch.onnx uses the
32
+ # recommended opset for this PyTorch version (e.g. 20 in 2.9+).
33
  DEFAULT_ONNX_OPSET = None
34
  _env_opset = os.getenv("COMFY_TRT_ONNX_OPSET")
35
  if _env_opset is not None:
 
403
 
404
  # -----------------------------------------------------------------
405
  # Build dynamic_shapes spec for torch.export / dynamo=True
406
+ # - STATIC node: no dynamic_shapes at all (fully static export)
407
+ # - DYNAMIC node: only create Dim if max > min
408
  # -----------------------------------------------------------------
409
+ dynamic_shapes = None
 
 
 
 
 
 
 
 
 
 
 
 
 
410
 
411
+ def _maybe_dim(name: str, min_v: int, max_v: int):
412
+ """Create Dim only if there is real dynamism (max > min)."""
413
+ if max_v < min_v:
414
+ trtlog(
415
+ f"WARNING: Dim {name} has min>{max_v}>{min_v}, swapping to fix."
416
+ )
417
+ min_v, max_v = max_v, min_v
418
+ if max_v > min_v:
419
+ trtlog(f"Dim {name}: dynamic range [{min_v}, {max_v}]")
420
+ return Dim(name, min=min_v, max=max_v)
421
+ else:
422
+ trtlog(f"Dim {name}: static value {min_v}, not using Dim.")
423
+ return None
424
+
425
+ if not is_static:
426
+ # Only build dynamic_shapes for the DYNAMIC node
427
+ B = _maybe_dim("batch", batch_size_min, batch_size_max)
428
+ H = _maybe_dim("height", height_min // 8, height_max // 8)
429
+ W = _maybe_dim("width", width_min // 8, width_max // 8)
430
+ tokens_min = context_len_min * context_min
431
+ tokens_max = context_len * context_max
432
+ T = _maybe_dim("tokens", tokens_min, tokens_max)
433
+
434
+ dynamic_shapes = {}
435
+
436
+ # x: [B, C, H, W]
437
+ x_dyn = {}
438
+ if B is not None:
439
+ x_dyn[0] = B
440
+ if H is not None:
441
+ x_dyn[2] = H
442
+ if W is not None:
443
+ x_dyn[3] = W
444
+ if x_dyn:
445
+ dynamic_shapes["x"] = x_dyn
446
+
447
+ # timesteps: [B]
448
+ if B is not None:
449
+ dynamic_shapes["timesteps"] = {0: B}
450
+
451
+ # context: [B, T, context_dim]
452
+ ctx_dyn = {}
453
+ if B is not None:
454
+ ctx_dyn[0] = B
455
+ if T is not None:
456
+ ctx_dyn[1] = T
457
+ if ctx_dyn:
458
+ dynamic_shapes["context"] = ctx_dyn
459
+
460
+ # y: [B, y_dim]
461
+ if "y" in input_names and B is not None:
462
+ dynamic_shapes["y"] = {0: B}
463
+
464
+ # guidance: [B, ...]
465
+ if "guidance" in input_names and B is not None:
466
+ dynamic_shapes["guidance"] = {0: B}
467
+
468
+ if not dynamic_shapes:
469
+ trtlog(
470
+ "No dimensions are actually dynamic for DYNAMIC node. "
471
+ "Export will effectively be static."
472
+ )
473
+ dynamic_shapes = None
474
+ else:
475
+ trtlog(f"dynamic_shapes spec: {dynamic_shapes}")
476
+ else:
477
+ trtlog("STATIC node: skipping torch.export.Dim and dynamic_shapes entirely.")
478
 
479
  # -----------------------------------------------------------------
480
  # Build example inputs (using OPT shapes)
 
490
  )
491
 
492
  # -----------------------------------------------------------------
493
+ # ONNX export with Dynamo (dynamo=True)
494
+ # - For static: dynamic_shapes=None, so shapes are fully specialized.
495
+ # - For dynamic: dynamic_shapes guides symbolic shapes.
496
  # -----------------------------------------------------------------
497
  os.makedirs(os.path.dirname(output_onnx), exist_ok=True)
498
 
 
501
  f"opset_version={DEFAULT_ONNX_OPSET}, dtype={dtype}, "
502
  f"output={output_onnx}"
503
  )
504
+ if dynamic_shapes is None:
505
+ trtlog("ONNX export will be STATIC (no dynamic_shapes).")
506
+ else:
507
+ trtlog("ONNX export will use dynamic_shapes (see spec above).")
508
 
509
  try:
510
  torch.onnx.export(
 
519
  dynamic_shapes=dynamic_shapes,
520
  # NOTE:
521
  # - We intentionally do NOT pass dynamic_axes here.
522
+ # dynamic_axes is for the legacy TorchScript exporter.
 
523
  )
524
  trtlog("torch.onnx.export completed successfully.")
525
  except Exception as e:
 
883
  context_opt,
884
  num_video_frames,
885
  ):
886
+ # STATIC: all min/opt/max are identical
887
  return super()._convert(
888
  model,
889
  filename_prefix,