File size: 25,851 Bytes
af83d87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
#!/usr/bin/env python3

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Benchmark script for GR00T inference timing.

Measures component-wise timing for:
- Data Processing: VLAStepData preparation and collation
- Backbone (VLM): Qwen3-VL forward pass
- Action Head (DiT): Flow-matching diffusion model
- E2E: Full end-to-end inference

Supports five inference modes:
1. PyTorch Eager: Standard PyTorch execution
2. torch.compile: PyTorch 2.0+ JIT compilation with max-autotune
3. TensorRT (DiT-only): Optimized DiT action head using TensorRT engine
4. TensorRT (Full Pipeline): All 6 components in TRT (ViT + LLM + Action Head)
5. TensorRT (vit_llm_only): ViT + LLM in TRT, action head in PyTorch (use on Spark/sm121)

Usage:
    # Basic benchmark (Eager + torch.compile)
    python scripts/deployment/benchmark_inference.py \
        --model-path checkpoints/GR00T-N1.7-LIBERO/libero_10

    # With DiT-only TRT
    python scripts/deployment/benchmark_inference.py \
        --model-path checkpoints/GR00T-N1.7-LIBERO/libero_10 \
        --trt-engine-path ./gr00t_n1d7_onnx/dit_model_bf16.trt

    # With full-pipeline TRT (6 engines)
    python scripts/deployment/benchmark_inference.py \
        --model-path checkpoints/GR00T-N1.7-LIBERO/libero_10 \
        --trt-engine-path ./gr00t_n1d7_engines \
        --trt-mode n17_full_pipeline
"""

from dataclasses import dataclass
import gc
import os
import sys
import time
from typing import Literal

import gr00t
from gr00t.data.dataset.lerobot_episode_loader import LeRobotEpisodeLoader
from gr00t.data.dataset.sharded_single_step_dataset import extract_step_data
from gr00t.data.embodiment_tags import EmbodimentTag
from gr00t.data.types import MessageType, VLAStepData
from gr00t.policy.gr00t_policy import Gr00tPolicy
import numpy as np
import torch
import tyro


# Ensure scripts/deployment/ is on sys.path for sibling module imports
_DEPLOY_DIR = os.path.dirname(os.path.abspath(__file__))
if _DEPLOY_DIR not in sys.path:
    sys.path.insert(0, _DEPLOY_DIR)


def set_seed(seed: int = 42):
    """Set random seed for reproducibility."""
    import random

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def _rec_to_dtype(x, dtype):
    """Recursively convert all floating point tensors to the given dtype."""
    if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
        return x.to(dtype=dtype)
    elif isinstance(x, dict) or hasattr(x, "items"):
        return {k: _rec_to_dtype(v, dtype) for k, v in x.items()}
    elif isinstance(x, list):
        return [_rec_to_dtype(v, dtype) for v in x]
    else:
        return x


def prepare_model_inputs(policy, observation, return_states=False):
    """
    Prepare inputs for the model, mimicking what happens inside _get_action.
    Returns collated_inputs that can be passed to model.get_action()

    Args:
        policy: The Gr00tPolicy instance
        observation: Dict with "video", "state", "language" keys
        return_states: If True, also return the states list (for action denormalization)

    Returns:
        collated_inputs if return_states=False, else (collated_inputs, states)
    """
    unbatched_obs = []
    batch_size = observation["video"][list(observation["video"].keys())[0]].shape[0]
    for i in range(batch_size):
        unbatched_value = {
            "video": {k: v[i] for k, v in observation["video"].items()},
            "state": {k: v[i] for k, v in observation["state"].items()},
            "language": {k: v[i] for k, v in observation["language"].items()},
        }
        unbatched_obs.append(unbatched_value)

    processed_inputs = []
    states = []
    for obs in unbatched_obs:
        vla_step_data = VLAStepData(
            images=obs["video"],
            states=obs["state"],
            actions={},
            text=obs["language"][policy.language_key][0],
            embodiment=policy.embodiment_tag,
        )
        states.append(vla_step_data.states)
        messages = [{"type": MessageType.EPISODE_STEP.value, "content": vla_step_data}]
        processed_inputs.append(policy.processor(messages))

    collated_inputs = policy.collate_fn(processed_inputs)
    collated_inputs = collated_inputs["inputs"]
    collated_inputs = _rec_to_dtype(collated_inputs, dtype=torch.bfloat16)

    if return_states:
        return collated_inputs, states
    return collated_inputs


def get_device_name():
    """Get short device name for table."""
    if torch.cuda.is_available():
        name = torch.cuda.get_device_name(0)
        # Shorten common names
        if "H100" in name:
            return "H100"
        elif "A100" in name:
            return "A100"
        elif "RTX 5090" in name:
            return "RTX 5090"
        elif "RTX 4090" in name:
            return "RTX 4090"
        elif "RTX 3090" in name:
            return "RTX 3090"
        elif "Orin" in name:
            return "Jetson Orin"
        else:
            # Return first meaningful part
            return name.split()[1] if len(name.split()) > 1 else name
    return "CPU"


def compute_e2e_from_components(components):
    """Compute E2E timing as sum of components (more stable than separate measurement)."""
    return components["data_processing"] + components["backbone"] + components["action_head"]


def benchmark_data_processing(policy, observation, num_iterations=20, warmup=10):
    """
    Benchmark data processing separately with proper warmup.
    Data processing is CPU-bound and needs more warmup iterations.

    Args:
        policy: The Gr00tPolicy instance
        observation: Either a single observation dict OR a list of observation dicts (trajectory)
        num_iterations: Number of benchmark iterations
        warmup: Number of warmup iterations

    If observation is a list (trajectory), cycles through observations during benchmarking.
    """
    import gc

    # Handle both single observation and trajectory (list of observations)
    if isinstance(observation, list):
        observations = observation
    else:
        observations = [observation]

    num_obs = len(observations)

    # Force GC before warmup to reduce variance
    gc.collect()

    # Warmup - helps with CPU caching and JIT for consistent benchmarks
    # For trajectory mode, warmup benefit is reduced since each observation is different
    if warmup > 0:
        for i in range(warmup):
            obs = observations[i % num_obs]
            _ = prepare_model_inputs(policy, obs)
        # Force GC after warmup
        gc.collect()

    # Benchmark
    times = []
    for i in range(num_iterations):
        obs = observations[i % num_obs]
        start = time.perf_counter()
        _ = prepare_model_inputs(policy, obs)
        end = time.perf_counter()
        times.append(end - start)

    return np.array(times) * 1000


def benchmark_components(policy, observation, num_iterations=20, warmup=3):
    """
    Benchmark component-wise timing.
    Returns dict with times for: data_processing, backbone, action_head

    Args:
        policy: The Gr00tPolicy instance
        observation: Either a single observation dict OR a list of observation dicts (trajectory)
        num_iterations: Number of benchmark iterations
        warmup: Number of warmup iterations

    If observation is a list (trajectory), cycles through observations during benchmarking.
    """
    import gc

    # Handle both single observation and trajectory (list of observations)
    if isinstance(observation, list):
        observations = observation
    else:
        observations = [observation]

    num_obs = len(observations)

    # Prepare inputs once for backbone/action_head warmup
    collated_inputs = prepare_model_inputs(policy, observations[0])

    # Warmup backbone + action head
    for i in range(warmup):
        obs = observations[i % num_obs]
        collated_inputs = prepare_model_inputs(policy, obs)
        with torch.inference_mode():
            backbone_inputs, action_inputs = policy.model.prepare_input(collated_inputs)
            backbone_outputs = policy.model.backbone(backbone_inputs)
            _ = policy.model.action_head.get_action(backbone_outputs, action_inputs)
    torch.cuda.synchronize()

    # Force GC before timing
    gc.collect()

    # Benchmark backbone and action head (GPU-bound)
    backbone_times = []
    action_head_times = []

    for i in range(num_iterations):
        obs = observations[i % num_obs]
        collated_inputs = prepare_model_inputs(policy, obs)

        # Backbone timing
        torch.cuda.synchronize()
        start = time.perf_counter()
        with torch.inference_mode():
            backbone_inputs, action_inputs = policy.model.prepare_input(collated_inputs)
            backbone_outputs = policy.model.backbone(backbone_inputs)
        torch.cuda.synchronize()
        end = time.perf_counter()
        backbone_times.append(end - start)

        # Action head timing
        torch.cuda.synchronize()
        start = time.perf_counter()
        with torch.inference_mode():
            _ = policy.model.action_head.get_action(backbone_outputs, action_inputs)
        torch.cuda.synchronize()
        end = time.perf_counter()
        action_head_times.append(end - start)

    # Benchmark data processing separately with proper warmup
    data_processing_times = benchmark_data_processing(
        policy, observation, num_iterations, warmup=10
    )

    return {
        "data_processing": data_processing_times,
        "backbone": np.array(backbone_times) * 1000,
        "action_head": np.array(action_head_times) * 1000,
    }


def print_markdown_table(results, device_name, denoising_steps):
    """Print results as a markdown table using median for robustness."""
    print("\n" + "=" * 100)
    print("MARKDOWN TABLE (copy/paste into README)")
    print("=" * 100)
    print(f"\nGR00T N1.7 Inference Timing ({denoising_steps} denoising steps):\n")

    # Component breakdown table (using median for robustness against outliers)
    print("### Component-wise Breakdown\n")
    print("| Device | Mode | Data Processing | Backbone | Action Head | E2E | Frequency |")
    print("|--------|------|-----------------|----------|-------------|-----|-----------|")

    for mode, data in results.items():
        dp_median = np.median(data["data_processing"])
        bb_median = np.median(data["backbone"])
        ah_median = np.median(data["action_head"])
        e2e_median = np.median(data["e2e"])
        freq = 1000 / e2e_median
        print(
            f"| {device_name} | {mode} | {dp_median:.0f} ms | {bb_median:.0f} ms | {ah_median:.0f} ms | {e2e_median:.0f} ms | {freq:.1f} Hz |"
        )

    # Speedup table
    if "PyTorch Eager" in results and len(results) > 1:
        print("\n### Speedup vs PyTorch Eager\n")
        print("| Device | Mode | E2E Speedup | Action Head Speedup |")
        print("|--------|------|-------------|---------------------|")

        baseline_e2e = np.median(results["PyTorch Eager"]["e2e"])
        baseline_ah = np.median(results["PyTorch Eager"]["action_head"])

        for mode, data in results.items():
            e2e_median = np.median(data["e2e"])
            ah_median = np.median(data["action_head"])
            e2e_speedup = baseline_e2e / e2e_median
            ah_speedup = baseline_ah / ah_median
            print(f"| {device_name} | {mode} | {e2e_speedup:.2f}x | {ah_speedup:.2f}x |")

    print("\n" + "=" * 100)


@dataclass
class BenchmarkConfig:
    """Configuration for GR00T inference benchmarking."""

    model_path: str = "checkpoints/GR00T-N1.7-LIBERO/libero_10"
    """Path to model checkpoint (local path, e.g. checkpoints/GR00T-N1.7-LIBERO/libero_10)."""

    dataset_path: str | None = None
    """Path to dataset. Defaults to demo_data/libero_demo."""

    embodiment_tag: str = "libero_sim"
    """Embodiment tag to use."""

    trt_engine_path: str | None = None
    """Path to TensorRT engine. If not provided, TensorRT benchmark is skipped."""

    num_iterations: int = 20
    """Number of benchmark iterations."""

    warmup: int = 5
    """Number of warmup iterations."""

    seed: int = 42
    """Random seed for reproducibility."""

    trt_mode: Literal["dit_only", "n17_full_pipeline", "vit_llm_only"] = "dit_only"
    """TRT mode: 'dit_only' (DiT engine only), 'n17_full_pipeline' (all 6 engines), or 'vit_llm_only' (ViT+LLM TRT, action head in PyTorch — use on Spark/sm121)."""

    skip_compile: bool = False
    """Skip torch.compile benchmark (can take a while due to JIT compilation)."""

    batch_size: int = 1
    """Batch size for TRT inference. Must match the batch size used during ONNX export."""

    use_trajectory: bool = False
    """Benchmark on full trajectory instead of single data point. This cycles through all steps in an episode for more realistic benchmarking."""


def main(args: BenchmarkConfig | None = None):
    if args is None:
        args = tyro.cli(BenchmarkConfig)

    set_seed(args.seed)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    if device == "cpu":
        print("ERROR: No CUDA GPU detected. Benchmarking requires a GPU.")
        sys.exit(1)
    device_name = get_device_name()

    # Default dataset path
    if args.dataset_path is None:
        repo_path = os.path.dirname(os.path.dirname(gr00t.__file__))
        args.dataset_path = os.path.join(repo_path, "demo_data/libero_demo")

    print("=" * 100)
    print("GR00T INFERENCE BENCHMARK")
    print("=" * 100)
    print(
        f"Device: {device_name} ({torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'})"
    )
    print(f"Model: {args.model_path}")
    print(f"Dataset: {args.dataset_path}")
    print(f"Iterations: {args.num_iterations}")
    print(f"Warmup: {args.warmup}")
    print(f"Use Trajectory: {args.use_trajectory}")
    print()

    # Load dataset and prepare observation
    print("Loading policy...")
    policy = Gr00tPolicy(
        model_path=args.model_path,
        embodiment_tag=EmbodimentTag.resolve(args.embodiment_tag),
        device=device,
        strict=True,
    )

    denoising_steps = policy.model.action_head.num_inference_timesteps
    action_horizon = policy.model.action_head.action_horizon
    print(f"Action Horizon: {action_horizon}")
    print(f"Denoising Steps: {denoising_steps}")

    modality_config = policy.get_modality_config()
    dataset = LeRobotEpisodeLoader(
        dataset_path=args.dataset_path,
        modality_configs=modality_config,
        video_backend="torchcodec",
    )

    episode_data = dataset[0]

    if args.use_trajectory:
        # Load all steps from the episode for trajectory-based benchmarking
        # episode_data is a pandas DataFrame, so len() gives the number of steps
        trajectory_length = len(episode_data)

        observations = []
        for step_idx in range(trajectory_length):
            try:
                step_data = extract_step_data(
                    episode_data,
                    step_index=step_idx,
                    modality_configs=modality_config,
                    embodiment_tag=EmbodimentTag.resolve(args.embodiment_tag),
                    allow_padding=False,
                )
                obs = {
                    "video": {k: np.stack(step_data.images[k])[None] for k in step_data.images},
                    "state": {k: step_data.states[k][None] for k in step_data.states},
                    "language": {modality_config["language"].modality_keys[0]: [[step_data.text]]},
                }
                observations.append(obs)
            except Exception:
                # Stop if we can't extract more steps (e.g., due to video frame requirements)
                break

        print(f"Loaded trajectory with {len(observations)} steps")
        observation = observations  # Pass list to benchmark functions
    else:
        step_data = extract_step_data(
            episode_data,
            step_index=0,
            modality_configs=modality_config,
            embodiment_tag=EmbodimentTag.resolve(args.embodiment_tag),
            allow_padding=False,
        )

        observation = {
            "video": {k: np.stack(step_data.images[k])[None] for k in step_data.images},
            "state": {k: step_data.states[k][None] for k in step_data.states},
            "language": {modality_config["language"].modality_keys[0]: [[step_data.text]]},
        }

    # Tile observations to match the batch size baked into TRT engines
    if args.batch_size > 1:
        from verify_n1d7_trt import _tile_observation

        print(f"Tiling observations to batch_size={args.batch_size}")
        if isinstance(observation, list):
            observation = [_tile_observation(obs, args.batch_size) for obs in observation]
        else:
            observation = _tile_observation(observation, args.batch_size)

    results = {}

    # ========================================
    # 0. Benchmark Data Processing (shared across all modes)
    # ========================================
    # Data processing is the same regardless of inference mode (PyTorch/compile/TensorRT)
    # so we benchmark it once with proper warmup to get consistent measurements
    print("\n" + "-" * 50)
    print("Benchmarking Data Processing (shared across all modes)...")
    print("-" * 50)

    shared_data_processing_times = benchmark_data_processing(
        policy, observation, args.num_iterations, warmup=10
    )
    print(
        f"  Data Processing: {np.mean(shared_data_processing_times):.2f} ± {np.std(shared_data_processing_times):.2f} ms"
    )

    # ========================================
    # 1. PyTorch Eager
    # ========================================
    print("\n" + "-" * 50)
    print("Benchmarking PyTorch Eager...")
    print("-" * 50)

    times_components = benchmark_components(policy, observation, args.num_iterations, args.warmup)

    components = {
        "data_processing": shared_data_processing_times,
        "backbone": times_components["backbone"],
        "action_head": times_components["action_head"],
    }
    components["e2e"] = compute_e2e_from_components(components)
    results["PyTorch Eager"] = components

    e2e_median = np.median(components["e2e"])
    print(f"  E2E:             {e2e_median:.0f} ms ({1000 / e2e_median:.1f} Hz)")
    print(f"  Data Processing: {np.median(components['data_processing']):.0f} ms")
    print(f"  Backbone:        {np.median(components['backbone']):.0f} ms")
    print(f"  Action Head:     {np.median(components['action_head']):.0f} ms")

    # ========================================
    # 2. torch.compile
    # ========================================
    if not args.skip_compile:
        print("\n" + "-" * 50)
        print("Benchmarking torch.compile (mode='max-autotune')...")
        print("(This may take a while due to JIT compilation on first run)")
        print("-" * 50)

        # Free the eager policy before loading the compiled instance to avoid
        # holding two full model copies in GPU memory simultaneously (OOM on Orin).
        del policy
        torch.cuda.empty_cache()
        gc.collect()

        policy_compiled = Gr00tPolicy(
            model_path=args.model_path,
            embodiment_tag=EmbodimentTag.resolve(args.embodiment_tag),
            device=device,
            strict=True,
        )
        policy_compiled.model.action_head.model.forward = torch.compile(
            policy_compiled.model.action_head.model.forward, mode="max-autotune"
        )

        # Enable cuDNN benchmark for additional optimization
        if torch.cuda.is_available():
            torch.backends.cudnn.benchmark = True

        # Extra warmup for torch.compile JIT
        times_components = benchmark_components(
            policy_compiled, observation, args.num_iterations, warmup=args.warmup + 2
        )

        components = {
            "data_processing": shared_data_processing_times,
            "backbone": times_components["backbone"],
            "action_head": times_components["action_head"],
        }
        components["e2e"] = compute_e2e_from_components(components)
        results["torch.compile"] = components

        e2e_median = np.median(components["e2e"])
        print(f"  E2E:             {e2e_median:.0f} ms ({1000 / e2e_median:.1f} Hz)")
        print(f"  Data Processing: {np.median(components['data_processing']):.0f} ms")
        print(f"  Backbone:        {np.median(components['backbone']):.0f} ms")
        print(f"  Action Head:     {np.median(components['action_head']):.0f} ms")

    # ========================================
    # 3. TensorRT (if available)
    # ========================================
    if args.trt_engine_path and os.path.exists(args.trt_engine_path):
        trt_label = f"TensorRT ({args.trt_mode})"
        print("\n" + "-" * 50)
        print(f"Benchmarking {trt_label}...")
        print("-" * 50)

        # Free whichever policy is still in memory before loading TRT, to avoid
        # holding two full model copies in GPU memory simultaneously (OOM on Orin).
        try:
            del policy_compiled
        except NameError:
            pass
        try:
            del policy
        except NameError:
            pass
        torch.cuda.empty_cache()
        gc.collect()

        sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
        from standalone_inference_script import replace_dit_with_tensorrt

        policy_trt = Gr00tPolicy(
            model_path=args.model_path,
            embodiment_tag=EmbodimentTag.resolve(args.embodiment_tag),
            device=device,
            strict=True,
        )

        if args.trt_mode in ("n17_full_pipeline", "vit_llm_only"):
            from trt_model_forward import setup_tensorrt_engines

            setup_tensorrt_engines(policy_trt, args.trt_engine_path, mode=args.trt_mode)
        else:
            from standalone_inference_script import replace_dit_with_tensorrt

            # dit_only mode: trt_engine_path may be a directory; resolve the engine file
            dit_engine_path = args.trt_engine_path
            if os.path.isdir(dit_engine_path):
                dit_engine_path = os.path.join(dit_engine_path, "dit_bf16.engine")
            replace_dit_with_tensorrt(policy_trt, dit_engine_path)

        # TensorRT needs extra warmup for engine initialization and CUDA context setup
        trt_warmup = max(args.warmup + 5, 10)
        times_components = benchmark_components(
            policy_trt, observation, args.num_iterations, warmup=trt_warmup
        )

        components = {
            "data_processing": shared_data_processing_times,
            "backbone": times_components["backbone"],
            "action_head": times_components["action_head"],
        }
        components["e2e"] = compute_e2e_from_components(components)
        results[trt_label] = components

        e2e_median = np.median(components["e2e"])
        print(f"  E2E:             {e2e_median:.0f} ms ({1000 / e2e_median:.1f} Hz)")
        print(f"  Data Processing: {np.median(components['data_processing']):.0f} ms")
        print(f"  Backbone:        {np.median(components['backbone']):.0f} ms")
        print(f"  Action Head:     {np.median(components['action_head']):.0f} ms")
    elif args.trt_engine_path:
        print(f"\nTensorRT engine not found: {args.trt_engine_path}")
        print("To build engines for full pipeline, run:")
        print(
            "  python scripts/deployment/export_onnx_n1d7.py --model-path checkpoints/GR00T-N1.7-LIBERO/libero_10"
            " --dataset-path demo_data/libero_demo --output-dir ./gr00t_n1d7_onnx --export-mode full_pipeline"
        )
        print(
            "  python scripts/deployment/build_tensorrt_engine.py --mode full_pipeline"
            " --onnx-dir ./gr00t_n1d7_onnx --engine-dir ./gr00t_n1d7_engines --precision bf16"
        )

    # ========================================
    # Print Summary Tables
    # ========================================
    print_markdown_table(results, device_name, denoising_steps)

    # Detailed summary
    print("\n" + "=" * 100)
    print("DETAILED SUMMARY")
    print("=" * 100)
    print(f"\nHardware: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
    print(f"Model: {args.model_path}")
    print(f"Action Horizon: {action_horizon}")
    print(f"Denoising Steps: {denoising_steps}")

    for mode, data in results.items():
        print(f"\n{mode}:")
        e2e = data["e2e"]
        print(
            f"  E2E:             median={np.median(e2e):.1f} ms, mean={np.mean(e2e):.1f} ± {np.std(e2e):.1f} ms, "
            f"min={np.min(e2e):.1f}, max={np.max(e2e):.1f} ({1000 / np.median(e2e):.1f} Hz)"
        )
        print(f"  Data Processing: {np.median(data['data_processing']):.2f} ms (median)")
        print(f"  Backbone:        {np.median(data['backbone']):.2f} ms (median)")
        print(f"  Action Head:     {np.median(data['action_head']):.2f} ms (median)")

    print("\n" + "=" * 100)


if __name__ == "__main__":
    config = tyro.cli(BenchmarkConfig)
    main(config)