Qwen3-VL-Embedding-2B-onnx / export_script /b_export_onnx_vision.py
jordaniel29's picture
Upload folder using huggingface_hub
df65075 verified
"""
Direct HF export of the Qwen3-VL vision encoder to ONNX.
Instead of manually reimplementing the vision encoder (with norm fusion,
GELU replacement, etc.), this script wraps the HF model's own
Qwen3VLVisionModel and exports it via torch.onnx.export.
This produces a Vision.onnx that is numerically identical to PyTorch
because it traces through the exact same code path.
Usage:
python 2_export_onnx_vision.py
python 2_export_onnx_vision.py --model_path /path/to/model --output_dir /path/to/output
This ONLY re-exports Vision.onnx. The Embed.onnx, Transformer.onnx,
and rotary_params.npz from export_embedding_onnx.py are reused as-is.
"""
import math
import os
import gc
import argparse
import torch
import onnx
import numpy as np
from qwen3_vl_embedding import Qwen3VLForEmbedding
def consolidate_external_data(onnx_path: str):
"""
Re-save an ONNX model so all external weights live in one .onnx.data file
instead of hundreds of per-tensor files.
"""
out_dir = os.path.dirname(onnx_path)
base = os.path.basename(onnx_path)
data_rel = base + ".data" # sibling file: Transformer.onnx.data
model = onnx.load(onnx_path, load_external_data=True)
# Delete ONNX per-tensor external files (e.g.
# `model.model.language_model.layers.0.self_attn.qk_norm_weight`). They can
# be named with dots, so we use an allow-list of known suffixes we want to
# keep. Subdirectories (e.g. `tokenizer/`) are skipped via the isfile check.
PRESERVE_SUFFIXES = (".onnx", ".onnx.data", ".npz", ".bin")
for fname in list(os.listdir(out_dir)):
fpath = os.path.join(out_dir, fname)
if not os.path.isfile(fpath):
continue
if any(fname.endswith(sfx) for sfx in PRESERVE_SUFFIXES):
continue
try:
os.remove(fpath)
except OSError:
pass
# Remove the target .data file if it already exists (stale)
target_data = os.path.join(out_dir, data_rel)
if os.path.exists(target_data):
os.remove(target_data)
onnx.save_model(
model, onnx_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=data_rel,
size_threshold=1024,
convert_attribute=False,
)
del model
gc.collect()
# ══════════════════════════════════════════════════════════════════════════════
# Configuration (must match export_embedding_onnx.py)
# ══════════════════════════════════════════════════════════════════════════════
QWEN3VLE_MODEL_HF_PATH = os.getenv(
"MODEL_QWEN3VLE_MODEL_HF_PATH", "/home/jordan/Research/Product-AI-mono/assets/model/Qwen3-VL-Embedding-2B"
)
QWEN3VLE_TRT_DIR_PATH = os.getenv(
"MODEL_QWEN3VLE_TRT_DIR_PATH", "assets/model/Qwen3-VL-Embedding-2B-onnx"
)
TEMPORAL_SIZE=1
DEFAULT_MODEL_PATH = QWEN3VLE_MODEL_HF_PATH
DEFAULT_OUTPUT_DIR = QWEN3VLE_TRT_DIR_PATH
OPSET = 17
# ══════════════════════════════════════════════════════════════════════════════
# Vision Wrapper β€” thin shell around the HF model
# ══════════════════════════════════════════════════════════════════════════════
class VisionEncoderWrapper(torch.nn.Module):
"""
Wraps Qwen3VLVisionModel for ONNX export.
The HF vision model's forward() takes (hidden_states, grid_thw) and
returns (merged_features, deepstack_features_list).
For ONNX we bake in a fixed grid_thw and return flat outputs.
"""
def __init__(self, visual, grid_thw: torch.Tensor):
super().__init__()
self.visual = visual
# Bake in the fixed grid as a buffer so it's part of the ONNX graph
self.register_buffer("grid_thw", grid_thw)
def forward(self, pixel_values):
"""
pixel_values : float32 [total_patches, flatten_dim]
From the HF processor (already CLIP-normalised).
Returns: (deepstack_feature_0, …, vision_hidden_states)
"""
hidden_states, deepstack_features = self.visual(
pixel_values, grid_thw=self.grid_thw
)
# Return deepstack features first, then merged vision features
return *deepstack_features, hidden_states
# ══════════════════════════════════════════════════════════════════════════════
# Export
# ══════════════════════════════════════════════════════════════════════════════
def main(model_path, output_dir):
os.makedirs(output_dir, exist_ok=True)
onnx_vision = os.path.join(output_dir, "Vision.onnx")
# ── Load rotary_params to get the grid config ────────────────────────
rp = np.load(os.path.join(output_dir, "rotary_params.npz"))
height_factor = int(rp["height_factor"])
width_factor = int(rp["width_factor"])
patch_size = int(rp["patch_size"])
merge_size = int(rp["merge_size"])
grid_h = height_factor * merge_size
grid_w = width_factor * merge_size
image_height = int(rp["image_height"])
image_width = int(rp["image_width"])
temporal_patch_size_npz = int(rp["temporal_patch_size"]) if "temporal_patch_size" in rp else 2
temporal_patches = math.ceil(TEMPORAL_SIZE / temporal_patch_size_npz)
total_patches = temporal_patches * grid_h * grid_w
print(f"Grid: {temporal_patches}x{grid_h}x{grid_w} = {total_patches} patches")
print(f"Image: {image_height}x{image_width} TEMPORAL_SIZE={TEMPORAL_SIZE}")
# ── Load model ──────────────────────────────────────────────────────
print("Loading model …")
model = Qwen3VLForEmbedding.from_pretrained(
model_path,
torch_dtype=torch.float32,
device_map="cpu",
low_cpu_mem_usage=True,
_attn_implementation="eager",
).eval()
visual = model.model.visual
temporal_patch_size = visual.patch_embed.temporal_patch_size
flatten_dim = 3 * temporal_patch_size * patch_size * patch_size
deepstack_features_len = len(visual.deepstack_visual_indexes)
print(f" deepstack features: {deepstack_features_len}")
print(f" temporal_patch_size: {temporal_patch_size}")
print(f" flatten_dim: {flatten_dim}")
# ── Build wrapper ───────────────────────────────────────────────────
grid_thw = torch.tensor([[temporal_patches, grid_h, grid_w]], dtype=torch.int64)
wrapper = VisionEncoderWrapper(visual, grid_thw)
# ── Build dummy input ───────────────────────────────────────────────
dummy_pixels = torch.randn(total_patches, flatten_dim, dtype=torch.float32)
# ── Output names ────────────────────────────────────────────────────
output_names = []
for i in range(deepstack_features_len):
output_names.append(f"deepstack_feature_{i}")
output_names.append("vision_hidden_states")
# ── Export ──────────────────────────────────────────────────────────
print("Exporting Vision (direct HF) …")
with torch.inference_mode():
torch.onnx.export(
wrapper,
(dummy_pixels,),
onnx_vision,
input_names=["pixel_values"],
output_names=output_names,
opset_version=OPSET,
dynamo=False,
)
del wrapper, model
gc.collect()
print(f" β†’ {onnx_vision}")
print("Consolidating external weights …")
consolidate_external_data(onnx_vision)
print("Done.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Export Qwen3-VL vision encoder to ONNX (direct HF)"
)
parser.add_argument(
"--model_path", type=str, default=DEFAULT_MODEL_PATH,
help=f"Path to the HF model (default: {DEFAULT_MODEL_PATH})",
)
parser.add_argument(
"--output_dir", type=str, default=DEFAULT_OUTPUT_DIR,
help=f"Directory to save ONNX files (default: {DEFAULT_OUTPUT_DIR})",
)
args = parser.parse_args()
main(args.model_path, args.output_dir)