| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import tempfile |
| from pathlib import Path |
| from typing import Dict, Tuple |
|
|
| import torch |
| from safetensors.torch import load_file |
| from torch import nn |
| from transformers import LlamaConfig, LlamaForCausalLM, SiglipVisionConfig, SiglipVisionModel |
| from transformers.cache_utils import DynamicCache |
| from optimum.exporters.onnx import main_export |
|
|
|
|
| class ProjectionHead(nn.Module): |
| """ |
| Simple projection used to map SigLIP hidden states into the LM embedding space. |
| """ |
|
|
| def __init__(self, in_dim: int, out_dim: int): |
| super().__init__() |
| self.projection = nn.Linear(in_dim, out_dim) |
|
|
| def forward(self, vision_hidden_states: torch.Tensor) -> torch.Tensor: |
| return self.projection(vision_hidden_states) |
|
|
|
|
| def normalize_lm_key(raw_key: str) -> str: |
| key = raw_key.replace("base_model.model.model.", "model.") |
| key = key.replace("base_model.model.", "model.") |
| key = key.replace("language_model.", "") |
| if key.startswith("model.lm_head."): |
| key = key.replace("model.lm_head.", "lm_head.") |
| return key |
|
|
|
|
| def split_state_dict(full_state: Dict[str, torch.Tensor]) -> Tuple[Dict, Dict, Dict]: |
| vision_sd = {k[len("vision_encoder.") :]: v for k, v in full_state.items() if k.startswith("vision_encoder.")} |
| projection_sd = {k[len("projection.") :]: v for k, v in full_state.items() if k.startswith("projection.")} |
| lm_sd = {k[len("language_model.") :]: v for k, v in full_state.items() if k.startswith("language_model.")} |
| return vision_sd, projection_sd, lm_sd |
|
|
|
|
| def merge_lora(lm_state: Dict[str, torch.Tensor], lora_alpha: int, lora_r: int) -> Dict[str, torch.Tensor]: |
| """ |
| Merge LoRA weights into the base LM weights. |
| """ |
| merged: Dict[str, torch.Tensor] = {} |
|
|
| for key, tensor in lm_state.items(): |
| if "lora_" in key: |
| continue |
| clean_key = normalize_lm_key(key.replace(".base_layer", "")) |
| merged[clean_key] = tensor.clone() |
|
|
| scale = lora_alpha / float(lora_r) |
| for a_key, a_tensor in lm_state.items(): |
| if not a_key.endswith("lora_A.default.weight"): |
| continue |
| b_key = a_key.replace("lora_A.default.weight", "lora_B.default.weight") |
| base_key = normalize_lm_key(a_key.replace(".lora_A.default.weight", ".weight").replace(".base_layer", "")) |
|
|
| if base_key not in merged: |
| raise KeyError(f"Base weight missing for LoRA merge: {base_key}") |
|
|
| b_tensor = lm_state[b_key] |
| merged[base_key] += (b_tensor @ a_tensor) * scale |
|
|
| return merged |
|
|
|
|
| def build_models(model_dir: Path) -> Tuple[SiglipVisionModel, ProjectionHead, LlamaForCausalLM]: |
| with open(model_dir / "config.json", "r") as f: |
| config = json.load(f) |
|
|
| full_state = load_file(model_dir / "model.safetensors", device="cpu") |
| vision_sd, projection_sd, lm_sd = split_state_dict(full_state) |
|
|
| vision_config = SiglipVisionConfig(**config["vision_config"]) |
| text_config = LlamaConfig(**config["text_config"]) |
|
|
| vision_encoder = SiglipVisionModel(vision_config) |
| vision_encoder.load_state_dict(vision_sd, strict=True) |
| vision_encoder.eval() |
|
|
| projection = ProjectionHead( |
| in_dim=vision_config.hidden_size, |
| out_dim=projection_sd["projection.weight"].shape[0], |
| ) |
| projection.load_state_dict(projection_sd, strict=True) |
| projection.eval() |
|
|
| merged_lm_state = merge_lora(lm_sd, lora_alpha=config["lora_alpha"], lora_r=config["lora_r"]) |
| language_model = LlamaForCausalLM(text_config) |
| language_model.load_state_dict(merged_lm_state, strict=True) |
| |
| language_model.config._attn_implementation = "eager" |
| language_model.eval() |
|
|
| return vision_encoder, projection, language_model |
|
|
|
|
| class PrefixDecoderInit(nn.Module): |
| """ |
| ONNX-friendly wrapper for the first pass: runs prefix embeddings + text tokens through LM and returns logits + cache. |
| """ |
|
|
| def __init__(self, language_model: LlamaForCausalLM): |
| super().__init__() |
| self.language_model = language_model |
|
|
| def forward(self, prefix_embeddings: torch.Tensor, input_ids: torch.Tensor): |
| text_embeds = self.language_model.model.embed_tokens(input_ids) |
| inputs_embeds = torch.cat([prefix_embeddings, text_embeds], dim=1) |
|
|
| outputs = self.language_model( |
| inputs_embeds=inputs_embeds, |
| use_cache=True, |
| ) |
| pkv = outputs.past_key_values.to_legacy_cache() |
| return (outputs.logits, *sum(pkv, ())) |
|
|
|
|
| class PrefixDecoderWithPast(nn.Module): |
| """ |
| ONNX-friendly wrapper for decoding steps: consumes past key values and new text tokens. |
| """ |
|
|
| def __init__(self, language_model: LlamaForCausalLM): |
| super().__init__() |
| self.language_model = language_model |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| position_ids: torch.Tensor, |
| *past_key_values: torch.Tensor, |
| ): |
| text_embeds = self.language_model.model.embed_tokens(input_ids) |
| legacy = tuple((past_key_values[i], past_key_values[i + 1]) for i in range(0, len(past_key_values), 2)) |
| past = DynamicCache.from_legacy_cache(legacy) |
| outputs = self.language_model( |
| inputs_embeds=text_embeds, |
| position_ids=position_ids, |
| past_key_values=past, |
| use_cache=True, |
| ) |
| pkv = outputs.past_key_values.to_legacy_cache() |
| return (outputs.logits, *sum(pkv, ())) |
|
|
|
|
| def export_submodels( |
| vision_encoder: SiglipVisionModel, |
| projection: ProjectionHead, |
| language_model: LlamaForCausalLM, |
| output_dir: Path, |
| opset: int, |
| ) -> None: |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| with tempfile.TemporaryDirectory() as tmp: |
| tmp_path = Path(tmp) |
| vision_path = tmp_path / "vision_encoder" |
| lm_path = tmp_path / "language_model" |
| vision_encoder.save_pretrained(vision_path) |
| language_model.save_pretrained(lm_path) |
|
|
| main_export( |
| model_name_or_path=str(vision_path), |
| output=str(output_dir / "vision_encoder"), |
| task="feature-extraction", |
| opset=opset, |
| device="cpu", |
| trust_remote_code=True, |
| ) |
|
|
| main_export( |
| model_name_or_path=str(lm_path), |
| output=str(output_dir / "language_model"), |
| task="text-generation-with-past", |
| opset=opset, |
| device="cpu", |
| trust_remote_code=True, |
| ) |
|
|
| seq_len = (vision_encoder.config.image_size // vision_encoder.config.patch_size) ** 2 + 1 |
| sample = torch.zeros(1, seq_len, vision_encoder.config.hidden_size, dtype=torch.float32) |
| torch.onnx.export( |
| projection, |
| sample, |
| output_dir / "projection.onnx", |
| input_names=["vision_hidden_states"], |
| output_names=["projected_prefix"], |
| opset_version=opset, |
| dynamic_axes={ |
| "vision_hidden_states": {0: "batch", 1: "sequence"}, |
| "projected_prefix": {0: "batch", 1: "sequence"}, |
| }, |
| ) |
|
|
|
|
| def export_full_graph( |
| vision_encoder: SiglipVisionModel, |
| projection: ProjectionHead, |
| language_model: LlamaForCausalLM, |
| work_dir: Path, |
| opset: int, |
| ) -> None: |
| """ |
| Export a single-step graph (vision+projection+LM) and a decode graph (text-only with past). |
| """ |
|
|
| full_dir = work_dir |
| full_dir.mkdir(parents=True, exist_ok=True) |
|
|
| decoder_init = PrefixDecoderInit(language_model) |
|
|
| batch = 1 |
| text_len = 4 |
| prefix_len = (vision_encoder.config.image_size // vision_encoder.config.patch_size) ** 2 |
|
|
| prefix_dummy = torch.zeros(batch, prefix_len, language_model.config.hidden_size, dtype=torch.float32) |
| input_ids = torch.zeros(batch, text_len, dtype=torch.long) |
| present_names = [ |
| name |
| for i in range(language_model.config.num_hidden_layers) |
| for name in (f"present.{i}.key", f"present.{i}.value") |
| ] |
| init_dynamic_shapes = { |
| "prefix_embeddings": {0: "batch", 1: "prefix_sequence"}, |
| "input_ids": {0: "batch", 1: "text_sequence"}, |
| } |
|
|
| torch.onnx.export( |
| decoder_init, |
| (prefix_dummy, input_ids), |
| full_dir / "decoder_prefix_init.onnx", |
| input_names=["prefix_embeddings", "input_ids"], |
| output_names=["logits", *present_names], |
| opset_version=opset, |
| dynamo=True, |
| dynamic_shapes=init_dynamic_shapes, |
| ) |
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Prepare and export DistilViT2 submodules to ONNX.") |
| parser.add_argument("--model-dir", type=Path, default=Path("."), help="Folder containing config.json and model.safetensors.") |
| parser.add_argument("--output-dir", type=Path, default=Path("onnx"), help="Where to store ONNX files.") |
| parser.add_argument("--opset", type=int, default=17, help="ONNX opset to use.") |
| parser.add_argument("--skip-onnx", action="store_true", help="Only prepare HF checkpoints, skip ONNX export.") |
| parser.add_argument( |
| "--full-onnx", |
| action="store_true", |
| help="Also export a full graph (vision+projection+LM) and a decoder_with_past graph.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| vision_encoder, projection, language_model = build_models(args.model_dir) |
|
|
| args.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| export_submodels(vision_encoder, projection, language_model, args.output_dir, args.opset) |
| if args.full_onnx: |
| export_full_graph(vision_encoder, projection, language_model, args.output_dir, args.opset) |
| print(f"Full ONNX graphs written to {args.output_dir}.") |
| print(f"ONNX models written to {args.output_dir}.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|