File size: 18,341 Bytes
af83d87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
#!/usr/bin/env python3

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Build TensorRT engines from exported ONNX models.

Supports two modes:
- single: Build engine for a single ONNX model
- full_pipeline: Build engines for all pipeline components
  (ViT, LLM, State Encoder, Action Encoder, DiT, Action Decoder)

Shape profiles are automatically derived from the ONNX models.

Usage:
    # Full pipeline:
    python scripts/deployment/build_tensorrt_engine.py \
        --mode full_pipeline \
        --onnx-dir ./gr00t_n1d7_onnx \
        --engine-dir ./gr00t_n1d7_engines \
        --precision bf16
"""

from dataclasses import dataclass
import json
import logging
import os
import time
from typing import Literal

import onnx
import tensorrt as trt
import tyro


# Set up logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)


# ============================================================
# Auto Shape Profile from ONNX
# ============================================================


def derive_shapes_from_onnx(onnx_path, max_batch=8):
    """Read an ONNX model and derive min/opt/max shape profiles.

    For each input:
    - Fixed dimensions (concrete values) are kept as-is across min/opt/max.
    - Dynamic batch dimension: min=1, opt=1, max=max_batch.
    - Dynamic sequence dimensions: min=1, opt=concrete_value, max=2*concrete_value.
      (concrete_value comes from the ONNX model's shape hints)

    Returns (min_shapes, opt_shapes, max_shapes) dicts.
    """
    model = onnx.load(onnx_path, load_external_data=False)

    min_shapes, opt_shapes, max_shapes = {}, {}, {}

    for inp in model.graph.input:
        name = inp.name
        dims = inp.type.tensor_type.shape.dim

        min_shape, opt_shape, max_shape = [], [], []
        for i, d in enumerate(dims):
            if d.dim_value > 0:
                # Fixed dimension — use as-is
                min_shape.append(d.dim_value)
                opt_shape.append(d.dim_value)
                max_shape.append(d.dim_value)
            else:
                # Dynamic dimension
                if i == 0:
                    # Batch dimension
                    min_shape.append(1)
                    opt_shape.append(1)
                    max_shape.append(max_batch)
                else:
                    # Sequence/spatial dimension — use generous range
                    # We don't know the "typical" value from ONNX alone,
                    # so use 1 / 1 / large_max. The builder will optimize for opt.
                    min_shape.append(1)
                    opt_shape.append(1)
                    max_shape.append(512)

        min_shapes[name] = tuple(min_shape)
        opt_shapes[name] = tuple(opt_shape)
        max_shapes[name] = tuple(max_shape)

    return min_shapes, opt_shapes, max_shapes


def derive_shapes_with_hint(onnx_path, opt_seq_lens=None, max_batch=8):
    """Derive shapes from ONNX, with optional sequence length hints.

    Args:
        onnx_path: Path to ONNX model
        opt_seq_lens: Dict mapping dynamic dim names to optimal sequence lengths.
                      e.g. {"sa_seq_len": 51, "vl_seq_len": 280, "sequence_length": 280}
        max_batch: Maximum batch size
    """
    model = onnx.load(onnx_path, load_external_data=False)
    opt_seq_lens = opt_seq_lens or {}

    min_shapes, opt_shapes, max_shapes = {}, {}, {}

    for inp in model.graph.input:
        name = inp.name
        dims = inp.type.tensor_type.shape.dim

        min_shape, opt_shape, max_shape = [], [], []
        for i, d in enumerate(dims):
            if d.dim_value > 0:
                # Fixed dimension
                min_shape.append(d.dim_value)
                opt_shape.append(d.dim_value)
                max_shape.append(d.dim_value)
            else:
                dim_name = d.dim_param if d.dim_param else f"dim_{i}"
                if dim_name == "batch_size":
                    # Batch dimension (at any index)
                    min_shape.append(1)
                    opt_shape.append(1)
                    max_shape.append(max_batch)
                elif dim_name in opt_seq_lens:
                    # Named dynamic dim with a hint
                    opt_val = opt_seq_lens[dim_name]
                    min_shape.append(1)
                    opt_shape.append(opt_val)
                    max_shape.append(max(opt_val * 2, opt_val + 64))
                else:
                    # Unknown dynamic dim — use wide range
                    min_shape.append(1)
                    opt_shape.append(256)
                    max_shape.append(512)

        min_shapes[name] = tuple(min_shape)
        opt_shapes[name] = tuple(opt_shape)
        max_shapes[name] = tuple(max_shape)

    return min_shapes, opt_shapes, max_shapes


# ============================================================
# Engine Builder
# ============================================================


def build_engine(
    onnx_path: str,
    engine_path: str,
    precision: str = "bf16",
    workspace_mb: int = 8192,
    min_shapes: dict = None,
    opt_shapes: dict = None,
    max_shapes: dict = None,
    trt_severity=None,
):
    """Build TensorRT engine from ONNX model.

    Args:
        onnx_path: Path to ONNX model
        engine_path: Path to save TensorRT engine
        precision: Precision mode ('fp32', 'fp16', 'bf16', 'fp8')
        workspace_mb: Workspace size in MB
        min_shapes: Minimum input shapes (dict: name -> shape tuple)
        opt_shapes: Optimal input shapes (dict: name -> shape tuple)
        max_shapes: Maximum input shapes (dict: name -> shape tuple)
    """
    logger.info("=" * 80)
    logger.info("TensorRT Engine Builder")
    logger.info("=" * 80)
    logger.info(f"ONNX model: {onnx_path}")
    logger.info(f"Engine output: {engine_path}")
    logger.info(f"Precision: {precision.upper()}")
    logger.info(f"Workspace: {workspace_mb} MB")
    logger.info("=" * 80)

    TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE if trt_severity is None else trt_severity)

    # Create builder and network
    logger.info("\n[Step 1/5] Creating TensorRT builder...")
    builder = trt.Builder(TRT_LOGGER)

    # Use STRONGLY_TYPED network when available (TRT 10.x+).
    # With STRONGLY_TYPED, tensor types are inferred from the ONNX model and
    # TRT won't silently change precision. EXPLICIT_BATCH is deprecated in TRT 10.x.
    use_strongly_typed = hasattr(trt.NetworkDefinitionCreationFlag, "STRONGLY_TYPED")
    if use_strongly_typed:
        network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
        logger.info("Using STRONGLY_TYPED network (TRT 10.x+)")
    else:
        network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
        logger.info("Using EXPLICIT_BATCH network (TRT 9.x fallback)")
    network = builder.create_network(network_flags)
    parser = trt.OnnxParser(network, TRT_LOGGER)

    # Parse ONNX model
    logger.info("\n[Step 2/5] Parsing ONNX model...")
    if not parser.parse_from_file(onnx_path):
        logger.error("Failed to parse ONNX file")
        for error in range(parser.num_errors):
            logger.error(parser.get_error(error))
        raise RuntimeError("ONNX parsing failed")

    logger.info(f"Network inputs: {network.num_inputs}")
    for i in range(network.num_inputs):
        inp = network.get_input(i)
        logger.info(f"  Input {i}: {inp.name} {inp.shape}")

    logger.info(f"Network outputs: {network.num_outputs}")
    for i in range(network.num_outputs):
        out = network.get_output(i)
        logger.info(f"  Output {i}: {out.name} {out.shape}")

    # Create builder config
    logger.info("\n[Step 3/5] Configuring builder...")
    config = builder.create_builder_config()

    config.profiling_verbosity = trt.ProfilingVerbosity.DETAILED
    logger.info("Enabled DETAILED profiling verbosity for engine inspection")

    config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace_mb * (1024**2))

    if use_strongly_typed:
        # With STRONGLY_TYPED, precision comes from the ONNX model's tensor types.
        # No need to set BF16/FP16 builder flags — they're implicit in the model.
        # For FP8, the Q/DQ nodes in the ONNX model dictate FP8 layers.
        logger.info(
            f"Precision '{precision}' enforced by STRONGLY_TYPED network (types from ONNX model)"
        )
    else:
        # Weak-typed fallback: explicitly set precision flags
        if precision == "fp16":
            config.set_flag(trt.BuilderFlag.FP16)
            logger.info("Enabled FP16 mode")
        elif precision == "bf16":
            config.set_flag(trt.BuilderFlag.BF16)
            logger.info("Enabled BF16 mode")
        elif precision == "fp8":
            config.set_flag(trt.BuilderFlag.FP8)
            config.set_flag(trt.BuilderFlag.BF16)
            logger.info("Enabled FP8 + BF16 mode")
        elif precision == "fp32":
            logger.info("Using FP32 (default precision)")
        else:
            raise ValueError(f"Unknown precision: {precision}")

    # Set optimization profiles for dynamic shapes
    if min_shapes and opt_shapes and max_shapes:
        logger.info("\n[Step 4/5] Setting optimization profiles...")
        profile = builder.create_optimization_profile()

        for i in range(network.num_inputs):
            inp = network.get_input(i)
            input_name = inp.name

            if input_name in min_shapes:
                min_shape = min_shapes[input_name]
                opt_shape = opt_shapes[input_name]
                max_shape = max_shapes[input_name]

                profile.set_shape(input_name, min_shape, opt_shape, max_shape)
                logger.info(f"  {input_name}:")
                logger.info(f"    min: {min_shape}")
                logger.info(f"    opt: {opt_shape}")
                logger.info(f"    max: {max_shape}")

        config.add_optimization_profile(profile)
    else:
        raise RuntimeError("Provide min/max and opt shapes for dynamic axes")

    # Build engine
    logger.info("\n[Step 5/5] Building TensorRT engine...")

    start_time = time.time()
    serialized_engine = builder.build_serialized_network(network, config)
    build_time = time.time() - start_time

    if serialized_engine is None:
        raise RuntimeError("Failed to build TensorRT engine")

    logger.info(f"Engine built in {build_time:.1f} seconds ({build_time / 60:.1f} minutes)")

    # Save engine
    logger.info(f"\nSaving engine to {engine_path}...")
    os.makedirs(os.path.dirname(engine_path) or ".", exist_ok=True)
    with open(engine_path, "wb") as f:
        f.write(serialized_engine)

    engine_size_mb = os.path.getsize(engine_path) / (1024**2)
    logger.info(f"Engine saved! Size: {engine_size_mb:.2f} MB")

    logger.info("\n" + "=" * 80)
    logger.info("ENGINE BUILD COMPLETE!")
    logger.info("=" * 80)
    logger.info(f"Engine file: {engine_path}")
    logger.info(f"Size: {engine_size_mb:.2f} MB")
    logger.info(f"Build time: {build_time:.1f}s")
    logger.info(f"Precision: {precision.upper()}")
    logger.info("=" * 80)

    return engine_path


# ============================================================
# Full Pipeline Builder
# ============================================================


def build_full_pipeline(
    onnx_dir, engine_dir, precision="bf16", workspace_mb=8192, trt_severity=None
):
    """Build all TRT engines for the full pipeline.

    Shape profiles are automatically derived from the ONNX models.
    Dynamic sequence dimensions use hints based on typical inference shapes.

    Args:
        onnx_dir: Directory containing exported ONNX models
        engine_dir: Directory to save TRT engines
        precision: Precision mode
        workspace_mb: Workspace size in MB
    """
    os.makedirs(engine_dir, exist_ok=True)

    # Load sequence length hints from export metadata if available,
    # otherwise fall back to hardcoded defaults for GR1 single-view.
    metadata_path = os.path.join(onnx_dir, "export_metadata.json")
    if os.path.exists(metadata_path):
        with open(metadata_path) as f:
            metadata = json.load(f)
        seq_hints = {
            "sa_seq_len": metadata["sa_seq_len"],
            "vl_seq_len": metadata["vl_seq_len"],
            "sequence_length": metadata["llm_seq_len"],
            "seq_len": metadata["llm_seq_len"],  # N1.7 LLM dynamic dim name
            "num_patches": metadata.get("num_patches", 256),
            "num_merged_patches": metadata.get("num_merged_patches", 64),
            "num_vis_tokens": metadata.get("num_vis_tokens", 64),  # N1.7 deepstack
        }
        logger.info(f"Loaded shape hints from {metadata_path}: {seq_hints}")
    else:
        seq_hints = {
            "sa_seq_len": 51,  # 1 state + action_horizon
            "vl_seq_len": 280,  # typical backbone output seq_len
            "sequence_length": 280,  # LLM seq_len
        }
        logger.warning(
            f"No export_metadata.json found in {onnx_dir}, using default hints: {seq_hints}"
        )

    # Components: (name, onnx_file, engine_file)
    components = [
        # FP32 ViT preferred for accuracy; falls back to BF16 if only bf16 was exported.
        (
            "ViT",
            "vit_fp32.onnx"
            if os.path.exists(os.path.join(onnx_dir, "vit_fp32.onnx"))
            else "vit_bf16.onnx",
            "vit_bf16.engine",
        ),
        ("LLM", "llm_bf16.onnx", "llm_bf16.engine"),
        ("VL Self-Attention", "vl_self_attention.onnx", "vl_self_attention.engine"),
        ("State Encoder", "state_encoder.onnx", "state_encoder.engine"),
        ("Action Encoder", "action_encoder.onnx", "action_encoder.engine"),
        ("DiT", "dit_bf16.onnx", "dit_bf16.engine"),
        ("Action Decoder", "action_decoder.onnx", "action_decoder.engine"),
    ]

    results = []

    for name, onnx_file, engine_file in components:
        onnx_path = os.path.join(onnx_dir, onnx_file)

        if not os.path.exists(onnx_path):
            logger.warning(f"Skipping {name}: ONNX file not found at {onnx_path}")
            continue

        logger.info(f"\n{'#' * 80}")
        logger.info(f"# Building {name} engine")
        logger.info(f"{'#' * 80}")

        engine_path = os.path.join(engine_dir, engine_file)

        try:
            # Derive shapes from the ONNX model itself
            min_shapes, opt_shapes, max_shapes = derive_shapes_with_hint(
                onnx_path, opt_seq_lens=seq_hints
            )

            logger.info(f"  Auto-derived shape profiles for {name}:")
            for input_name in opt_shapes:
                logger.info(
                    f"    {input_name}: min={min_shapes[input_name]} "
                    f"opt={opt_shapes[input_name]} max={max_shapes[input_name]}"
                )

            build_engine(
                onnx_path=onnx_path,
                engine_path=engine_path,
                precision=precision,
                workspace_mb=workspace_mb,
                min_shapes=min_shapes,
                opt_shapes=opt_shapes,
                max_shapes=max_shapes,
                trt_severity=trt_severity,
            )
            results.append((name, engine_path, "SUCCESS"))
        except Exception as e:
            logger.error(f"Failed to build {name} engine: {e}")
            results.append((name, engine_path, f"FAILED: {e}"))

    # Print summary
    logger.info("\n" + "=" * 80)
    logger.info("FULL PIPELINE BUILD SUMMARY")
    logger.info("=" * 80)
    for name, path, status in results:
        logger.info(f"  {name:20s} -> {status}")
    logger.info("=" * 80)


# ============================================================
# Main
# ============================================================


@dataclass
class BuildConfig:
    """Configuration for building TensorRT engines from ONNX models."""

    mode: Literal["single", "full_pipeline"] = "single"
    """Build mode: 'single' (one engine) or 'full_pipeline' (all engines)."""

    onnx: str | None = None
    """Path to ONNX model (single mode)."""

    engine: str | None = None
    """Path to save TensorRT engine (single mode)."""

    onnx_dir: str = "./gr00t_n1d7_onnx"
    """Directory with ONNX models (full_pipeline mode)."""

    engine_dir: str = "./gr00t_n1d7_engines"
    """Directory to save engines (full_pipeline mode)."""

    precision: Literal["fp32", "fp16", "bf16", "fp8"] = "bf16"
    """Precision mode (default: bf16)."""

    workspace: int = 8192
    """Workspace size in MB (default: 8192)."""


def main(args: BuildConfig | None = None, trt_severity=None):
    if args is None:
        args = tyro.cli(BuildConfig)

    if args.mode == "full_pipeline":
        build_full_pipeline(
            onnx_dir=args.onnx_dir,
            engine_dir=args.engine_dir,
            precision=args.precision,
            workspace_mb=args.workspace,
            trt_severity=trt_severity,
        )
    else:
        if not args.onnx or not args.engine:
            raise ValueError("--onnx and --engine are required in single mode")

        # Auto-derive shapes from the ONNX model
        min_shapes, opt_shapes, max_shapes = derive_shapes_with_hint(args.onnx)
        build_engine(
            onnx_path=args.onnx,
            engine_path=args.engine,
            precision=args.precision,
            workspace_mb=args.workspace,
            min_shapes=min_shapes,
            opt_shapes=opt_shapes,
            max_shapes=max_shapes,
            trt_severity=trt_severity,
        )


if __name__ == "__main__":
    config = tyro.cli(BuildConfig)
    main(config)