| """ |
| Direct HF export of the Qwen3-VL vision encoder to ONNX. |
| |
| Instead of manually reimplementing the vision encoder (with norm fusion, |
| GELU replacement, etc.), this script wraps the HF model's own |
| Qwen3VLVisionModel and exports it via torch.onnx.export. |
| |
| This produces a Vision.onnx that is numerically identical to PyTorch |
| because it traces through the exact same code path. |
| |
| Usage: |
| python 2_export_onnx_vision.py |
| python 2_export_onnx_vision.py --model_path /path/to/model --output_dir /path/to/output |
| |
| This ONLY re-exports Vision.onnx. The Embed.onnx, Transformer.onnx, |
| and rotary_params.npz from export_embedding_onnx.py are reused as-is. |
| """ |
|
|
| import math |
| import os |
| import gc |
| import argparse |
| import torch |
| import onnx |
| import numpy as np |
| from qwen3_vl_embedding import Qwen3VLForEmbedding |
|
|
| def consolidate_external_data(onnx_path: str): |
| """ |
| Re-save an ONNX model so all external weights live in one .onnx.data file |
| instead of hundreds of per-tensor files. |
| """ |
| out_dir = os.path.dirname(onnx_path) |
| base = os.path.basename(onnx_path) |
| data_rel = base + ".data" |
| model = onnx.load(onnx_path, load_external_data=True) |
|
|
| |
| |
| |
| |
| PRESERVE_SUFFIXES = (".onnx", ".onnx.data", ".npz", ".bin") |
| for fname in list(os.listdir(out_dir)): |
| fpath = os.path.join(out_dir, fname) |
| if not os.path.isfile(fpath): |
| continue |
| if any(fname.endswith(sfx) for sfx in PRESERVE_SUFFIXES): |
| continue |
| try: |
| os.remove(fpath) |
| except OSError: |
| pass |
|
|
| |
| target_data = os.path.join(out_dir, data_rel) |
| if os.path.exists(target_data): |
| os.remove(target_data) |
|
|
| onnx.save_model( |
| model, onnx_path, |
| save_as_external_data=True, |
| all_tensors_to_one_file=True, |
| location=data_rel, |
| size_threshold=1024, |
| convert_attribute=False, |
| ) |
| del model |
| gc.collect() |
|
|
| |
| |
| |
| QWEN3VLE_MODEL_HF_PATH = os.getenv( |
| "MODEL_QWEN3VLE_MODEL_HF_PATH", "/home/jordan/Research/Product-AI-mono/assets/model/Qwen3-VL-Embedding-2B" |
| ) |
| QWEN3VLE_TRT_DIR_PATH = os.getenv( |
| "MODEL_QWEN3VLE_TRT_DIR_PATH", "assets/model/Qwen3-VL-Embedding-2B-onnx" |
| ) |
| TEMPORAL_SIZE=1 |
|
|
| DEFAULT_MODEL_PATH = QWEN3VLE_MODEL_HF_PATH |
| DEFAULT_OUTPUT_DIR = QWEN3VLE_TRT_DIR_PATH |
| OPSET = 17 |
|
|
|
|
| |
| |
| |
| class VisionEncoderWrapper(torch.nn.Module): |
| """ |
| Wraps Qwen3VLVisionModel for ONNX export. |
| |
| The HF vision model's forward() takes (hidden_states, grid_thw) and |
| returns (merged_features, deepstack_features_list). |
| |
| For ONNX we bake in a fixed grid_thw and return flat outputs. |
| """ |
|
|
| def __init__(self, visual, grid_thw: torch.Tensor): |
| super().__init__() |
| self.visual = visual |
| |
| self.register_buffer("grid_thw", grid_thw) |
|
|
| def forward(self, pixel_values): |
| """ |
| pixel_values : float32 [total_patches, flatten_dim] |
| From the HF processor (already CLIP-normalised). |
| |
| Returns: (deepstack_feature_0, β¦, vision_hidden_states) |
| """ |
| hidden_states, deepstack_features = self.visual( |
| pixel_values, grid_thw=self.grid_thw |
| ) |
| |
| return *deepstack_features, hidden_states |
|
|
|
|
| |
| |
| |
| def main(model_path, output_dir): |
| os.makedirs(output_dir, exist_ok=True) |
| onnx_vision = os.path.join(output_dir, "Vision.onnx") |
|
|
| |
| rp = np.load(os.path.join(output_dir, "rotary_params.npz")) |
| height_factor = int(rp["height_factor"]) |
| width_factor = int(rp["width_factor"]) |
| patch_size = int(rp["patch_size"]) |
| merge_size = int(rp["merge_size"]) |
|
|
| grid_h = height_factor * merge_size |
| grid_w = width_factor * merge_size |
| image_height = int(rp["image_height"]) |
| image_width = int(rp["image_width"]) |
|
|
| temporal_patch_size_npz = int(rp["temporal_patch_size"]) if "temporal_patch_size" in rp else 2 |
| temporal_patches = math.ceil(TEMPORAL_SIZE / temporal_patch_size_npz) |
| total_patches = temporal_patches * grid_h * grid_w |
|
|
| print(f"Grid: {temporal_patches}x{grid_h}x{grid_w} = {total_patches} patches") |
| print(f"Image: {image_height}x{image_width} TEMPORAL_SIZE={TEMPORAL_SIZE}") |
|
|
| |
| print("Loading model β¦") |
| model = Qwen3VLForEmbedding.from_pretrained( |
| model_path, |
| torch_dtype=torch.float32, |
| device_map="cpu", |
| low_cpu_mem_usage=True, |
| _attn_implementation="eager", |
| ).eval() |
|
|
| visual = model.model.visual |
| temporal_patch_size = visual.patch_embed.temporal_patch_size |
| flatten_dim = 3 * temporal_patch_size * patch_size * patch_size |
|
|
| deepstack_features_len = len(visual.deepstack_visual_indexes) |
| print(f" deepstack features: {deepstack_features_len}") |
| print(f" temporal_patch_size: {temporal_patch_size}") |
| print(f" flatten_dim: {flatten_dim}") |
|
|
| |
| grid_thw = torch.tensor([[temporal_patches, grid_h, grid_w]], dtype=torch.int64) |
| wrapper = VisionEncoderWrapper(visual, grid_thw) |
|
|
| |
| dummy_pixels = torch.randn(total_patches, flatten_dim, dtype=torch.float32) |
|
|
| |
| output_names = [] |
| for i in range(deepstack_features_len): |
| output_names.append(f"deepstack_feature_{i}") |
| output_names.append("vision_hidden_states") |
|
|
| |
| print("Exporting Vision (direct HF) β¦") |
| with torch.inference_mode(): |
| torch.onnx.export( |
| wrapper, |
| (dummy_pixels,), |
| onnx_vision, |
| input_names=["pixel_values"], |
| output_names=output_names, |
| opset_version=OPSET, |
| dynamo=False, |
| ) |
|
|
| del wrapper, model |
| gc.collect() |
| print(f" β {onnx_vision}") |
| print("Consolidating external weights β¦") |
| consolidate_external_data(onnx_vision) |
| print("Done.") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| description="Export Qwen3-VL vision encoder to ONNX (direct HF)" |
| ) |
| parser.add_argument( |
| "--model_path", type=str, default=DEFAULT_MODEL_PATH, |
| help=f"Path to the HF model (default: {DEFAULT_MODEL_PATH})", |
| ) |
| parser.add_argument( |
| "--output_dir", type=str, default=DEFAULT_OUTPUT_DIR, |
| help=f"Directory to save ONNX files (default: {DEFAULT_OUTPUT_DIR})", |
| ) |
| args = parser.parse_args() |
|
|
| main(args.model_path, args.output_dir) |
|
|