File size: 6,564 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
# Copyright (c) OpenMMLab. All rights reserved.
# This script serves the sole purpose of converting spatial-temporal detection
# models supported in MMAction2 to ONNX files. Please note that attempting to
# convert other models using this script may not yield successful results.
import argparse

import onnxruntime
import torch
import torch.nn as nn
from mmdet.structures.bbox import bbox2roi
from mmengine import Config
from mmengine.registry import init_default_scope
from mmengine.runner import load_checkpoint

from mmaction.registry import MODELS


def parse_args():
    parser = argparse.ArgumentParser(description='Get model flops and params')
    parser.add_argument('config', help='config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument(
        '--num_frames', type=int, default=8, help='number of input frames.')
    parser.add_argument(
        '--shape',
        type=int,
        nargs='+',
        default=[256, 455],
        help='input image size')
    parser.add_argument(
        '--device', type=str, default='cpu', help='CPU/CUDA device option')
    parser.add_argument(
        '--output_file',
        type=str,
        default='stdet.onnx',
        help='file name of the output onnx file')
    args = parser.parse_args()
    return args


class SpatialMaxPool3d(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        x = x.max(dim=-1, keepdim=True)[0]
        return x.max(dim=-2, keepdim=True)[0]


class SpatialAvgPool(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.mean(dim=(-1, -2), keepdims=True)


class TemporalMaxPool3d(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.max(dim=-3, keepdim=True)[0]


class TemporalAvgPool3d(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.mean(dim=-3, keepdim=True)


class GlobalPool2d(nn.Module):

    def __init__(self, pool_size, output_size, later_max=True):
        super().__init__()
        self.pool = nn.AvgPool2d(pool_size)
        self.max = later_max
        self.output_size = output_size

    def forward(self, x):
        x = self.pool(x)
        if self.max:
            x = x.max(dim=-1, keepdim=True)[0]
            x = x.max(dim=-2, keepdim=True)[0]
        else:
            x = x.mean(dim=(-1, -2), keepdims=True)
        x = x.expand(-1, -1, self.output_size, self.output_size)
        return x


class STDet(nn.Module):

    def __init__(self, base_model, input_tensor):
        super(STDet, self).__init__()
        self.backbone = base_model.backbone
        self.bbox_roi_extractor = base_model.roi_head.bbox_roi_extractor
        self.bbox_head = base_model.roi_head.bbox_head

        output_size = self.bbox_roi_extractor.global_pool.output_size
        pool_size = min(input_tensor.shape[-2:]) // 16 // output_size

        if isinstance(self.bbox_head.temporal_pool, nn.AdaptiveAvgPool3d):
            self.bbox_head.temporal_pool = TemporalAvgPool3d()
        else:
            self.bbox_head.temporal_pool = TemporalMaxPool3d()
        if isinstance(self.bbox_head.spatial_pool, nn.AdaptiveAvgPool3d):
            self.bbox_head.spatial_pool = SpatialAvgPool()
            self.bbox_roi_extractor.global_pool = GlobalPool2d(
                pool_size, output_size, later_max=False)
        else:
            self.bbox_head.spatial_pool = SpatialMaxPool3d()
            self.bbox_roi_extractor.global_pool = GlobalPool2d(
                pool_size, output_size, later_max=True)

    def forward(self, input_tensor, rois):
        feat = self.backbone(input_tensor)
        bbox_feats, _ = self.bbox_roi_extractor(feat, rois)
        cls_score = self.bbox_head(bbox_feats)
        return cls_score


def main():
    args = parse_args()
    config = Config.fromfile(args.config)

    if config.model.type != 'FastRCNN':
        print('This script serves the sole purpose of converting spatial '
              'temporal detection models in MMAction2 to ONNX files. Please '
              'note that attempting to convert other models using this script '
              'may not yield successful results.\n\n')

    init_default_scope(config.get('default_scope', 'mmaction'))

    base_model = MODELS.build(config.model)
    load_checkpoint(base_model, args.checkpoint, map_location='cpu')
    base_model.to(args.device)

    if len(args.shape) == 1:
        input_shape = (args.shape[0], args.shape[0])
    elif len(args.shape) == 2:
        input_shape = tuple(args.shape)
    else:
        raise ValueError('invalid input shape')

    input_tensor = torch.randn(1, 3, args.num_frames, *input_shape)
    input_tensor = input_tensor.clamp(-3, 3).to(args.device)
    proposal = torch.Tensor([[22., 59., 67., 157.], [186., 73., 217., 159.],
                             [407., 95., 431., 168.]])

    rois = bbox2roi([proposal]).to(args.device)

    model = STDet(base_model, input_tensor).to(args.device)
    model.eval()
    cls_score = model(input_tensor, rois)
    print(f'Model output shape: {cls_score.shape}')

    torch.onnx.export(
        model, (input_tensor, rois),
        args.output_file,
        input_names=['input_tensor', 'rois'],
        output_names=['cls_score'],
        export_params=True,
        do_constant_folding=True,
        verbose=False,
        opset_version=11,
        dynamic_axes={
            'input_tensor': {
                0: 'batch_size',
                3: 'height',
                4: 'width'
            },
            'rois': {
                0: 'total_num_bbox_for_the_batch'
            },
            'cls_score': {
                0: 'total_num_bbox_for_the_batch'
            }
        })

    print(f'Successfully export the onnx file to {args.output_file}')

    # Test exported file
    session = onnxruntime.InferenceSession(args.output_file)
    input_feed = {
        'input_tensor': input_tensor.cpu().data.numpy(),
        'rois': rois.cpu().data.numpy()
    }
    outputs = session.run(['cls_score'], input_feed=input_feed)
    outputs = outputs[0]
    diff = abs(cls_score.cpu().data.numpy() - outputs).max()
    if diff < 1e-5:
        print('The output difference is smaller than 1e-5.')


if __name__ == '__main__':
    main()