File size: 3,597 Bytes
993d81c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 | #!/usr/bin/env python3
"""Export RTMPose-M 256x192 ONNX from official OpenMMLab pretrained model.
Downloads the official pre-exported ONNX from OpenMMLab model zoo,
converts opset if needed, and fixes the batch dimension to static 1
Model: RTMPose-M (13.58M params)
Input: 1x3x256x192 (RGB, float32)
Output: simcc_x (1,17,384), simcc_y (1,17,512)
"""
import argparse
import io
import os
import zipfile
import numpy as np
import onnx
import onnx.version_converter
import requests
ONNX_URL = (
"https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/"
"onnx_sdk/rtmpose-m_simcc-body7_pt-body7_420e-256x192-"
"e48f03d0_20230504.zip"
)
CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".model_cache")
def download_onnx() -> str:
"""Download official RTMPose-M ONNX and cache locally."""
cache_path = os.path.join(CACHE_DIR, "rtmpose_m_official.onnx")
if os.path.exists(cache_path):
print(f"Using cached ONNX: {cache_path}")
return cache_path
os.makedirs(CACHE_DIR, exist_ok=True)
print("Downloading official RTMPose-M ONNX from OpenMMLab...")
resp = requests.get(ONNX_URL, timeout=120)
resp.raise_for_status()
with zipfile.ZipFile(io.BytesIO(resp.content)) as zf:
for name in zf.namelist():
if name.endswith(".onnx"):
with zf.open(name) as src, open(cache_path, "wb") as dst:
dst.write(src.read())
print(f"Cached: {cache_path}")
return cache_path
raise RuntimeError("No .onnx found in downloaded zip")
def convert_opset(model: onnx.ModelProto, target_opset: int) -> onnx.ModelProto:
"""Convert ONNX model to target opset version if needed."""
current_opset = model.opset_import[0].version
if current_opset == target_opset:
return model
print(f"Converting opset {current_opset} -> {target_opset}")
return onnx.version_converter.convert_version(model, target_opset)
def fix_batch_dim(model: onnx.ModelProto, batch: int = 1) -> None:
"""Replace dynamic batch dim (dim_param) with static dim_value."""
for tensor in list(model.graph.input) + list(model.graph.output):
dim0 = tensor.type.tensor_type.shape.dim[0]
if dim0.dim_param:
dim0.ClearField("dim_param")
dim0.dim_value = batch
def print_model_info(model: onnx.ModelProto) -> None:
"""Print model parameter count and IO shapes."""
total_params = sum(int(np.prod(init.dims)) for init in model.graph.initializer)
print(f"Parameters: {total_params / 1e6:.2f}M")
for inp in model.graph.input:
dims = [d.dim_value for d in inp.type.tensor_type.shape.dim]
print(f" Input: {inp.name} {dims}")
for out in model.graph.output:
dims = [d.dim_value for d in out.type.tensor_type.shape.dim]
print(f" Output: {out.name} {dims}")
def main():
ap = argparse.ArgumentParser(description="Export RTMPose-M 256x192 ONNX")
ap.add_argument("--opset", type=int, default=13, help="Target ONNX opset version")
ap.add_argument("--output", default="rtmpose_m_256x192.onnx", help="Output path")
ap.add_argument("--batch", type=int, default=1, help="Static batch size")
args = ap.parse_args()
source_path = download_onnx()
model = onnx.load(source_path)
model = convert_opset(model, args.opset)
fix_batch_dim(model, args.batch)
onnx.save(model, args.output)
print(f"\nExported: {args.output} ({os.path.getsize(args.output) / 1e6:.2f} MB)")
print_model_info(model)
if __name__ == "__main__":
main()
|