| |
| """Export RTMPose-M 256x192 ONNX from official OpenMMLab pretrained model. |
| |
| Downloads the official pre-exported ONNX from OpenMMLab model zoo, |
| converts opset if needed, and fixes the batch dimension to static 1 |
| |
| Model: RTMPose-M (13.58M params) |
| Input: 1x3x256x192 (RGB, float32) |
| Output: simcc_x (1,17,384), simcc_y (1,17,512) |
| """ |
|
|
| import argparse |
| import io |
| import os |
| import zipfile |
|
|
| import numpy as np |
| import onnx |
| import onnx.version_converter |
| import requests |
|
|
| ONNX_URL = ( |
| "https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/" |
| "onnx_sdk/rtmpose-m_simcc-body7_pt-body7_420e-256x192-" |
| "e48f03d0_20230504.zip" |
| ) |
| CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".model_cache") |
|
|
|
|
| def download_onnx() -> str: |
| """Download official RTMPose-M ONNX and cache locally.""" |
| cache_path = os.path.join(CACHE_DIR, "rtmpose_m_official.onnx") |
| if os.path.exists(cache_path): |
| print(f"Using cached ONNX: {cache_path}") |
| return cache_path |
|
|
| os.makedirs(CACHE_DIR, exist_ok=True) |
| print("Downloading official RTMPose-M ONNX from OpenMMLab...") |
| resp = requests.get(ONNX_URL, timeout=120) |
| resp.raise_for_status() |
|
|
| with zipfile.ZipFile(io.BytesIO(resp.content)) as zf: |
| for name in zf.namelist(): |
| if name.endswith(".onnx"): |
| with zf.open(name) as src, open(cache_path, "wb") as dst: |
| dst.write(src.read()) |
| print(f"Cached: {cache_path}") |
| return cache_path |
|
|
| raise RuntimeError("No .onnx found in downloaded zip") |
|
|
|
|
| def convert_opset(model: onnx.ModelProto, target_opset: int) -> onnx.ModelProto: |
| """Convert ONNX model to target opset version if needed.""" |
| current_opset = model.opset_import[0].version |
| if current_opset == target_opset: |
| return model |
| print(f"Converting opset {current_opset} -> {target_opset}") |
| return onnx.version_converter.convert_version(model, target_opset) |
|
|
|
|
| def fix_batch_dim(model: onnx.ModelProto, batch: int = 1) -> None: |
| """Replace dynamic batch dim (dim_param) with static dim_value.""" |
| for tensor in list(model.graph.input) + list(model.graph.output): |
| dim0 = tensor.type.tensor_type.shape.dim[0] |
| if dim0.dim_param: |
| dim0.ClearField("dim_param") |
| dim0.dim_value = batch |
|
|
|
|
| def print_model_info(model: onnx.ModelProto) -> None: |
| """Print model parameter count and IO shapes.""" |
| total_params = sum(int(np.prod(init.dims)) for init in model.graph.initializer) |
| print(f"Parameters: {total_params / 1e6:.2f}M") |
| for inp in model.graph.input: |
| dims = [d.dim_value for d in inp.type.tensor_type.shape.dim] |
| print(f" Input: {inp.name} {dims}") |
| for out in model.graph.output: |
| dims = [d.dim_value for d in out.type.tensor_type.shape.dim] |
| print(f" Output: {out.name} {dims}") |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser(description="Export RTMPose-M 256x192 ONNX") |
| ap.add_argument("--opset", type=int, default=13, help="Target ONNX opset version") |
| ap.add_argument("--output", default="rtmpose_m_256x192.onnx", help="Output path") |
| ap.add_argument("--batch", type=int, default=1, help="Static batch size") |
| args = ap.parse_args() |
|
|
| source_path = download_onnx() |
| model = onnx.load(source_path) |
| model = convert_opset(model, args.opset) |
| fix_batch_dim(model, args.batch) |
| onnx.save(model, args.output) |
|
|
| print(f"\nExported: {args.output} ({os.path.getsize(args.output) / 1e6:.2f} MB)") |
| print_model_info(model) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|