File size: 34,691 Bytes
af83d87 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 | # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorRT forward functions for GR00T N1.7 inference.
This module provides TRT-accelerated forward functions that replace the
PyTorch backbone and action head during inference.
Architecture (n17_full_pipeline mode):
Backbone: ViT (TRT) → embed_tokens + masked_scatter + get_rope_index (PyTorch)
→ LLM (TRT, with deepstack injection)
Action Head: VLLN (PyTorch) → State Encoder (TRT) → denoising loop:
[ Action Encoder (TRT) → DiT (TRT) → Action Decoder (TRT) ]
Architecture (vit_llm_only mode):
Backbone: ViT (TRT) → embed_tokens + masked_scatter + get_rope_index (PyTorch)
→ LLM (TRT, with deepstack injection)
Action Head: stays in PyTorch
Use when DiT cannot be exported with dynamic vl_seq_len (e.g. torch 2.10 / sm121).
Architecture (action_head mode):
Backbone: stays in PyTorch (Qwen3-VL)
Action Head: VLLN (PyTorch) → State Encoder (TRT) → denoising loop:
[ Action Encoder (TRT) → DiT (TRT) → Action Decoder (TRT) ]
"""
from functools import partial
import logging
import os
import sys
import torch
from transformers.feature_extraction_utils import BatchFeature
logger = logging.getLogger(__name__)
# Ensure sibling modules are importable (scripts/deployment is not a package)
_deploy_dir = os.path.dirname(os.path.abspath(__file__))
if _deploy_dir not in sys.path:
sys.path.insert(0, _deploy_dir)
from trt_torch import Engine # noqa: E402
# ============================================================
# N1.7 Backbone TRT Forward (ViT TRT + LLM TRT)
# ============================================================
def _qwen3_vit_and_scatter(self, vl_input):
"""Shared logic: ViT TRT + embed_tokens + scatter + get_rope_index.
Returns all inputs needed by either PyTorch LLM or LLM TRT engine.
These ops stay in PyTorch because they involve dynamic Python logic
(get_rope_index, masked_scatter, get_placeholder_mask).
"""
qwen_model = self.model # Qwen3VLForConditionalGeneration
inner_model = qwen_model.model # Qwen3VLModel
pixel_values = vl_input["pixel_values"]
grid_thw = vl_input["image_grid_thw"]
engine_dtype = torch.bfloat16
# --- ViT TRT Engine ---
# Detect ViT engine dtype (FP32 for accuracy or BF16 for speed)
vit_dtype = self.vit_engine.dtype_of("pixel_values")
if isinstance(pixel_values, (list, tuple)):
pv = torch.cat(pixel_values, dim=0)
else:
pv = pixel_values
if pv.dtype != vit_dtype:
pv = pv.to(vit_dtype)
self.vit_engine.set_runtime_tensor_shape("pixel_values", pv.shape)
vit_result = self.vit_engine(pv)
image_embeds = vit_result["image_embeds"]
deepstack_features = vit_result.get("deepstack_features")
# Unpack deepstack: [num_layers, N, D] → list of [N, D]
deepstack_list = []
if deepstack_features is not None and deepstack_features.numel() > 1:
deepstack_list = list(deepstack_features.unbind(0))
# --- PyTorch: embed_tokens + scatter ---
input_ids = vl_input["input_ids"]
inputs_embeds = self._embedding_layer(input_ids)
if inputs_embeds.dtype != engine_dtype:
inputs_embeds = inputs_embeds.to(engine_dtype)
if image_embeds.dtype != engine_dtype:
image_embeds = image_embeds.to(engine_dtype)
image_embeds_cat = torch.cat([image_embeds], dim=0)
image_mask, _ = inner_model.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds_cat
)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds_cat)
visual_pos_masks = image_mask[..., 0] if image_mask is not None else None
# Compute 3D position IDs (stays in PyTorch — complex Python logic)
attention_mask = vl_input["attention_mask"]
position_ids, rope_deltas = inner_model.get_rope_index(
input_ids, grid_thw, video_grid_thw=None, attention_mask=attention_mask
)
inner_model.rope_deltas = rope_deltas
image_mask_out = input_ids == self._image_token_id
backbone_attention_mask = attention_mask == 1
# transformers 4.57+ strips padding tokens before calling language_model internally.
# Apply the same stripping so TRT engine inputs match export-time captured shapes.
valid_mask = attention_mask[0] == 1 # [full_seq_len]
if not valid_mask.all():
inputs_embeds = inputs_embeds[:, valid_mask, :]
attention_mask = attention_mask[:, valid_mask]
position_ids = position_ids[:, :, valid_mask]
if visual_pos_masks is not None:
visual_pos_masks = visual_pos_masks[:, valid_mask]
image_mask_out = image_mask_out[:, valid_mask]
backbone_attention_mask = backbone_attention_mask[:, valid_mask]
return {
"inputs_embeds": inputs_embeds,
"attention_mask": attention_mask,
"position_ids": position_ids,
"visual_pos_masks": visual_pos_masks,
"deepstack_list": deepstack_list,
"image_mask_out": image_mask_out,
"backbone_attention_mask": backbone_attention_mask,
}
def qwen3_backbone_tensorrt_forward(self, vl_input):
"""Replace Qwen3Backbone.forward() with ViT TRT + PyTorch LLM.
ViT is replaced with a TRT engine. The LLM stays in PyTorch.
Used when LLM TRT engine is not available.
Args:
self: Qwen3Backbone instance (monkey-patched)
vl_input: BatchFeature with keys: input_ids, attention_mask, pixel_values, image_grid_thw
"""
self.set_frozen_modules_to_eval_mode()
keys_to_use = ["input_ids", "attention_mask", "pixel_values", "image_grid_thw"]
vl_input = {k: vl_input[k] for k in keys_to_use}
prepared = _qwen3_vit_and_scatter(self, vl_input)
qwen_model = self.model
inner_model = qwen_model.model
# LLM forward (PyTorch)
outputs = inner_model.language_model(
input_ids=None,
position_ids=prepared["position_ids"],
attention_mask=prepared["attention_mask"],
inputs_embeds=prepared["inputs_embeds"],
visual_pos_masks=prepared["visual_pos_masks"],
deepstack_visual_embeds=prepared["deepstack_list"] or None,
output_hidden_states=True,
)
return BatchFeature(
data={
"backbone_features": outputs.last_hidden_state,
"backbone_attention_mask": prepared["backbone_attention_mask"],
"image_mask": prepared["image_mask_out"],
}
)
def qwen3_backbone_llm_trt_forward(self, vl_input):
"""Replace Qwen3Backbone.forward() with PyTorch ViT + LLM TRT.
ViT stays in PyTorch. LLM is replaced with a TRT engine.
Used when ViT TRT has accuracy issues but LLM TRT is accurate.
"""
self.set_frozen_modules_to_eval_mode()
keys_to_use = ["input_ids", "attention_mask", "pixel_values", "image_grid_thw"]
vl_input = {k: vl_input[k] for k in keys_to_use}
# Run PyTorch ViT + scatter + rope (original backbone logic up to LLM)
qwen_model = self.model
inner_model = qwen_model.model
# ViT forward (PyTorch — kept for accuracy)
pixel_values = vl_input["pixel_values"]
grid_thw = vl_input["image_grid_thw"]
image_embeds_split, deepstack_image_embeds = inner_model.get_image_features(
pixel_values, grid_thw
)
# get_image_features returns a tuple of per-image tensors; concat for scatter
image_embeds = torch.cat(list(image_embeds_split), dim=0)
# Scatter image embeddings into text embeddings
input_ids = vl_input["input_ids"]
inputs_embeds = qwen_model.get_input_embeddings()(input_ids)
image_mask, _ = inner_model.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
visual_pos_masks = image_mask[..., 0] if image_mask is not None else None
deepstack_list = list(deepstack_image_embeds) if deepstack_image_embeds else []
# Compute position IDs
attention_mask = vl_input["attention_mask"]
position_ids, rope_deltas = inner_model.get_rope_index(
input_ids, grid_thw, video_grid_thw=None, attention_mask=attention_mask
)
inner_model.rope_deltas = rope_deltas
image_mask_out = input_ids == qwen_model.config.image_token_id
backbone_attention_mask = attention_mask == 1
# Strip padding tokens (transformers 4.57+)
valid_mask = attention_mask[0] == 1
if not valid_mask.all():
inputs_embeds = inputs_embeds[:, valid_mask, :]
attention_mask = attention_mask[:, valid_mask]
position_ids = position_ids[:, :, valid_mask]
if visual_pos_masks is not None:
visual_pos_masks = visual_pos_masks[:, valid_mask]
image_mask_out = image_mask_out[:, valid_mask]
backbone_attention_mask = backbone_attention_mask[:, valid_mask]
# LLM forward (TRT)
llm_float_dtype = self.llm_engine.dtype_of("inputs_embeds")
if inputs_embeds.dtype != llm_float_dtype:
inputs_embeds = inputs_embeds.to(llm_float_dtype)
if attention_mask.dtype != torch.int64:
attention_mask = attention_mask.to(torch.int64)
if position_ids.dtype != torch.int64:
position_ids = position_ids.to(torch.int64)
self.llm_engine.set_runtime_tensor_shape("inputs_embeds", inputs_embeds.shape)
self.llm_engine.set_runtime_tensor_shape("attention_mask", attention_mask.shape)
self.llm_engine.set_runtime_tensor_shape("position_ids", position_ids.shape)
llm_kwargs = {}
if visual_pos_masks is not None and deepstack_list:
self.llm_engine.set_runtime_tensor_shape("visual_pos_masks", visual_pos_masks.shape)
llm_kwargs["visual_pos_masks"] = visual_pos_masks
for i, ds in enumerate(deepstack_list):
name = f"deepstack_{i}"
if ds.dtype != llm_float_dtype:
ds = ds.to(llm_float_dtype)
self.llm_engine.set_runtime_tensor_shape(name, ds.shape)
llm_kwargs[name] = ds
backbone_features = self.llm_engine(inputs_embeds, attention_mask, position_ids, **llm_kwargs)[
"embeddings"
]
if backbone_features.dtype != torch.bfloat16:
backbone_features = backbone_features.to(torch.bfloat16)
return BatchFeature(
data={
"backbone_features": backbone_features,
"backbone_attention_mask": backbone_attention_mask,
"image_mask": image_mask_out,
}
)
def qwen3_backbone_full_trt_forward(self, vl_input):
"""Replace Qwen3Backbone.forward() with ViT TRT + LLM TRT.
Both ViT and LLM are replaced with TRT engines.
PyTorch ops kept: embed_tokens, masked_scatter, get_rope_index (lightweight).
Args:
self: Qwen3Backbone instance (monkey-patched)
vl_input: BatchFeature with keys: input_ids, attention_mask, pixel_values, image_grid_thw
"""
self.set_frozen_modules_to_eval_mode()
keys_to_use = ["input_ids", "attention_mask", "pixel_values", "image_grid_thw"]
vl_input = {k: vl_input[k] for k in keys_to_use}
prepared = _qwen3_vit_and_scatter(self, vl_input)
inputs_embeds = prepared["inputs_embeds"]
attention_mask = prepared["attention_mask"]
position_ids = prepared["position_ids"]
# Detect LLM engine's expected float dtype from its first input binding.
# Handles both BF16 engines (default) and FP32 engines gracefully.
llm_float_dtype = self.llm_engine.dtype_of("inputs_embeds")
if inputs_embeds.dtype != llm_float_dtype:
inputs_embeds = inputs_embeds.to(llm_float_dtype)
if attention_mask.dtype != torch.int64:
attention_mask = attention_mask.to(torch.int64)
if position_ids.dtype != torch.int64:
position_ids = position_ids.to(torch.int64)
# Set LLM engine input shapes
self.llm_engine.set_runtime_tensor_shape("inputs_embeds", inputs_embeds.shape)
self.llm_engine.set_runtime_tensor_shape("attention_mask", attention_mask.shape)
self.llm_engine.set_runtime_tensor_shape("position_ids", position_ids.shape)
llm_kwargs = {}
# Visual pos masks and deepstack features
visual_pos_masks = prepared["visual_pos_masks"]
deepstack_list = prepared["deepstack_list"]
if visual_pos_masks is not None and deepstack_list:
self.llm_engine.set_runtime_tensor_shape("visual_pos_masks", visual_pos_masks.shape)
llm_kwargs["visual_pos_masks"] = visual_pos_masks
for i, ds in enumerate(deepstack_list):
name = f"deepstack_{i}"
if ds.dtype != llm_float_dtype:
ds = ds.to(llm_float_dtype)
self.llm_engine.set_runtime_tensor_shape(name, ds.shape)
llm_kwargs[name] = ds
backbone_features = self.llm_engine(inputs_embeds, attention_mask, position_ids, **llm_kwargs)[
"embeddings"
]
# Cast LLM output back to BF16 — downstream (vl_self_attention, DiT) expect BF16.
if backbone_features.dtype != torch.bfloat16:
backbone_features = backbone_features.to(torch.bfloat16)
return BatchFeature(
data={
"backbone_features": backbone_features,
"backbone_attention_mask": prepared["backbone_attention_mask"],
"image_mask": prepared["image_mask_out"],
}
)
# ============================================================
# Action Head TRT Forward
# ============================================================
def action_head_tensorrt_forward(self, backbone_output, action_input, options=None):
"""Replace ActionHead.get_action() with TRT-accelerated inference.
VLLN (LayerNorm) stays in PyTorch. State Encoder, Action Encoder,
DiT, and Action Decoder are replaced with TRT engines.
N1.7 change: state is reshaped from [B, state_history_length, max_state_dim]
to [B, 1, state_history_length * max_state_dim] before the state encoder.
Args:
self: ActionHead instance (monkey-patched)
backbone_output: BatchFeature with backbone_features, backbone_attention_mask, image_mask
action_input: BatchFeature with state, embodiment_id
"""
# --- VLLN (PyTorch) + vl_self_attention (TRT if available, else PyTorch) ---
backbone_features = backbone_output.backbone_features
backbone_features = self.vlln(backbone_features)
if hasattr(self, "vl_sa_engine") and self.vl_sa_engine is not None:
engine_dtype = torch.bfloat16
if backbone_features.dtype != engine_dtype:
backbone_features = backbone_features.to(engine_dtype)
self.vl_sa_engine.set_runtime_tensor_shape("hidden_states", backbone_features.shape)
backbone_features = self.vl_sa_engine(backbone_features)["output"]
else:
backbone_features = self.vl_self_attention(backbone_features)
vl_embs = backbone_features
embodiment_id = action_input.embodiment_id
batch_size = vl_embs.shape[0]
device = vl_embs.device
engine_dtype = torch.bfloat16
# Ensure consistent dtypes
if vl_embs.dtype != engine_dtype:
vl_embs = vl_embs.to(engine_dtype)
if action_input.state.dtype != engine_dtype:
action_input.state = action_input.state.to(engine_dtype)
if embodiment_id.dtype != torch.int64:
embodiment_id = embodiment_id.to(torch.int64)
# --- State history reshape (N1.7) ---
# N1.7: state comes as [B, state_history_length, max_state_dim]
# Flatten to [B, 1, state_history_length * max_state_dim] for the encoder
state = action_input.state
if state.ndim == 3 and state.shape[1] > 1:
state = state.view(state.shape[0], 1, -1)
elif state.ndim == 3 and state.shape[1] == 1:
# Already [B, 1, dim] — state_history_length=1
pass
else:
# Unexpected shape, pass through
logger.warning(f"Unexpected state shape: {state.shape}")
# --- State Encoder TRT ---
self.state_encoder_engine.set_runtime_tensor_shape("state", state.shape)
self.state_encoder_engine.set_runtime_tensor_shape("embodiment_id", embodiment_id.shape)
state_features = self.state_encoder_engine(state, embodiment_id)["output"]
# --- Initialize actions as random noise ---
if hasattr(self, "init_actions"):
actions = self.init_actions.expand((batch_size, -1, -1))
else:
actions = torch.randn(
size=(batch_size, self.config.action_horizon, self.action_dim),
dtype=engine_dtype,
device=device,
)
num_steps = self.num_inference_timesteps
dt = 1.0 / num_steps
# --- Denoising loop ---
for t in range(num_steps):
t_cont = t / float(num_steps)
t_discretized = int(t_cont * self.num_timestep_buckets)
timesteps_tensor = torch.full(
size=(batch_size,), fill_value=t_discretized, device=device, dtype=torch.int64
)
# Action Encoder TRT
self.action_encoder_engine.set_runtime_tensor_shape("actions", actions.shape)
self.action_encoder_engine.set_runtime_tensor_shape("timesteps", timesteps_tensor.shape)
self.action_encoder_engine.set_runtime_tensor_shape("embodiment_id", embodiment_id.shape)
action_features = self.action_encoder_engine(
actions.to(engine_dtype), timesteps_tensor, embodiment_id
)["output"]
# Maybe add position embedding (stays in PyTorch)
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
pos_embs = self.position_embedding(pos_ids).unsqueeze(0).to(engine_dtype)
action_features = action_features + pos_embs
# Concatenate state + action embeddings
sa_embs = torch.cat((state_features, action_features), dim=1).to(engine_dtype)
# DiT TRT
self.dit_engine.set_runtime_tensor_shape("sa_embs", sa_embs.shape)
self.dit_engine.set_runtime_tensor_shape("vl_embs", vl_embs.shape)
self.dit_engine.set_runtime_tensor_shape("timestep", timesteps_tensor.shape)
dit_kwargs = {}
if hasattr(backbone_output, "image_mask") and backbone_output.image_mask is not None:
image_mask = backbone_output.image_mask
self.dit_engine.set_runtime_tensor_shape("image_mask", image_mask.shape)
dit_kwargs["image_mask"] = image_mask
if (
hasattr(backbone_output, "backbone_attention_mask")
and backbone_output.backbone_attention_mask is not None
):
bb_mask = backbone_output.backbone_attention_mask
self.dit_engine.set_runtime_tensor_shape("backbone_attention_mask", bb_mask.shape)
dit_kwargs["backbone_attention_mask"] = bb_mask
model_output = self.dit_engine(sa_embs, vl_embs, timesteps_tensor, **dit_kwargs)["output"]
# Action Decoder TRT
self.action_decoder_engine.set_runtime_tensor_shape("model_output", model_output.shape)
self.action_decoder_engine.set_runtime_tensor_shape("embodiment_id", embodiment_id.shape)
pred = self.action_decoder_engine(model_output, embodiment_id)["output"]
pred_velocity = pred[:, -self.action_horizon :]
# Euler integration
actions = actions + dt * pred_velocity
return BatchFeature(data={"action_pred": actions})
# ============================================================
# Engine Setup
# ============================================================
def setup_tensorrt_engines(policy, trt_engine_path, mode="n17_full_pipeline"):
"""Load TRT engines, delete PyTorch modules, and monkey-patch forward methods.
Args:
policy: Gr00tPolicy instance
trt_engine_path: Path to directory containing TRT engine files
mode: 'n17_full_pipeline' (ViT TRT + LLM TRT + Action Head TRT),
'vit_llm_only' (ViT TRT + LLM TRT, Action Head in PyTorch),
'action_head' (Action Head TRT only), or 'dit_only'
"""
if mode == "n17_full_pipeline":
_setup_n17_full_pipeline(policy, trt_engine_path)
elif mode == "vit_llm_only":
_setup_vit_llm_only(policy, trt_engine_path)
elif mode == "action_head":
_setup_action_head(policy, trt_engine_path)
elif mode == "dit_only":
_setup_dit_only(policy, trt_engine_path)
else:
raise ValueError(
f"Unknown mode: {mode}. Expected 'n17_full_pipeline', 'vit_llm_only', "
f"'action_head', or 'dit_only'."
)
def _setup_n17_full_pipeline(policy, trt_engine_path):
"""Set up TRT engines for N1.7: ViT TRT + LLM TRT + Action Head TRT.
The Qwen3-VL backbone's vision encoder and text model are both replaced
with TRT engines. PyTorch ops kept: embed_tokens, masked_scatter,
get_rope_index (lightweight, <1ms).
Falls back to PyTorch LLM if llm_bf16.engine is not found.
"""
backbone = policy.model.backbone
qwen_model = backbone.model # Qwen3VLForConditionalGeneration
action_head = policy.model.action_head
# --- Backbone setup ---
# Save references needed by the TRT forward
backbone._embedding_layer = qwen_model.model.language_model.get_input_embeddings()
backbone._image_token_id = qwen_model.config.image_token_id
# Load ViT TRT engine (optional — PyTorch ViT used as fallback for accuracy)
vit_engine_path = os.path.join(trt_engine_path, "vit_bf16.engine")
use_vit_trt = os.path.exists(vit_engine_path)
if use_vit_trt:
print(f"Loading ViT engine: {vit_engine_path}")
backbone.vit_engine = Engine(vit_engine_path)
del qwen_model.model.visual
torch.cuda.empty_cache()
print(" Deleted PyTorch ViT (replaced by TRT engine)")
else:
backbone.vit_engine = None
print(f" ViT engine not found at {vit_engine_path}, keeping PyTorch ViT")
# Load LLM TRT engine (if available)
llm_engine_path = os.path.join(trt_engine_path, "llm_bf16.engine")
use_llm_trt = os.path.exists(llm_engine_path)
if use_llm_trt:
print(f"Loading LLM engine: {llm_engine_path}")
backbone.llm_engine = Engine(llm_engine_path)
# Delete PyTorch LLM layers to free GPU memory
# Keep embed_tokens (needed for token embedding before TRT)
# Keep get_rope_index via inner_model (needed for position IDs)
del qwen_model.model.language_model.layers
del qwen_model.model.language_model.norm
torch.cuda.empty_cache()
print(" Deleted PyTorch LLM layers (replaced by TRT engine)")
else:
backbone.llm_engine = None
print(f" LLM engine not found at {llm_engine_path}, using PyTorch LLM")
# Monkey-patch backbone forward
if use_vit_trt and use_llm_trt:
backbone.forward = partial(qwen3_backbone_full_trt_forward, backbone)
elif use_vit_trt and not use_llm_trt:
backbone.forward = partial(qwen3_backbone_tensorrt_forward, backbone)
elif not use_vit_trt and use_llm_trt:
# PyTorch ViT + LLM TRT (best accuracy when ViT TRT has issues)
backbone.forward = partial(qwen3_backbone_llm_trt_forward, backbone)
else:
print(" No backbone TRT engines loaded, backbone remains in PyTorch")
# --- Action head setup ---
# Load vl_self_attention TRT engine (if available)
vl_sa_engine_path = os.path.join(trt_engine_path, "vl_self_attention.engine")
if os.path.exists(vl_sa_engine_path):
print(f"Loading VL Self-Attention engine: {vl_sa_engine_path}")
action_head.vl_sa_engine = Engine(vl_sa_engine_path)
# Delete PyTorch module — TRT engine replaces it
if hasattr(action_head, "vl_self_attention"):
del action_head.vl_self_attention
torch.cuda.empty_cache()
print(" Deleted PyTorch vl_self_attention (replaced by TRT engine)")
else:
action_head.vl_sa_engine = None
print(f" VL Self-Attention engine not found at {vl_sa_engine_path}, using PyTorch")
if hasattr(action_head, "model"):
del action_head.model
if hasattr(action_head, "state_encoder"):
del action_head.state_encoder
if hasattr(action_head, "action_encoder"):
del action_head.action_encoder
if hasattr(action_head, "action_decoder"):
del action_head.action_decoder
torch.cuda.empty_cache()
assert action_head.action_dim == action_head.config.max_action_dim
print(f"Loading action head engines from: {trt_engine_path}")
action_head.state_encoder_engine = Engine(os.path.join(trt_engine_path, "state_encoder.engine"))
action_head.action_encoder_engine = Engine(
os.path.join(trt_engine_path, "action_encoder.engine")
)
action_head.dit_engine = Engine(os.path.join(trt_engine_path, "dit_bf16.engine"))
action_head.action_decoder_engine = Engine(
os.path.join(trt_engine_path, "action_decoder.engine")
)
action_head.get_action = partial(action_head_tensorrt_forward, action_head)
llm_status = "TRT" if use_llm_trt else "PyTorch"
vit_status = "TRT" if backbone.vit_engine else "PyTorch"
print("N1.7 full-pipeline TRT engines loaded.")
print(f" ViT: {vit_status} | LLM: {llm_status} | Action Head: TRT")
def _setup_vit_llm_only(policy, trt_engine_path):
"""Set up TRT engines for ViT + LLM only; action head stays in PyTorch.
Use this on platforms where DiT cannot be exported with dynamic vl_seq_len
(e.g. DGX Spark / torch 2.10 dynamo exporter bakes seq_len as static).
The backbone (ViT TRT + LLM TRT) still gets TRT acceleration; the PyTorch
action head receives the LLM embeddings at the actual runtime seq_len
without any shape constraint.
"""
backbone = policy.model.backbone
qwen_model = backbone.model # Qwen3VLForConditionalGeneration
# Save references needed by the TRT forward
backbone._embedding_layer = qwen_model.model.language_model.get_input_embeddings()
backbone._image_token_id = qwen_model.config.image_token_id
# Load ViT TRT engine
vit_engine_path = os.path.join(trt_engine_path, "vit_bf16.engine")
if not os.path.exists(vit_engine_path):
raise FileNotFoundError(
f"ViT TRT engine not found: {vit_engine_path}\n"
f"Run export_onnx_n1d7.py + build_tensorrt_engine.py first."
)
print(f"Loading ViT engine: {vit_engine_path}")
backbone.vit_engine = Engine(vit_engine_path)
del qwen_model.model.visual
torch.cuda.empty_cache()
print(" Deleted PyTorch ViT (replaced by TRT engine)")
# Load LLM TRT engine
llm_engine_path = os.path.join(trt_engine_path, "llm_bf16.engine")
if not os.path.exists(llm_engine_path):
raise FileNotFoundError(
f"LLM TRT engine not found: {llm_engine_path}\n"
f"Run export_onnx_n1d7.py + build_tensorrt_engine.py first."
)
print(f"Loading LLM engine: {llm_engine_path}")
backbone.llm_engine = Engine(llm_engine_path)
del qwen_model.model.language_model.layers
del qwen_model.model.language_model.norm
torch.cuda.empty_cache()
print(" Deleted PyTorch LLM layers (replaced by TRT engine)")
# Patch backbone forward to use ViT TRT + LLM TRT
backbone.forward = partial(qwen3_backbone_full_trt_forward, backbone)
print("vit_llm_only TRT engines loaded.")
print(" ViT: TRT | LLM: TRT | Action Head: PyTorch")
def _setup_action_head(policy, trt_engine_path):
"""Set up TRT engines for action head only (N1.7 mode).
Backbone (Qwen3-VL) stays in PyTorch. Only the 4 action head components
(State Encoder, Action Encoder, DiT, Action Decoder) are replaced with
TRT engines.
"""
action_head = policy.model.action_head
# Delete PyTorch modules that are replaced by TRT
if hasattr(action_head, "model"):
del action_head.model
if hasattr(action_head, "state_encoder"):
del action_head.state_encoder
if hasattr(action_head, "action_encoder"):
del action_head.action_encoder
if hasattr(action_head, "action_decoder"):
del action_head.action_decoder
torch.cuda.empty_cache()
# Verify action_dim consistency
assert action_head.action_dim == action_head.config.max_action_dim, (
f"action_dim mismatch: action_head.action_dim={action_head.action_dim} "
f"!= config.max_action_dim={action_head.config.max_action_dim}"
)
# Load action head TRT engines
print(f"Loading action head engines from: {trt_engine_path}")
action_head.state_encoder_engine = Engine(os.path.join(trt_engine_path, "state_encoder.engine"))
action_head.action_encoder_engine = Engine(
os.path.join(trt_engine_path, "action_encoder.engine")
)
action_head.dit_engine = Engine(os.path.join(trt_engine_path, "dit_bf16.engine"))
action_head.action_decoder_engine = Engine(
os.path.join(trt_engine_path, "action_decoder.engine")
)
# Monkey-patch: backbone.forward stays original, only action head is replaced
action_head.get_action = partial(action_head_tensorrt_forward, action_head)
print("Action head TRT engines loaded and forward method patched.")
print(" Backbone remains in PyTorch (Qwen3-VL).")
def _setup_dit_only(policy, trt_engine_path):
"""Set up TRT engine for DiT-only acceleration (backward compatible).
Only replaces the DiT model in the action head. The backbone and other
action head components remain in PyTorch.
"""
action_head = policy.model.action_head
# Delete the PyTorch DiT model
if hasattr(action_head, "model"):
del action_head.model
torch.cuda.empty_cache()
# Load DiT TRT engine
# Support both naming conventions
dit_path = os.path.join(trt_engine_path, "dit_bf16.engine")
if not os.path.exists(dit_path):
dit_path = os.path.join(trt_engine_path, "dit_model_bf16.engine")
if not os.path.exists(dit_path):
# Try the old naming convention
dit_path = os.path.join(trt_engine_path, "dit_model_bf16.trt")
print(f"Loading DiT engine: {dit_path}")
action_head.dit_engine = Engine(dit_path)
# Monkey-patch only the get_action method
# We need a simpler forward that only replaces the DiT call
@torch.no_grad()
def dit_only_get_action_with_features(
backbone_features, state_features, embodiment_id, backbone_output
):
"""get_action_with_features with DiT replaced by TRT."""
vl_embs = backbone_features
batch_size = vl_embs.shape[0]
device = vl_embs.device
engine_dtype = torch.bfloat16
actions = torch.randn(
size=(batch_size, action_head.config.action_horizon, action_head.action_dim),
dtype=vl_embs.dtype,
device=device,
)
dt = 1.0 / action_head.num_inference_timesteps
for t in range(action_head.num_inference_timesteps):
t_cont = t / float(action_head.num_inference_timesteps)
t_discretized = int(t_cont * action_head.num_timestep_buckets)
timesteps_tensor = torch.full(
size=(batch_size,), fill_value=t_discretized, device=device
)
action_features = action_head.action_encoder(actions, timesteps_tensor, embodiment_id)
if action_head.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
pos_embs = action_head.position_embedding(pos_ids).unsqueeze(0)
action_features = action_features + pos_embs
sa_embs = torch.cat((state_features, action_features), dim=1).to(engine_dtype)
# Use TRT for DiT
vl_embs_trt = vl_embs.to(engine_dtype)
timesteps_trt = timesteps_tensor.to(torch.int64)
action_head.dit_engine.set_runtime_tensor_shape("sa_embs", sa_embs.shape)
action_head.dit_engine.set_runtime_tensor_shape("vl_embs", vl_embs_trt.shape)
action_head.dit_engine.set_runtime_tensor_shape("timestep", timesteps_trt.shape)
dit_kwargs = {}
if hasattr(backbone_output, "image_mask") and backbone_output.image_mask is not None:
image_mask = backbone_output.image_mask
action_head.dit_engine.set_runtime_tensor_shape("image_mask", image_mask.shape)
dit_kwargs["image_mask"] = image_mask
if (
hasattr(backbone_output, "backbone_attention_mask")
and backbone_output.backbone_attention_mask is not None
):
bb_mask = backbone_output.backbone_attention_mask
action_head.dit_engine.set_runtime_tensor_shape(
"backbone_attention_mask", bb_mask.shape
)
dit_kwargs["backbone_attention_mask"] = bb_mask
model_output = action_head.dit_engine(
sa_embs, vl_embs_trt, timesteps_trt, **dit_kwargs
)["output"]
pred = action_head.action_decoder(model_output, embodiment_id)
pred_velocity = pred[:, -action_head.action_horizon :]
actions = actions + dt * pred_velocity
return BatchFeature(
data={
"action_pred": actions,
"backbone_features": vl_embs,
"state_features": state_features,
}
)
action_head.get_action_with_features = dit_only_get_action_with_features
print("DiT-only TRT engine loaded and forward method patched.")
|