alt-text-gen / export_prefix_vlm.py
tarekziade's picture
tarekziade HF Staff
Added JS demo
dd05587 unverified
from __future__ import annotations
import argparse
import json
import tempfile
from pathlib import Path
from typing import Dict, Tuple
import torch
from safetensors.torch import load_file
from torch import nn
from transformers import LlamaConfig, LlamaForCausalLM, SiglipVisionConfig, SiglipVisionModel
from transformers.cache_utils import DynamicCache
from optimum.exporters.onnx import main_export
class ProjectionHead(nn.Module):
"""
Simple projection used to map SigLIP hidden states into the LM embedding space.
"""
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
self.projection = nn.Linear(in_dim, out_dim)
def forward(self, vision_hidden_states: torch.Tensor) -> torch.Tensor:
return self.projection(vision_hidden_states)
def normalize_lm_key(raw_key: str) -> str:
key = raw_key.replace("base_model.model.model.", "model.")
key = key.replace("base_model.model.", "model.")
key = key.replace("language_model.", "")
if key.startswith("model.lm_head."):
key = key.replace("model.lm_head.", "lm_head.")
return key
def split_state_dict(full_state: Dict[str, torch.Tensor]) -> Tuple[Dict, Dict, Dict]:
vision_sd = {k[len("vision_encoder.") :]: v for k, v in full_state.items() if k.startswith("vision_encoder.")}
projection_sd = {k[len("projection.") :]: v for k, v in full_state.items() if k.startswith("projection.")}
lm_sd = {k[len("language_model.") :]: v for k, v in full_state.items() if k.startswith("language_model.")}
return vision_sd, projection_sd, lm_sd
def merge_lora(lm_state: Dict[str, torch.Tensor], lora_alpha: int, lora_r: int) -> Dict[str, torch.Tensor]:
"""
Merge LoRA weights into the base LM weights.
"""
merged: Dict[str, torch.Tensor] = {}
for key, tensor in lm_state.items():
if "lora_" in key:
continue
clean_key = normalize_lm_key(key.replace(".base_layer", ""))
merged[clean_key] = tensor.clone()
scale = lora_alpha / float(lora_r)
for a_key, a_tensor in lm_state.items():
if not a_key.endswith("lora_A.default.weight"):
continue
b_key = a_key.replace("lora_A.default.weight", "lora_B.default.weight")
base_key = normalize_lm_key(a_key.replace(".lora_A.default.weight", ".weight").replace(".base_layer", ""))
if base_key not in merged:
raise KeyError(f"Base weight missing for LoRA merge: {base_key}")
b_tensor = lm_state[b_key]
merged[base_key] += (b_tensor @ a_tensor) * scale
return merged
def build_models(model_dir: Path) -> Tuple[SiglipVisionModel, ProjectionHead, LlamaForCausalLM]:
with open(model_dir / "config.json", "r") as f:
config = json.load(f)
full_state = load_file(model_dir / "model.safetensors", device="cpu")
vision_sd, projection_sd, lm_sd = split_state_dict(full_state)
vision_config = SiglipVisionConfig(**config["vision_config"])
text_config = LlamaConfig(**config["text_config"])
vision_encoder = SiglipVisionModel(vision_config)
vision_encoder.load_state_dict(vision_sd, strict=True)
vision_encoder.eval()
projection = ProjectionHead(
in_dim=vision_config.hidden_size,
out_dim=projection_sd["projection.weight"].shape[0],
)
projection.load_state_dict(projection_sd, strict=True)
projection.eval()
merged_lm_state = merge_lora(lm_sd, lora_alpha=config["lora_alpha"], lora_r=config["lora_r"])
language_model = LlamaForCausalLM(text_config)
language_model.load_state_dict(merged_lm_state, strict=True)
# Force eager attention to simplify ONNX export paths.
language_model.config._attn_implementation = "eager"
language_model.eval()
return vision_encoder, projection, language_model
class PrefixDecoderInit(nn.Module):
"""
ONNX-friendly wrapper for the first pass: runs prefix embeddings + text tokens through LM and returns logits + cache.
"""
def __init__(self, language_model: LlamaForCausalLM):
super().__init__()
self.language_model = language_model
def forward(self, prefix_embeddings: torch.Tensor, input_ids: torch.Tensor):
text_embeds = self.language_model.model.embed_tokens(input_ids)
inputs_embeds = torch.cat([prefix_embeddings, text_embeds], dim=1)
outputs = self.language_model(
inputs_embeds=inputs_embeds,
use_cache=True,
)
pkv = outputs.past_key_values.to_legacy_cache()
return (outputs.logits, *sum(pkv, ()))
class PrefixDecoderWithPast(nn.Module):
"""
ONNX-friendly wrapper for decoding steps: consumes past key values and new text tokens.
"""
def __init__(self, language_model: LlamaForCausalLM):
super().__init__()
self.language_model = language_model
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
*past_key_values: torch.Tensor,
):
text_embeds = self.language_model.model.embed_tokens(input_ids)
legacy = tuple((past_key_values[i], past_key_values[i + 1]) for i in range(0, len(past_key_values), 2))
past = DynamicCache.from_legacy_cache(legacy)
outputs = self.language_model(
inputs_embeds=text_embeds,
position_ids=position_ids,
past_key_values=past,
use_cache=True,
)
pkv = outputs.past_key_values.to_legacy_cache()
return (outputs.logits, *sum(pkv, ()))
def export_submodels(
vision_encoder: SiglipVisionModel,
projection: ProjectionHead,
language_model: LlamaForCausalLM,
output_dir: Path,
opset: int,
) -> None:
output_dir.mkdir(parents=True, exist_ok=True)
with tempfile.TemporaryDirectory() as tmp:
tmp_path = Path(tmp)
vision_path = tmp_path / "vision_encoder"
lm_path = tmp_path / "language_model"
vision_encoder.save_pretrained(vision_path)
language_model.save_pretrained(lm_path)
main_export(
model_name_or_path=str(vision_path),
output=str(output_dir / "vision_encoder"),
task="feature-extraction",
opset=opset,
device="cpu",
trust_remote_code=True,
)
main_export(
model_name_or_path=str(lm_path),
output=str(output_dir / "language_model"),
task="text-generation-with-past",
opset=opset,
device="cpu",
trust_remote_code=True,
)
seq_len = (vision_encoder.config.image_size // vision_encoder.config.patch_size) ** 2 + 1
sample = torch.zeros(1, seq_len, vision_encoder.config.hidden_size, dtype=torch.float32)
torch.onnx.export(
projection,
sample,
output_dir / "projection.onnx",
input_names=["vision_hidden_states"],
output_names=["projected_prefix"],
opset_version=opset,
dynamic_axes={
"vision_hidden_states": {0: "batch", 1: "sequence"},
"projected_prefix": {0: "batch", 1: "sequence"},
},
)
def export_full_graph(
vision_encoder: SiglipVisionModel,
projection: ProjectionHead,
language_model: LlamaForCausalLM,
work_dir: Path,
opset: int,
) -> None:
"""
Export a single-step graph (vision+projection+LM) and a decode graph (text-only with past).
"""
full_dir = work_dir
full_dir.mkdir(parents=True, exist_ok=True)
decoder_init = PrefixDecoderInit(language_model)
batch = 1
text_len = 4
prefix_len = (vision_encoder.config.image_size // vision_encoder.config.patch_size) ** 2
prefix_dummy = torch.zeros(batch, prefix_len, language_model.config.hidden_size, dtype=torch.float32)
input_ids = torch.zeros(batch, text_len, dtype=torch.long)
present_names = [
name
for i in range(language_model.config.num_hidden_layers)
for name in (f"present.{i}.key", f"present.{i}.value")
]
init_dynamic_shapes = {
"prefix_embeddings": {0: "batch", 1: "prefix_sequence"},
"input_ids": {0: "batch", 1: "text_sequence"},
}
torch.onnx.export(
decoder_init,
(prefix_dummy, input_ids),
full_dir / "decoder_prefix_init.onnx",
input_names=["prefix_embeddings", "input_ids"],
output_names=["logits", *present_names],
opset_version=opset,
dynamo=True,
dynamic_shapes=init_dynamic_shapes,
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Prepare and export DistilViT2 submodules to ONNX.")
parser.add_argument("--model-dir", type=Path, default=Path("."), help="Folder containing config.json and model.safetensors.")
parser.add_argument("--output-dir", type=Path, default=Path("onnx"), help="Where to store ONNX files.")
parser.add_argument("--opset", type=int, default=17, help="ONNX opset to use.")
parser.add_argument("--skip-onnx", action="store_true", help="Only prepare HF checkpoints, skip ONNX export.")
parser.add_argument(
"--full-onnx",
action="store_true",
help="Also export a full graph (vision+projection+LM) and a decoder_with_past graph.",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
vision_encoder, projection, language_model = build_models(args.model_dir)
args.output_dir.mkdir(parents=True, exist_ok=True)
export_submodels(vision_encoder, projection, language_model, args.output_dir, args.opset)
if args.full_onnx:
export_full_graph(vision_encoder, projection, language_model, args.output_dir, args.opset)
print(f"Full ONNX graphs written to {args.output_dir}.")
print(f"ONNX models written to {args.output_dir}.")
if __name__ == "__main__":
main()