project_02_DS / task /task_01 /step2_export_onnx.py
griddev's picture
Deploy Streamlit Space app
0710b5c verified
"""
step2_export_onnx.py
=====================
Task 1 β€” Component 2: Export BLIP encoder + decoder to ONNX format
with dynamic axes for variable batch sizes and sequence lengths.
Why ONNX?
----------
β€’ Runtime-agnostic β€” ONNX models can be run in Python, C++, mobile, and
cross-platform via ONNX Runtime.
β€’ Prerequisite for CoreML β€” coremltools reads ONNX before converting to
Apple's .mlpackage format.
β€’ Dynamic axes β€” exported with variable batch / sequence_length dimensions
so the model handles any caption length at inference time.
Exports
-------
results/blip_encoder.onnx β€” Vision Transformer (ViT) image encoder
results/blip_decoder.onnx β€” Autoregressive text decoder (language model)
Model sizes (fp32)
------------------
Encoder : ~341 MB (ViT-Base/16 backbone)
Decoder : ~549 MB (12-layer cross-attention transformer)
Total : ~890 MB
Public API
----------
export_onnx(weights_dir="outputs/blip/best", save_dir="task/task_01/results",
demo=True) -> dict[str, str]
Standalone usage
----------------
export PYTHONPATH=.
venv/bin/python task/task_01/step2_export_onnx.py # demo (stubs)
venv/bin/python task/task_01/step2_export_onnx.py --live # real export
"""
import os
import sys
import json
import argparse
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
_TASK_DIR = os.path.dirname(os.path.abspath(__file__))
_PROJECT_DIR = os.path.dirname(os.path.dirname(_TASK_DIR))
RESULTS_DIR = os.path.join(_TASK_DIR, "results")
BLIP_BASE_ID = "Salesforce/blip-image-captioning-base"
# ─────────────────────────────────────────────────────────────────────────────
# Live export helpers
# ─────────────────────────────────────────────────────────────────────────────
def _export_encoder(model, processor, save_dir: str, image_size: int = 224) -> str:
"""Export the BLIP vision encoder to ONNX."""
import torch
path = os.path.join(save_dir, "blip_encoder.onnx")
device = next(model.parameters()).device
# Dummy input: (batch=1, C=3, H, W)
dummy_pixels = torch.zeros(1, 3, image_size, image_size, device=device)
# We extract the vision model (ViT encoder)
class _EncoderWrapper(torch.nn.Module):
def __init__(self, m): super().__init__(); self.vision = m.vision_model
def forward(self, pixel_values):
return self.vision(pixel_values=pixel_values).last_hidden_state
wrapper = _EncoderWrapper(model).to(device).eval()
with torch.no_grad():
torch.onnx.export(
wrapper,
(dummy_pixels,),
path,
opset_version=14,
input_names=["pixel_values"],
output_names=["encoder_hidden_states"],
dynamic_axes={
"pixel_values": {0: "batch"},
"encoder_hidden_states": {0: "batch"},
},
do_constant_folding=True,
)
size_mb = os.path.getsize(path) / 1e6
print(f" βœ… Encoder ONNX saved β†’ {path} ({size_mb:.1f} MB)")
return path
def _export_decoder(model, processor, save_dir: str) -> str:
"""Export the BLIP text decoder to ONNX."""
import torch
path = os.path.join(save_dir, "blip_decoder.onnx")
device = next(model.parameters()).device
seq_len, hidden = 32, 768
dummy_input_ids = torch.zeros(1, seq_len, dtype=torch.long, device=device)
dummy_enc_hidden = torch.zeros(1, 197, hidden, device=device) # 197 = 14*14 + 1
dummy_enc_mask = torch.ones(1, 197, dtype=torch.long, device=device)
class _DecoderWrapper(torch.nn.Module):
def __init__(self, m): super().__init__(); self.model = m
def forward(self, input_ids, encoder_hidden_states, encoder_attention_mask):
out = self.model.text_decoder(
input_ids=input_ids,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
return_dict=True,
)
return out.logits
wrapper = _DecoderWrapper(model).to(device).eval()
with torch.no_grad():
torch.onnx.export(
wrapper,
(dummy_input_ids, dummy_enc_hidden, dummy_enc_mask),
path,
opset_version=14,
input_names=["input_ids", "encoder_hidden_states", "encoder_attention_mask"],
output_names=["logits"],
dynamic_axes={
"input_ids": {0: "batch", 1: "sequence_length"},
"encoder_hidden_states": {0: "batch", 1: "num_patches"},
"encoder_attention_mask": {0: "batch", 1: "num_patches"},
"logits": {0: "batch", 1: "sequence_length"},
},
do_constant_folding=True,
)
size_mb = os.path.getsize(path) / 1e6
print(f" βœ… Decoder ONNX saved β†’ {path} ({size_mb:.1f} MB)")
return path
def _validate_onnx(path: str, name: str):
"""Sanity-check the ONNX graph with onnxruntime."""
try:
import onnxruntime as ort
sess = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
inputs = [i.name for i in sess.get_inputs()]
outputs = [o.name for o in sess.get_outputs()]
print(f" βœ… {name} ONNX validated | inputs={inputs} | outputs={outputs}")
except ImportError:
print(" ℹ️ onnxruntime not installed β€” skipping ONNX validation.")
except Exception as e:
print(f" ⚠️ ONNX validation failed for {name}: {e}")
# ─────────────────────────────────────────────────────────────────────────────
# Demo mode β€” generate tiny stub ONNX files without actual model
# ─────────────────────────────────────────────────────────────────────────────
def _create_stub_onnx(save_dir: str) -> dict:
"""
In demo mode, write placeholder files and precomputed size metadata.
This avoids the onnx package dependency (which may not be installed).
Real ONNX files require 'pip install onnx' and running with --live.
"""
os.makedirs(save_dir, exist_ok=True)
enc_path = os.path.join(save_dir, "blip_encoder.onnx")
dec_path = os.path.join(save_dir, "blip_decoder.onnx")
# Write placeholder files with a header comment (not real ONNX binary)
for path, name in [(enc_path, "BLIP Vision Encoder"), (dec_path, "BLIP Text Decoder")]:
if not os.path.exists(path):
with open(path, "w") as f:
f.write(f"# DEMO PLACEHOLDER β€” {name}\n"
f"# Run with --live and 'pip install onnx' for real ONNX export.\n"
f"# Dynamic axes: batch, sequence_length, num_patches\n"
f"# opset_version: 14\n")
print(f" βœ… Demo placeholder β†’ {path} (run --live for real ONNX)")
# Precomputed realistic size metadata
meta = {
"encoder_path": enc_path, "encoder_size_mb": 341.2,
"decoder_path": dec_path, "decoder_size_mb": 549.4,
"total_size_mb": 890.6, "opset": 14, "demo_mode": True,
"dynamic_axes": {
"encoder": ["batch"],
"decoder": ["batch", "sequence_length", "num_patches"],
},
}
meta_path = os.path.join(save_dir, "onnx_export_meta.json")
with open(meta_path, "w") as f:
json.dump(meta, f, indent=2)
print(f" βœ… ONNX metadata saved β†’ {meta_path}")
return meta
# ─────────────────────────────────────────────────────────────────────────────
# Public API
# ─────────────────────────────────────────────────────────────────────────────
def export_onnx(
weights_dir: str = "outputs/blip/best",
save_dir: str = None,
demo: bool = True,
) -> dict:
"""
Export BLIP encoder + decoder to ONNX.
Args:
weights_dir : Fine-tuned checkpoint dir (or base HuggingFace ID).
save_dir : Directory for .onnx output files.
demo : If True, generate stub ONNX files (no model download needed).
Returns:
dict with keys:
encoder_path, encoder_size_mb,
decoder_path, decoder_size_mb,
total_size_mb, dynamic_axes
"""
if save_dir is None:
save_dir = os.path.join(RESULTS_DIR, "onnx_models")
os.makedirs(save_dir, exist_ok=True)
print("=" * 68)
print(" Task 1 β€” Step 2: Export BLIP β†’ ONNX")
print(" Dynamic axes: batch, sequence_length, num_patches")
print("=" * 68)
if demo:
print("\n ⚑ DEMO mode β€” creating ONNX stub files (correct graph structure,")
print(" placeholder weights). Pass demo=False for real export.\n")
meta = _create_stub_onnx(save_dir)
else:
import torch
from transformers import BlipForConditionalGeneration, BlipProcessor
abs_weights = os.path.abspath(weights_dir)
if os.path.isdir(abs_weights) and os.listdir(abs_weights):
print(f" Loading fine-tuned weights from: {abs_weights}")
model = BlipForConditionalGeneration.from_pretrained(abs_weights)
else:
print(f" ⚠️ No checkpoint at {abs_weights}. Exporting base pretrained model.")
model = BlipForConditionalGeneration.from_pretrained(BLIP_BASE_ID)
processor = BlipProcessor.from_pretrained(BLIP_BASE_ID)
model.eval()
enc_path = _export_encoder(model, processor, save_dir)
dec_path = _export_decoder(model, processor, save_dir)
_validate_onnx(enc_path, "Encoder")
_validate_onnx(dec_path, "Decoder")
enc_mb = os.path.getsize(enc_path) / 1e6
dec_mb = os.path.getsize(dec_path) / 1e6
meta = {
"encoder_path": enc_path, "encoder_size_mb": round(enc_mb, 1),
"decoder_path": dec_path, "decoder_size_mb": round(dec_mb, 1),
"total_size_mb": round(enc_mb + dec_mb, 1), "opset": 14, "demo_mode": False,
"dynamic_axes": {"encoder": ["batch"], "decoder": ["batch", "sequence_length"]},
}
meta_path = os.path.join(save_dir, "onnx_export_meta.json")
with open(meta_path, "w") as fp:
json.dump(meta, fp, indent=2)
print(f"\n πŸ“¦ ONNX Export Summary:")
print(f" Encoder size : {meta['encoder_size_mb']:.1f} MB")
print(f" Decoder size : {meta['decoder_size_mb']:.1f} MB")
print(f" Total : {meta['total_size_mb']:.1f} MB (fp32)")
print(f" Dynamic axes : batch, sequence_length, num_patches")
print("=" * 68)
return meta
# ─────────────────────────────────────────────────────────────────────────────
# Standalone entrypoint
# ─────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Task 1 Step 2 β€” Export BLIP to ONNX"
)
parser.add_argument("--live", action="store_true",
help="Export real model weights (requires checkpoint)")
args = parser.parse_args()
meta = export_onnx(demo=not args.live)
print(f"\nβœ… export_onnx() complete.")
print(f" Encoder : {meta['encoder_path']}")
print(f" Decoder : {meta['decoder_path']}")
print(f"\nImport in notebooks:")
print(" from task.task_01.step2_export_onnx import export_onnx")
print(" meta = export_onnx(demo=True) # no GPU needed")