Upload tensorrt_convert.py
Browse files- tensorrt_convert.py +80 -24
tensorrt_convert.py
CHANGED
|
@@ -29,7 +29,7 @@ def trtlog(msg: str):
|
|
| 29 |
# Opset handling:
|
| 30 |
# - If COMFY_TRT_ONNX_OPSET is set, use that integer.
|
| 31 |
# - Otherwise, leave opset_version=None so torch.onnx uses the
|
| 32 |
-
# recommended opset for this PyTorch version (e.g. 20
|
| 33 |
DEFAULT_ONNX_OPSET = None
|
| 34 |
_env_opset = os.getenv("COMFY_TRT_ONNX_OPSET")
|
| 35 |
if _env_opset is not None:
|
|
@@ -403,28 +403,78 @@ class TRT_MODEL_CONVERSION_BASE:
|
|
| 403 |
|
| 404 |
# -----------------------------------------------------------------
|
| 405 |
# Build dynamic_shapes spec for torch.export / dynamo=True
|
|
|
|
|
|
|
| 406 |
# -----------------------------------------------------------------
|
| 407 |
-
|
| 408 |
-
H = Dim("height", min=height_min // 8, max=height_max // 8)
|
| 409 |
-
W = Dim("width", min=width_min // 8, max=width_max // 8)
|
| 410 |
-
T = Dim(
|
| 411 |
-
"tokens",
|
| 412 |
-
min=context_len_min * context_min,
|
| 413 |
-
max=context_len * context_max,
|
| 414 |
-
)
|
| 415 |
-
|
| 416 |
-
dynamic_shapes = {
|
| 417 |
-
"x": {0: B, 2: H, 3: W},
|
| 418 |
-
"timesteps": {0: B},
|
| 419 |
-
"context": {0: B, 1: T},
|
| 420 |
-
}
|
| 421 |
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
|
| 429 |
# -----------------------------------------------------------------
|
| 430 |
# Build example inputs (using OPT shapes)
|
|
@@ -440,7 +490,9 @@ class TRT_MODEL_CONVERSION_BASE:
|
|
| 440 |
)
|
| 441 |
|
| 442 |
# -----------------------------------------------------------------
|
| 443 |
-
# ONNX export with Dynamo
|
|
|
|
|
|
|
| 444 |
# -----------------------------------------------------------------
|
| 445 |
os.makedirs(os.path.dirname(output_onnx), exist_ok=True)
|
| 446 |
|
|
@@ -449,6 +501,10 @@ class TRT_MODEL_CONVERSION_BASE:
|
|
| 449 |
f"opset_version={DEFAULT_ONNX_OPSET}, dtype={dtype}, "
|
| 450 |
f"output={output_onnx}"
|
| 451 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
|
| 453 |
try:
|
| 454 |
torch.onnx.export(
|
|
@@ -463,8 +519,7 @@ class TRT_MODEL_CONVERSION_BASE:
|
|
| 463 |
dynamic_shapes=dynamic_shapes,
|
| 464 |
# NOTE:
|
| 465 |
# - We intentionally do NOT pass dynamic_axes here.
|
| 466 |
-
# dynamic_axes is for the legacy TorchScript exporter
|
| 467 |
-
# dynamic_shapes + dynamo=True is the modern path.
|
| 468 |
)
|
| 469 |
trtlog("torch.onnx.export completed successfully.")
|
| 470 |
except Exception as e:
|
|
@@ -828,6 +883,7 @@ class STATIC_TRT_MODEL_CONVERSION(TRT_MODEL_CONVERSION_BASE):
|
|
| 828 |
context_opt,
|
| 829 |
num_video_frames,
|
| 830 |
):
|
|
|
|
| 831 |
return super()._convert(
|
| 832 |
model,
|
| 833 |
filename_prefix,
|
|
|
|
| 29 |
# Opset handling:
|
| 30 |
# - If COMFY_TRT_ONNX_OPSET is set, use that integer.
|
| 31 |
# - Otherwise, leave opset_version=None so torch.onnx uses the
|
| 32 |
+
# recommended opset for this PyTorch version (e.g. 20 in 2.9+).
|
| 33 |
DEFAULT_ONNX_OPSET = None
|
| 34 |
_env_opset = os.getenv("COMFY_TRT_ONNX_OPSET")
|
| 35 |
if _env_opset is not None:
|
|
|
|
| 403 |
|
| 404 |
# -----------------------------------------------------------------
|
| 405 |
# Build dynamic_shapes spec for torch.export / dynamo=True
|
| 406 |
+
# - STATIC node: no dynamic_shapes at all (fully static export)
|
| 407 |
+
# - DYNAMIC node: only create Dim if max > min
|
| 408 |
# -----------------------------------------------------------------
|
| 409 |
+
dynamic_shapes = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
|
| 411 |
+
def _maybe_dim(name: str, min_v: int, max_v: int):
|
| 412 |
+
"""Create Dim only if there is real dynamism (max > min)."""
|
| 413 |
+
if max_v < min_v:
|
| 414 |
+
trtlog(
|
| 415 |
+
f"WARNING: Dim {name} has min>{max_v}>{min_v}, swapping to fix."
|
| 416 |
+
)
|
| 417 |
+
min_v, max_v = max_v, min_v
|
| 418 |
+
if max_v > min_v:
|
| 419 |
+
trtlog(f"Dim {name}: dynamic range [{min_v}, {max_v}]")
|
| 420 |
+
return Dim(name, min=min_v, max=max_v)
|
| 421 |
+
else:
|
| 422 |
+
trtlog(f"Dim {name}: static value {min_v}, not using Dim.")
|
| 423 |
+
return None
|
| 424 |
+
|
| 425 |
+
if not is_static:
|
| 426 |
+
# Only build dynamic_shapes for the DYNAMIC node
|
| 427 |
+
B = _maybe_dim("batch", batch_size_min, batch_size_max)
|
| 428 |
+
H = _maybe_dim("height", height_min // 8, height_max // 8)
|
| 429 |
+
W = _maybe_dim("width", width_min // 8, width_max // 8)
|
| 430 |
+
tokens_min = context_len_min * context_min
|
| 431 |
+
tokens_max = context_len * context_max
|
| 432 |
+
T = _maybe_dim("tokens", tokens_min, tokens_max)
|
| 433 |
+
|
| 434 |
+
dynamic_shapes = {}
|
| 435 |
+
|
| 436 |
+
# x: [B, C, H, W]
|
| 437 |
+
x_dyn = {}
|
| 438 |
+
if B is not None:
|
| 439 |
+
x_dyn[0] = B
|
| 440 |
+
if H is not None:
|
| 441 |
+
x_dyn[2] = H
|
| 442 |
+
if W is not None:
|
| 443 |
+
x_dyn[3] = W
|
| 444 |
+
if x_dyn:
|
| 445 |
+
dynamic_shapes["x"] = x_dyn
|
| 446 |
+
|
| 447 |
+
# timesteps: [B]
|
| 448 |
+
if B is not None:
|
| 449 |
+
dynamic_shapes["timesteps"] = {0: B}
|
| 450 |
+
|
| 451 |
+
# context: [B, T, context_dim]
|
| 452 |
+
ctx_dyn = {}
|
| 453 |
+
if B is not None:
|
| 454 |
+
ctx_dyn[0] = B
|
| 455 |
+
if T is not None:
|
| 456 |
+
ctx_dyn[1] = T
|
| 457 |
+
if ctx_dyn:
|
| 458 |
+
dynamic_shapes["context"] = ctx_dyn
|
| 459 |
+
|
| 460 |
+
# y: [B, y_dim]
|
| 461 |
+
if "y" in input_names and B is not None:
|
| 462 |
+
dynamic_shapes["y"] = {0: B}
|
| 463 |
+
|
| 464 |
+
# guidance: [B, ...]
|
| 465 |
+
if "guidance" in input_names and B is not None:
|
| 466 |
+
dynamic_shapes["guidance"] = {0: B}
|
| 467 |
+
|
| 468 |
+
if not dynamic_shapes:
|
| 469 |
+
trtlog(
|
| 470 |
+
"No dimensions are actually dynamic for DYNAMIC node. "
|
| 471 |
+
"Export will effectively be static."
|
| 472 |
+
)
|
| 473 |
+
dynamic_shapes = None
|
| 474 |
+
else:
|
| 475 |
+
trtlog(f"dynamic_shapes spec: {dynamic_shapes}")
|
| 476 |
+
else:
|
| 477 |
+
trtlog("STATIC node: skipping torch.export.Dim and dynamic_shapes entirely.")
|
| 478 |
|
| 479 |
# -----------------------------------------------------------------
|
| 480 |
# Build example inputs (using OPT shapes)
|
|
|
|
| 490 |
)
|
| 491 |
|
| 492 |
# -----------------------------------------------------------------
|
| 493 |
+
# ONNX export with Dynamo (dynamo=True)
|
| 494 |
+
# - For static: dynamic_shapes=None, so shapes are fully specialized.
|
| 495 |
+
# - For dynamic: dynamic_shapes guides symbolic shapes.
|
| 496 |
# -----------------------------------------------------------------
|
| 497 |
os.makedirs(os.path.dirname(output_onnx), exist_ok=True)
|
| 498 |
|
|
|
|
| 501 |
f"opset_version={DEFAULT_ONNX_OPSET}, dtype={dtype}, "
|
| 502 |
f"output={output_onnx}"
|
| 503 |
)
|
| 504 |
+
if dynamic_shapes is None:
|
| 505 |
+
trtlog("ONNX export will be STATIC (no dynamic_shapes).")
|
| 506 |
+
else:
|
| 507 |
+
trtlog("ONNX export will use dynamic_shapes (see spec above).")
|
| 508 |
|
| 509 |
try:
|
| 510 |
torch.onnx.export(
|
|
|
|
| 519 |
dynamic_shapes=dynamic_shapes,
|
| 520 |
# NOTE:
|
| 521 |
# - We intentionally do NOT pass dynamic_axes here.
|
| 522 |
+
# dynamic_axes is for the legacy TorchScript exporter.
|
|
|
|
| 523 |
)
|
| 524 |
trtlog("torch.onnx.export completed successfully.")
|
| 525 |
except Exception as e:
|
|
|
|
| 883 |
context_opt,
|
| 884 |
num_video_frames,
|
| 885 |
):
|
| 886 |
+
# STATIC: all min/opt/max are identical
|
| 887 |
return super()._convert(
|
| 888 |
model,
|
| 889 |
filename_prefix,
|