GR00T / scripts /deployment /trt_model_forward.py
yqi19's picture
add: source files (batch 3)
af83d87 verified
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorRT forward functions for GR00T N1.7 inference.
This module provides TRT-accelerated forward functions that replace the
PyTorch backbone and action head during inference.
Architecture (n17_full_pipeline mode):
Backbone: ViT (TRT) → embed_tokens + masked_scatter + get_rope_index (PyTorch)
→ LLM (TRT, with deepstack injection)
Action Head: VLLN (PyTorch) → State Encoder (TRT) → denoising loop:
[ Action Encoder (TRT) → DiT (TRT) → Action Decoder (TRT) ]
Architecture (vit_llm_only mode):
Backbone: ViT (TRT) → embed_tokens + masked_scatter + get_rope_index (PyTorch)
→ LLM (TRT, with deepstack injection)
Action Head: stays in PyTorch
Use when DiT cannot be exported with dynamic vl_seq_len (e.g. torch 2.10 / sm121).
Architecture (action_head mode):
Backbone: stays in PyTorch (Qwen3-VL)
Action Head: VLLN (PyTorch) → State Encoder (TRT) → denoising loop:
[ Action Encoder (TRT) → DiT (TRT) → Action Decoder (TRT) ]
"""
from functools import partial
import logging
import os
import sys
import torch
from transformers.feature_extraction_utils import BatchFeature
logger = logging.getLogger(__name__)
# Ensure sibling modules are importable (scripts/deployment is not a package)
_deploy_dir = os.path.dirname(os.path.abspath(__file__))
if _deploy_dir not in sys.path:
sys.path.insert(0, _deploy_dir)
from trt_torch import Engine # noqa: E402
# ============================================================
# N1.7 Backbone TRT Forward (ViT TRT + LLM TRT)
# ============================================================
def _qwen3_vit_and_scatter(self, vl_input):
"""Shared logic: ViT TRT + embed_tokens + scatter + get_rope_index.
Returns all inputs needed by either PyTorch LLM or LLM TRT engine.
These ops stay in PyTorch because they involve dynamic Python logic
(get_rope_index, masked_scatter, get_placeholder_mask).
"""
qwen_model = self.model # Qwen3VLForConditionalGeneration
inner_model = qwen_model.model # Qwen3VLModel
pixel_values = vl_input["pixel_values"]
grid_thw = vl_input["image_grid_thw"]
engine_dtype = torch.bfloat16
# --- ViT TRT Engine ---
# Detect ViT engine dtype (FP32 for accuracy or BF16 for speed)
vit_dtype = self.vit_engine.dtype_of("pixel_values")
if isinstance(pixel_values, (list, tuple)):
pv = torch.cat(pixel_values, dim=0)
else:
pv = pixel_values
if pv.dtype != vit_dtype:
pv = pv.to(vit_dtype)
self.vit_engine.set_runtime_tensor_shape("pixel_values", pv.shape)
vit_result = self.vit_engine(pv)
image_embeds = vit_result["image_embeds"]
deepstack_features = vit_result.get("deepstack_features")
# Unpack deepstack: [num_layers, N, D] → list of [N, D]
deepstack_list = []
if deepstack_features is not None and deepstack_features.numel() > 1:
deepstack_list = list(deepstack_features.unbind(0))
# --- PyTorch: embed_tokens + scatter ---
input_ids = vl_input["input_ids"]
inputs_embeds = self._embedding_layer(input_ids)
if inputs_embeds.dtype != engine_dtype:
inputs_embeds = inputs_embeds.to(engine_dtype)
if image_embeds.dtype != engine_dtype:
image_embeds = image_embeds.to(engine_dtype)
image_embeds_cat = torch.cat([image_embeds], dim=0)
image_mask, _ = inner_model.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds_cat
)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds_cat)
visual_pos_masks = image_mask[..., 0] if image_mask is not None else None
# Compute 3D position IDs (stays in PyTorch — complex Python logic)
attention_mask = vl_input["attention_mask"]
position_ids, rope_deltas = inner_model.get_rope_index(
input_ids, grid_thw, video_grid_thw=None, attention_mask=attention_mask
)
inner_model.rope_deltas = rope_deltas
image_mask_out = input_ids == self._image_token_id
backbone_attention_mask = attention_mask == 1
# transformers 4.57+ strips padding tokens before calling language_model internally.
# Apply the same stripping so TRT engine inputs match export-time captured shapes.
valid_mask = attention_mask[0] == 1 # [full_seq_len]
if not valid_mask.all():
inputs_embeds = inputs_embeds[:, valid_mask, :]
attention_mask = attention_mask[:, valid_mask]
position_ids = position_ids[:, :, valid_mask]
if visual_pos_masks is not None:
visual_pos_masks = visual_pos_masks[:, valid_mask]
image_mask_out = image_mask_out[:, valid_mask]
backbone_attention_mask = backbone_attention_mask[:, valid_mask]
return {
"inputs_embeds": inputs_embeds,
"attention_mask": attention_mask,
"position_ids": position_ids,
"visual_pos_masks": visual_pos_masks,
"deepstack_list": deepstack_list,
"image_mask_out": image_mask_out,
"backbone_attention_mask": backbone_attention_mask,
}
def qwen3_backbone_tensorrt_forward(self, vl_input):
"""Replace Qwen3Backbone.forward() with ViT TRT + PyTorch LLM.
ViT is replaced with a TRT engine. The LLM stays in PyTorch.
Used when LLM TRT engine is not available.
Args:
self: Qwen3Backbone instance (monkey-patched)
vl_input: BatchFeature with keys: input_ids, attention_mask, pixel_values, image_grid_thw
"""
self.set_frozen_modules_to_eval_mode()
keys_to_use = ["input_ids", "attention_mask", "pixel_values", "image_grid_thw"]
vl_input = {k: vl_input[k] for k in keys_to_use}
prepared = _qwen3_vit_and_scatter(self, vl_input)
qwen_model = self.model
inner_model = qwen_model.model
# LLM forward (PyTorch)
outputs = inner_model.language_model(
input_ids=None,
position_ids=prepared["position_ids"],
attention_mask=prepared["attention_mask"],
inputs_embeds=prepared["inputs_embeds"],
visual_pos_masks=prepared["visual_pos_masks"],
deepstack_visual_embeds=prepared["deepstack_list"] or None,
output_hidden_states=True,
)
return BatchFeature(
data={
"backbone_features": outputs.last_hidden_state,
"backbone_attention_mask": prepared["backbone_attention_mask"],
"image_mask": prepared["image_mask_out"],
}
)
def qwen3_backbone_llm_trt_forward(self, vl_input):
"""Replace Qwen3Backbone.forward() with PyTorch ViT + LLM TRT.
ViT stays in PyTorch. LLM is replaced with a TRT engine.
Used when ViT TRT has accuracy issues but LLM TRT is accurate.
"""
self.set_frozen_modules_to_eval_mode()
keys_to_use = ["input_ids", "attention_mask", "pixel_values", "image_grid_thw"]
vl_input = {k: vl_input[k] for k in keys_to_use}
# Run PyTorch ViT + scatter + rope (original backbone logic up to LLM)
qwen_model = self.model
inner_model = qwen_model.model
# ViT forward (PyTorch — kept for accuracy)
pixel_values = vl_input["pixel_values"]
grid_thw = vl_input["image_grid_thw"]
image_embeds_split, deepstack_image_embeds = inner_model.get_image_features(
pixel_values, grid_thw
)
# get_image_features returns a tuple of per-image tensors; concat for scatter
image_embeds = torch.cat(list(image_embeds_split), dim=0)
# Scatter image embeddings into text embeddings
input_ids = vl_input["input_ids"]
inputs_embeds = qwen_model.get_input_embeddings()(input_ids)
image_mask, _ = inner_model.get_placeholder_mask(
input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
)
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
visual_pos_masks = image_mask[..., 0] if image_mask is not None else None
deepstack_list = list(deepstack_image_embeds) if deepstack_image_embeds else []
# Compute position IDs
attention_mask = vl_input["attention_mask"]
position_ids, rope_deltas = inner_model.get_rope_index(
input_ids, grid_thw, video_grid_thw=None, attention_mask=attention_mask
)
inner_model.rope_deltas = rope_deltas
image_mask_out = input_ids == qwen_model.config.image_token_id
backbone_attention_mask = attention_mask == 1
# Strip padding tokens (transformers 4.57+)
valid_mask = attention_mask[0] == 1
if not valid_mask.all():
inputs_embeds = inputs_embeds[:, valid_mask, :]
attention_mask = attention_mask[:, valid_mask]
position_ids = position_ids[:, :, valid_mask]
if visual_pos_masks is not None:
visual_pos_masks = visual_pos_masks[:, valid_mask]
image_mask_out = image_mask_out[:, valid_mask]
backbone_attention_mask = backbone_attention_mask[:, valid_mask]
# LLM forward (TRT)
llm_float_dtype = self.llm_engine.dtype_of("inputs_embeds")
if inputs_embeds.dtype != llm_float_dtype:
inputs_embeds = inputs_embeds.to(llm_float_dtype)
if attention_mask.dtype != torch.int64:
attention_mask = attention_mask.to(torch.int64)
if position_ids.dtype != torch.int64:
position_ids = position_ids.to(torch.int64)
self.llm_engine.set_runtime_tensor_shape("inputs_embeds", inputs_embeds.shape)
self.llm_engine.set_runtime_tensor_shape("attention_mask", attention_mask.shape)
self.llm_engine.set_runtime_tensor_shape("position_ids", position_ids.shape)
llm_kwargs = {}
if visual_pos_masks is not None and deepstack_list:
self.llm_engine.set_runtime_tensor_shape("visual_pos_masks", visual_pos_masks.shape)
llm_kwargs["visual_pos_masks"] = visual_pos_masks
for i, ds in enumerate(deepstack_list):
name = f"deepstack_{i}"
if ds.dtype != llm_float_dtype:
ds = ds.to(llm_float_dtype)
self.llm_engine.set_runtime_tensor_shape(name, ds.shape)
llm_kwargs[name] = ds
backbone_features = self.llm_engine(inputs_embeds, attention_mask, position_ids, **llm_kwargs)[
"embeddings"
]
if backbone_features.dtype != torch.bfloat16:
backbone_features = backbone_features.to(torch.bfloat16)
return BatchFeature(
data={
"backbone_features": backbone_features,
"backbone_attention_mask": backbone_attention_mask,
"image_mask": image_mask_out,
}
)
def qwen3_backbone_full_trt_forward(self, vl_input):
"""Replace Qwen3Backbone.forward() with ViT TRT + LLM TRT.
Both ViT and LLM are replaced with TRT engines.
PyTorch ops kept: embed_tokens, masked_scatter, get_rope_index (lightweight).
Args:
self: Qwen3Backbone instance (monkey-patched)
vl_input: BatchFeature with keys: input_ids, attention_mask, pixel_values, image_grid_thw
"""
self.set_frozen_modules_to_eval_mode()
keys_to_use = ["input_ids", "attention_mask", "pixel_values", "image_grid_thw"]
vl_input = {k: vl_input[k] for k in keys_to_use}
prepared = _qwen3_vit_and_scatter(self, vl_input)
inputs_embeds = prepared["inputs_embeds"]
attention_mask = prepared["attention_mask"]
position_ids = prepared["position_ids"]
# Detect LLM engine's expected float dtype from its first input binding.
# Handles both BF16 engines (default) and FP32 engines gracefully.
llm_float_dtype = self.llm_engine.dtype_of("inputs_embeds")
if inputs_embeds.dtype != llm_float_dtype:
inputs_embeds = inputs_embeds.to(llm_float_dtype)
if attention_mask.dtype != torch.int64:
attention_mask = attention_mask.to(torch.int64)
if position_ids.dtype != torch.int64:
position_ids = position_ids.to(torch.int64)
# Set LLM engine input shapes
self.llm_engine.set_runtime_tensor_shape("inputs_embeds", inputs_embeds.shape)
self.llm_engine.set_runtime_tensor_shape("attention_mask", attention_mask.shape)
self.llm_engine.set_runtime_tensor_shape("position_ids", position_ids.shape)
llm_kwargs = {}
# Visual pos masks and deepstack features
visual_pos_masks = prepared["visual_pos_masks"]
deepstack_list = prepared["deepstack_list"]
if visual_pos_masks is not None and deepstack_list:
self.llm_engine.set_runtime_tensor_shape("visual_pos_masks", visual_pos_masks.shape)
llm_kwargs["visual_pos_masks"] = visual_pos_masks
for i, ds in enumerate(deepstack_list):
name = f"deepstack_{i}"
if ds.dtype != llm_float_dtype:
ds = ds.to(llm_float_dtype)
self.llm_engine.set_runtime_tensor_shape(name, ds.shape)
llm_kwargs[name] = ds
backbone_features = self.llm_engine(inputs_embeds, attention_mask, position_ids, **llm_kwargs)[
"embeddings"
]
# Cast LLM output back to BF16 — downstream (vl_self_attention, DiT) expect BF16.
if backbone_features.dtype != torch.bfloat16:
backbone_features = backbone_features.to(torch.bfloat16)
return BatchFeature(
data={
"backbone_features": backbone_features,
"backbone_attention_mask": prepared["backbone_attention_mask"],
"image_mask": prepared["image_mask_out"],
}
)
# ============================================================
# Action Head TRT Forward
# ============================================================
def action_head_tensorrt_forward(self, backbone_output, action_input, options=None):
"""Replace ActionHead.get_action() with TRT-accelerated inference.
VLLN (LayerNorm) stays in PyTorch. State Encoder, Action Encoder,
DiT, and Action Decoder are replaced with TRT engines.
N1.7 change: state is reshaped from [B, state_history_length, max_state_dim]
to [B, 1, state_history_length * max_state_dim] before the state encoder.
Args:
self: ActionHead instance (monkey-patched)
backbone_output: BatchFeature with backbone_features, backbone_attention_mask, image_mask
action_input: BatchFeature with state, embodiment_id
"""
# --- VLLN (PyTorch) + vl_self_attention (TRT if available, else PyTorch) ---
backbone_features = backbone_output.backbone_features
backbone_features = self.vlln(backbone_features)
if hasattr(self, "vl_sa_engine") and self.vl_sa_engine is not None:
engine_dtype = torch.bfloat16
if backbone_features.dtype != engine_dtype:
backbone_features = backbone_features.to(engine_dtype)
self.vl_sa_engine.set_runtime_tensor_shape("hidden_states", backbone_features.shape)
backbone_features = self.vl_sa_engine(backbone_features)["output"]
else:
backbone_features = self.vl_self_attention(backbone_features)
vl_embs = backbone_features
embodiment_id = action_input.embodiment_id
batch_size = vl_embs.shape[0]
device = vl_embs.device
engine_dtype = torch.bfloat16
# Ensure consistent dtypes
if vl_embs.dtype != engine_dtype:
vl_embs = vl_embs.to(engine_dtype)
if action_input.state.dtype != engine_dtype:
action_input.state = action_input.state.to(engine_dtype)
if embodiment_id.dtype != torch.int64:
embodiment_id = embodiment_id.to(torch.int64)
# --- State history reshape (N1.7) ---
# N1.7: state comes as [B, state_history_length, max_state_dim]
# Flatten to [B, 1, state_history_length * max_state_dim] for the encoder
state = action_input.state
if state.ndim == 3 and state.shape[1] > 1:
state = state.view(state.shape[0], 1, -1)
elif state.ndim == 3 and state.shape[1] == 1:
# Already [B, 1, dim] — state_history_length=1
pass
else:
# Unexpected shape, pass through
logger.warning(f"Unexpected state shape: {state.shape}")
# --- State Encoder TRT ---
self.state_encoder_engine.set_runtime_tensor_shape("state", state.shape)
self.state_encoder_engine.set_runtime_tensor_shape("embodiment_id", embodiment_id.shape)
state_features = self.state_encoder_engine(state, embodiment_id)["output"]
# --- Initialize actions as random noise ---
if hasattr(self, "init_actions"):
actions = self.init_actions.expand((batch_size, -1, -1))
else:
actions = torch.randn(
size=(batch_size, self.config.action_horizon, self.action_dim),
dtype=engine_dtype,
device=device,
)
num_steps = self.num_inference_timesteps
dt = 1.0 / num_steps
# --- Denoising loop ---
for t in range(num_steps):
t_cont = t / float(num_steps)
t_discretized = int(t_cont * self.num_timestep_buckets)
timesteps_tensor = torch.full(
size=(batch_size,), fill_value=t_discretized, device=device, dtype=torch.int64
)
# Action Encoder TRT
self.action_encoder_engine.set_runtime_tensor_shape("actions", actions.shape)
self.action_encoder_engine.set_runtime_tensor_shape("timesteps", timesteps_tensor.shape)
self.action_encoder_engine.set_runtime_tensor_shape("embodiment_id", embodiment_id.shape)
action_features = self.action_encoder_engine(
actions.to(engine_dtype), timesteps_tensor, embodiment_id
)["output"]
# Maybe add position embedding (stays in PyTorch)
if self.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
pos_embs = self.position_embedding(pos_ids).unsqueeze(0).to(engine_dtype)
action_features = action_features + pos_embs
# Concatenate state + action embeddings
sa_embs = torch.cat((state_features, action_features), dim=1).to(engine_dtype)
# DiT TRT
self.dit_engine.set_runtime_tensor_shape("sa_embs", sa_embs.shape)
self.dit_engine.set_runtime_tensor_shape("vl_embs", vl_embs.shape)
self.dit_engine.set_runtime_tensor_shape("timestep", timesteps_tensor.shape)
dit_kwargs = {}
if hasattr(backbone_output, "image_mask") and backbone_output.image_mask is not None:
image_mask = backbone_output.image_mask
self.dit_engine.set_runtime_tensor_shape("image_mask", image_mask.shape)
dit_kwargs["image_mask"] = image_mask
if (
hasattr(backbone_output, "backbone_attention_mask")
and backbone_output.backbone_attention_mask is not None
):
bb_mask = backbone_output.backbone_attention_mask
self.dit_engine.set_runtime_tensor_shape("backbone_attention_mask", bb_mask.shape)
dit_kwargs["backbone_attention_mask"] = bb_mask
model_output = self.dit_engine(sa_embs, vl_embs, timesteps_tensor, **dit_kwargs)["output"]
# Action Decoder TRT
self.action_decoder_engine.set_runtime_tensor_shape("model_output", model_output.shape)
self.action_decoder_engine.set_runtime_tensor_shape("embodiment_id", embodiment_id.shape)
pred = self.action_decoder_engine(model_output, embodiment_id)["output"]
pred_velocity = pred[:, -self.action_horizon :]
# Euler integration
actions = actions + dt * pred_velocity
return BatchFeature(data={"action_pred": actions})
# ============================================================
# Engine Setup
# ============================================================
def setup_tensorrt_engines(policy, trt_engine_path, mode="n17_full_pipeline"):
"""Load TRT engines, delete PyTorch modules, and monkey-patch forward methods.
Args:
policy: Gr00tPolicy instance
trt_engine_path: Path to directory containing TRT engine files
mode: 'n17_full_pipeline' (ViT TRT + LLM TRT + Action Head TRT),
'vit_llm_only' (ViT TRT + LLM TRT, Action Head in PyTorch),
'action_head' (Action Head TRT only), or 'dit_only'
"""
if mode == "n17_full_pipeline":
_setup_n17_full_pipeline(policy, trt_engine_path)
elif mode == "vit_llm_only":
_setup_vit_llm_only(policy, trt_engine_path)
elif mode == "action_head":
_setup_action_head(policy, trt_engine_path)
elif mode == "dit_only":
_setup_dit_only(policy, trt_engine_path)
else:
raise ValueError(
f"Unknown mode: {mode}. Expected 'n17_full_pipeline', 'vit_llm_only', "
f"'action_head', or 'dit_only'."
)
def _setup_n17_full_pipeline(policy, trt_engine_path):
"""Set up TRT engines for N1.7: ViT TRT + LLM TRT + Action Head TRT.
The Qwen3-VL backbone's vision encoder and text model are both replaced
with TRT engines. PyTorch ops kept: embed_tokens, masked_scatter,
get_rope_index (lightweight, <1ms).
Falls back to PyTorch LLM if llm_bf16.engine is not found.
"""
backbone = policy.model.backbone
qwen_model = backbone.model # Qwen3VLForConditionalGeneration
action_head = policy.model.action_head
# --- Backbone setup ---
# Save references needed by the TRT forward
backbone._embedding_layer = qwen_model.model.language_model.get_input_embeddings()
backbone._image_token_id = qwen_model.config.image_token_id
# Load ViT TRT engine (optional — PyTorch ViT used as fallback for accuracy)
vit_engine_path = os.path.join(trt_engine_path, "vit_bf16.engine")
use_vit_trt = os.path.exists(vit_engine_path)
if use_vit_trt:
print(f"Loading ViT engine: {vit_engine_path}")
backbone.vit_engine = Engine(vit_engine_path)
del qwen_model.model.visual
torch.cuda.empty_cache()
print(" Deleted PyTorch ViT (replaced by TRT engine)")
else:
backbone.vit_engine = None
print(f" ViT engine not found at {vit_engine_path}, keeping PyTorch ViT")
# Load LLM TRT engine (if available)
llm_engine_path = os.path.join(trt_engine_path, "llm_bf16.engine")
use_llm_trt = os.path.exists(llm_engine_path)
if use_llm_trt:
print(f"Loading LLM engine: {llm_engine_path}")
backbone.llm_engine = Engine(llm_engine_path)
# Delete PyTorch LLM layers to free GPU memory
# Keep embed_tokens (needed for token embedding before TRT)
# Keep get_rope_index via inner_model (needed for position IDs)
del qwen_model.model.language_model.layers
del qwen_model.model.language_model.norm
torch.cuda.empty_cache()
print(" Deleted PyTorch LLM layers (replaced by TRT engine)")
else:
backbone.llm_engine = None
print(f" LLM engine not found at {llm_engine_path}, using PyTorch LLM")
# Monkey-patch backbone forward
if use_vit_trt and use_llm_trt:
backbone.forward = partial(qwen3_backbone_full_trt_forward, backbone)
elif use_vit_trt and not use_llm_trt:
backbone.forward = partial(qwen3_backbone_tensorrt_forward, backbone)
elif not use_vit_trt and use_llm_trt:
# PyTorch ViT + LLM TRT (best accuracy when ViT TRT has issues)
backbone.forward = partial(qwen3_backbone_llm_trt_forward, backbone)
else:
print(" No backbone TRT engines loaded, backbone remains in PyTorch")
# --- Action head setup ---
# Load vl_self_attention TRT engine (if available)
vl_sa_engine_path = os.path.join(trt_engine_path, "vl_self_attention.engine")
if os.path.exists(vl_sa_engine_path):
print(f"Loading VL Self-Attention engine: {vl_sa_engine_path}")
action_head.vl_sa_engine = Engine(vl_sa_engine_path)
# Delete PyTorch module — TRT engine replaces it
if hasattr(action_head, "vl_self_attention"):
del action_head.vl_self_attention
torch.cuda.empty_cache()
print(" Deleted PyTorch vl_self_attention (replaced by TRT engine)")
else:
action_head.vl_sa_engine = None
print(f" VL Self-Attention engine not found at {vl_sa_engine_path}, using PyTorch")
if hasattr(action_head, "model"):
del action_head.model
if hasattr(action_head, "state_encoder"):
del action_head.state_encoder
if hasattr(action_head, "action_encoder"):
del action_head.action_encoder
if hasattr(action_head, "action_decoder"):
del action_head.action_decoder
torch.cuda.empty_cache()
assert action_head.action_dim == action_head.config.max_action_dim
print(f"Loading action head engines from: {trt_engine_path}")
action_head.state_encoder_engine = Engine(os.path.join(trt_engine_path, "state_encoder.engine"))
action_head.action_encoder_engine = Engine(
os.path.join(trt_engine_path, "action_encoder.engine")
)
action_head.dit_engine = Engine(os.path.join(trt_engine_path, "dit_bf16.engine"))
action_head.action_decoder_engine = Engine(
os.path.join(trt_engine_path, "action_decoder.engine")
)
action_head.get_action = partial(action_head_tensorrt_forward, action_head)
llm_status = "TRT" if use_llm_trt else "PyTorch"
vit_status = "TRT" if backbone.vit_engine else "PyTorch"
print("N1.7 full-pipeline TRT engines loaded.")
print(f" ViT: {vit_status} | LLM: {llm_status} | Action Head: TRT")
def _setup_vit_llm_only(policy, trt_engine_path):
"""Set up TRT engines for ViT + LLM only; action head stays in PyTorch.
Use this on platforms where DiT cannot be exported with dynamic vl_seq_len
(e.g. DGX Spark / torch 2.10 dynamo exporter bakes seq_len as static).
The backbone (ViT TRT + LLM TRT) still gets TRT acceleration; the PyTorch
action head receives the LLM embeddings at the actual runtime seq_len
without any shape constraint.
"""
backbone = policy.model.backbone
qwen_model = backbone.model # Qwen3VLForConditionalGeneration
# Save references needed by the TRT forward
backbone._embedding_layer = qwen_model.model.language_model.get_input_embeddings()
backbone._image_token_id = qwen_model.config.image_token_id
# Load ViT TRT engine
vit_engine_path = os.path.join(trt_engine_path, "vit_bf16.engine")
if not os.path.exists(vit_engine_path):
raise FileNotFoundError(
f"ViT TRT engine not found: {vit_engine_path}\n"
f"Run export_onnx_n1d7.py + build_tensorrt_engine.py first."
)
print(f"Loading ViT engine: {vit_engine_path}")
backbone.vit_engine = Engine(vit_engine_path)
del qwen_model.model.visual
torch.cuda.empty_cache()
print(" Deleted PyTorch ViT (replaced by TRT engine)")
# Load LLM TRT engine
llm_engine_path = os.path.join(trt_engine_path, "llm_bf16.engine")
if not os.path.exists(llm_engine_path):
raise FileNotFoundError(
f"LLM TRT engine not found: {llm_engine_path}\n"
f"Run export_onnx_n1d7.py + build_tensorrt_engine.py first."
)
print(f"Loading LLM engine: {llm_engine_path}")
backbone.llm_engine = Engine(llm_engine_path)
del qwen_model.model.language_model.layers
del qwen_model.model.language_model.norm
torch.cuda.empty_cache()
print(" Deleted PyTorch LLM layers (replaced by TRT engine)")
# Patch backbone forward to use ViT TRT + LLM TRT
backbone.forward = partial(qwen3_backbone_full_trt_forward, backbone)
print("vit_llm_only TRT engines loaded.")
print(" ViT: TRT | LLM: TRT | Action Head: PyTorch")
def _setup_action_head(policy, trt_engine_path):
"""Set up TRT engines for action head only (N1.7 mode).
Backbone (Qwen3-VL) stays in PyTorch. Only the 4 action head components
(State Encoder, Action Encoder, DiT, Action Decoder) are replaced with
TRT engines.
"""
action_head = policy.model.action_head
# Delete PyTorch modules that are replaced by TRT
if hasattr(action_head, "model"):
del action_head.model
if hasattr(action_head, "state_encoder"):
del action_head.state_encoder
if hasattr(action_head, "action_encoder"):
del action_head.action_encoder
if hasattr(action_head, "action_decoder"):
del action_head.action_decoder
torch.cuda.empty_cache()
# Verify action_dim consistency
assert action_head.action_dim == action_head.config.max_action_dim, (
f"action_dim mismatch: action_head.action_dim={action_head.action_dim} "
f"!= config.max_action_dim={action_head.config.max_action_dim}"
)
# Load action head TRT engines
print(f"Loading action head engines from: {trt_engine_path}")
action_head.state_encoder_engine = Engine(os.path.join(trt_engine_path, "state_encoder.engine"))
action_head.action_encoder_engine = Engine(
os.path.join(trt_engine_path, "action_encoder.engine")
)
action_head.dit_engine = Engine(os.path.join(trt_engine_path, "dit_bf16.engine"))
action_head.action_decoder_engine = Engine(
os.path.join(trt_engine_path, "action_decoder.engine")
)
# Monkey-patch: backbone.forward stays original, only action head is replaced
action_head.get_action = partial(action_head_tensorrt_forward, action_head)
print("Action head TRT engines loaded and forward method patched.")
print(" Backbone remains in PyTorch (Qwen3-VL).")
def _setup_dit_only(policy, trt_engine_path):
"""Set up TRT engine for DiT-only acceleration (backward compatible).
Only replaces the DiT model in the action head. The backbone and other
action head components remain in PyTorch.
"""
action_head = policy.model.action_head
# Delete the PyTorch DiT model
if hasattr(action_head, "model"):
del action_head.model
torch.cuda.empty_cache()
# Load DiT TRT engine
# Support both naming conventions
dit_path = os.path.join(trt_engine_path, "dit_bf16.engine")
if not os.path.exists(dit_path):
dit_path = os.path.join(trt_engine_path, "dit_model_bf16.engine")
if not os.path.exists(dit_path):
# Try the old naming convention
dit_path = os.path.join(trt_engine_path, "dit_model_bf16.trt")
print(f"Loading DiT engine: {dit_path}")
action_head.dit_engine = Engine(dit_path)
# Monkey-patch only the get_action method
# We need a simpler forward that only replaces the DiT call
@torch.no_grad()
def dit_only_get_action_with_features(
backbone_features, state_features, embodiment_id, backbone_output
):
"""get_action_with_features with DiT replaced by TRT."""
vl_embs = backbone_features
batch_size = vl_embs.shape[0]
device = vl_embs.device
engine_dtype = torch.bfloat16
actions = torch.randn(
size=(batch_size, action_head.config.action_horizon, action_head.action_dim),
dtype=vl_embs.dtype,
device=device,
)
dt = 1.0 / action_head.num_inference_timesteps
for t in range(action_head.num_inference_timesteps):
t_cont = t / float(action_head.num_inference_timesteps)
t_discretized = int(t_cont * action_head.num_timestep_buckets)
timesteps_tensor = torch.full(
size=(batch_size,), fill_value=t_discretized, device=device
)
action_features = action_head.action_encoder(actions, timesteps_tensor, embodiment_id)
if action_head.config.add_pos_embed:
pos_ids = torch.arange(action_features.shape[1], dtype=torch.long, device=device)
pos_embs = action_head.position_embedding(pos_ids).unsqueeze(0)
action_features = action_features + pos_embs
sa_embs = torch.cat((state_features, action_features), dim=1).to(engine_dtype)
# Use TRT for DiT
vl_embs_trt = vl_embs.to(engine_dtype)
timesteps_trt = timesteps_tensor.to(torch.int64)
action_head.dit_engine.set_runtime_tensor_shape("sa_embs", sa_embs.shape)
action_head.dit_engine.set_runtime_tensor_shape("vl_embs", vl_embs_trt.shape)
action_head.dit_engine.set_runtime_tensor_shape("timestep", timesteps_trt.shape)
dit_kwargs = {}
if hasattr(backbone_output, "image_mask") and backbone_output.image_mask is not None:
image_mask = backbone_output.image_mask
action_head.dit_engine.set_runtime_tensor_shape("image_mask", image_mask.shape)
dit_kwargs["image_mask"] = image_mask
if (
hasattr(backbone_output, "backbone_attention_mask")
and backbone_output.backbone_attention_mask is not None
):
bb_mask = backbone_output.backbone_attention_mask
action_head.dit_engine.set_runtime_tensor_shape(
"backbone_attention_mask", bb_mask.shape
)
dit_kwargs["backbone_attention_mask"] = bb_mask
model_output = action_head.dit_engine(
sa_embs, vl_embs_trt, timesteps_trt, **dit_kwargs
)["output"]
pred = action_head.action_decoder(model_output, embodiment_id)
pred_velocity = pred[:, -action_head.action_horizon :]
actions = actions + dt * pred_velocity
return BatchFeature(
data={
"action_pred": actions,
"backbone_features": vl_embs,
"state_features": state_features,
}
)
action_head.get_action_with_features = dit_only_get_action_with_features
print("DiT-only TRT engine loaded and forward method patched.")