#!/usr/bin/env python3 """Export RTMPose-M 256x192 ONNX from official OpenMMLab pretrained model. Downloads the official pre-exported ONNX from OpenMMLab model zoo, converts opset if needed, and fixes the batch dimension to static 1 Model: RTMPose-M (13.58M params) Input: 1x3x256x192 (RGB, float32) Output: simcc_x (1,17,384), simcc_y (1,17,512) """ import argparse import io import os import zipfile import numpy as np import onnx import onnx.version_converter import requests ONNX_URL = ( "https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/" "onnx_sdk/rtmpose-m_simcc-body7_pt-body7_420e-256x192-" "e48f03d0_20230504.zip" ) CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".model_cache") def download_onnx() -> str: """Download official RTMPose-M ONNX and cache locally.""" cache_path = os.path.join(CACHE_DIR, "rtmpose_m_official.onnx") if os.path.exists(cache_path): print(f"Using cached ONNX: {cache_path}") return cache_path os.makedirs(CACHE_DIR, exist_ok=True) print("Downloading official RTMPose-M ONNX from OpenMMLab...") resp = requests.get(ONNX_URL, timeout=120) resp.raise_for_status() with zipfile.ZipFile(io.BytesIO(resp.content)) as zf: for name in zf.namelist(): if name.endswith(".onnx"): with zf.open(name) as src, open(cache_path, "wb") as dst: dst.write(src.read()) print(f"Cached: {cache_path}") return cache_path raise RuntimeError("No .onnx found in downloaded zip") def convert_opset(model: onnx.ModelProto, target_opset: int) -> onnx.ModelProto: """Convert ONNX model to target opset version if needed.""" current_opset = model.opset_import[0].version if current_opset == target_opset: return model print(f"Converting opset {current_opset} -> {target_opset}") return onnx.version_converter.convert_version(model, target_opset) def fix_batch_dim(model: onnx.ModelProto, batch: int = 1) -> None: """Replace dynamic batch dim (dim_param) with static dim_value.""" for tensor in list(model.graph.input) + list(model.graph.output): dim0 = tensor.type.tensor_type.shape.dim[0] if dim0.dim_param: dim0.ClearField("dim_param") dim0.dim_value = batch def print_model_info(model: onnx.ModelProto) -> None: """Print model parameter count and IO shapes.""" total_params = sum(int(np.prod(init.dims)) for init in model.graph.initializer) print(f"Parameters: {total_params / 1e6:.2f}M") for inp in model.graph.input: dims = [d.dim_value for d in inp.type.tensor_type.shape.dim] print(f" Input: {inp.name} {dims}") for out in model.graph.output: dims = [d.dim_value for d in out.type.tensor_type.shape.dim] print(f" Output: {out.name} {dims}") def main(): ap = argparse.ArgumentParser(description="Export RTMPose-M 256x192 ONNX") ap.add_argument("--opset", type=int, default=13, help="Target ONNX opset version") ap.add_argument("--output", default="rtmpose_m_256x192.onnx", help="Output path") ap.add_argument("--batch", type=int, default=1, help="Static batch size") args = ap.parse_args() source_path = download_onnx() model = onnx.load(source_path) model = convert_opset(model, args.opset) fix_batch_dim(model, args.batch) onnx.save(model, args.output) print(f"\nExported: {args.output} ({os.path.getsize(args.output) / 1e6:.2f} MB)") print_model_info(model) if __name__ == "__main__": main()