File size: 3,597 Bytes
993d81c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#!/usr/bin/env python3
"""Export RTMPose-M 256x192 ONNX from official OpenMMLab pretrained model.

Downloads the official pre-exported ONNX from OpenMMLab model zoo,
converts opset if needed, and fixes the batch dimension to static 1

Model: RTMPose-M (13.58M params)
Input: 1x3x256x192 (RGB, float32)
Output: simcc_x (1,17,384), simcc_y (1,17,512)
"""

import argparse
import io
import os
import zipfile

import numpy as np
import onnx
import onnx.version_converter
import requests

ONNX_URL = (
    "https://download.openmmlab.com/mmpose/v1/projects/rtmposev1/"
    "onnx_sdk/rtmpose-m_simcc-body7_pt-body7_420e-256x192-"
    "e48f03d0_20230504.zip"
)
CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), ".model_cache")


def download_onnx() -> str:
    """Download official RTMPose-M ONNX and cache locally."""
    cache_path = os.path.join(CACHE_DIR, "rtmpose_m_official.onnx")
    if os.path.exists(cache_path):
        print(f"Using cached ONNX: {cache_path}")
        return cache_path

    os.makedirs(CACHE_DIR, exist_ok=True)
    print("Downloading official RTMPose-M ONNX from OpenMMLab...")
    resp = requests.get(ONNX_URL, timeout=120)
    resp.raise_for_status()

    with zipfile.ZipFile(io.BytesIO(resp.content)) as zf:
        for name in zf.namelist():
            if name.endswith(".onnx"):
                with zf.open(name) as src, open(cache_path, "wb") as dst:
                    dst.write(src.read())
                print(f"Cached: {cache_path}")
                return cache_path

    raise RuntimeError("No .onnx found in downloaded zip")


def convert_opset(model: onnx.ModelProto, target_opset: int) -> onnx.ModelProto:
    """Convert ONNX model to target opset version if needed."""
    current_opset = model.opset_import[0].version
    if current_opset == target_opset:
        return model
    print(f"Converting opset {current_opset} -> {target_opset}")
    return onnx.version_converter.convert_version(model, target_opset)


def fix_batch_dim(model: onnx.ModelProto, batch: int = 1) -> None:
    """Replace dynamic batch dim (dim_param) with static dim_value."""
    for tensor in list(model.graph.input) + list(model.graph.output):
        dim0 = tensor.type.tensor_type.shape.dim[0]
        if dim0.dim_param:
            dim0.ClearField("dim_param")
            dim0.dim_value = batch


def print_model_info(model: onnx.ModelProto) -> None:
    """Print model parameter count and IO shapes."""
    total_params = sum(int(np.prod(init.dims)) for init in model.graph.initializer)
    print(f"Parameters: {total_params / 1e6:.2f}M")
    for inp in model.graph.input:
        dims = [d.dim_value for d in inp.type.tensor_type.shape.dim]
        print(f"  Input:  {inp.name} {dims}")
    for out in model.graph.output:
        dims = [d.dim_value for d in out.type.tensor_type.shape.dim]
        print(f"  Output: {out.name} {dims}")


def main():
    ap = argparse.ArgumentParser(description="Export RTMPose-M 256x192 ONNX")
    ap.add_argument("--opset", type=int, default=13, help="Target ONNX opset version")
    ap.add_argument("--output", default="rtmpose_m_256x192.onnx", help="Output path")
    ap.add_argument("--batch", type=int, default=1, help="Static batch size")
    args = ap.parse_args()

    source_path = download_onnx()
    model = onnx.load(source_path)
    model = convert_opset(model, args.opset)
    fix_batch_dim(model, args.batch)
    onnx.save(model, args.output)

    print(f"\nExported: {args.output} ({os.path.getsize(args.output) / 1e6:.2f} MB)")
    print_model_info(model)


if __name__ == "__main__":
    main()