File size: 21,163 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModel
from mmengine.model.weight_init import constant_init, kaiming_init

from mmaction.registry import MODELS
from .utils import post_processing, temporal_iop


@MODELS.register_module()
class TEM(BaseModel):
    """Temporal Evaluation Model for Boundary Sensitive Network.



    Please refer `BSN: Boundary Sensitive Network for Temporal Action

    Proposal Generation <http://arxiv.org/abs/1806.02964>`_.

    Code reference

    https://github.com/wzmsltw/BSN-boundary-sensitive-network

    Args:

        temporal_dim (int): Total frames selected for each video.

        tem_feat_dim (int): Feature dimension.

        tem_hidden_dim (int): Hidden layer dimension.

        tem_match_threshold (float): Temporal evaluation match threshold.

        loss_cls (dict): Config for building loss.

            Default: ``dict(type='BinaryLogisticRegressionLoss')``.

        loss_weight (float): Weight term for action_loss. Default: 2.

        output_dim (int): Output dimension. Default: 3.

        conv1_ratio (float): Ratio of conv1 layer output. Default: 1.0.

        conv2_ratio (float): Ratio of conv2 layer output. Default: 1.0.

        conv3_ratio (float): Ratio of conv3 layer output. Default: 0.01.

    """

    def __init__(self,

                 temporal_dim,

                 boundary_ratio,

                 tem_feat_dim,

                 tem_hidden_dim,

                 tem_match_threshold,

                 loss_cls=dict(type='BinaryLogisticRegressionLoss'),

                 loss_weight=2,

                 output_dim=3,

                 conv1_ratio=1,

                 conv2_ratio=1,

                 conv3_ratio=0.01):
        super().__init__()

        self.temporal_dim = temporal_dim
        self.boundary_ratio = boundary_ratio
        self.feat_dim = tem_feat_dim
        self.c_hidden = tem_hidden_dim
        self.match_threshold = tem_match_threshold
        self.output_dim = output_dim
        self.loss_cls = MODELS.build(loss_cls)
        self.loss_weight = loss_weight
        self.conv1_ratio = conv1_ratio
        self.conv2_ratio = conv2_ratio
        self.conv3_ratio = conv3_ratio

        self.conv1 = nn.Conv1d(
            in_channels=self.feat_dim,
            out_channels=self.c_hidden,
            kernel_size=3,
            stride=1,
            padding=1,
            groups=1)
        self.conv2 = nn.Conv1d(
            in_channels=self.c_hidden,
            out_channels=self.c_hidden,
            kernel_size=3,
            stride=1,
            padding=1,
            groups=1)
        self.conv3 = nn.Conv1d(
            in_channels=self.c_hidden,
            out_channels=self.output_dim,
            kernel_size=1,
            stride=1,
            padding=0)
        self.anchors_tmins, self.anchors_tmaxs = self._temporal_anchors()

    def init_weights(self) -> None:
        """Initiate the parameters either from existing checkpoint or from

        scratch."""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                kaiming_init(m)
            elif isinstance(m, nn.BatchNorm2d):
                constant_init(m, 1)

    def _temporal_anchors(self, tmin_offset=0., tmax_offset=1.):
        """Generate temporal anchors.



        Args:

            tmin_offset (int): Offset for the minimum value of temporal anchor.

                Default: 0.

            tmax_offset (int): Offset for the maximum value of temporal anchor.

                Default: 1.

        Returns:

            tuple[Sequence[float]]: The minimum and maximum values of temporal

                anchors.

        """
        temporal_gap = 1. / self.temporal_dim
        anchors_tmins = []
        anchors_tmaxs = []
        for i in range(self.temporal_dim):
            anchors_tmins.append(temporal_gap * (i + tmin_offset))
            anchors_tmaxs.append(temporal_gap * (i + tmax_offset))

        return anchors_tmins, anchors_tmaxs

    def _forward(self, x):
        """Define the computation performed at every call.



        Args:

            x (torch.Tensor): The input data.

        Returns:

            torch.Tensor: The output of the module.

        """
        x = F.relu(self.conv1_ratio * self.conv1(x))
        x = F.relu(self.conv2_ratio * self.conv2(x))
        x = torch.sigmoid(self.conv3_ratio * self.conv3(x))
        return x

    def loss(self, batch_inputs, batch_data_samples, **kwargs):
        """Calculate losses from a batch of inputs and data samples.



        Args:

            batch_inputs (Tensor): Raw Inputs of the recognizer.

                These should usually be mean centered and std scaled.

            batch_data_samples (List[:obj:`ActionDataSample`]): The batch

                data samples. It usually includes information such

                as ``gt_labels``.



        Returns:

            dict: A dictionary of loss components.

        """
        tem_output = self._forward(batch_inputs)

        score_action = tem_output[:, 0, :]
        score_start = tem_output[:, 1, :]
        score_end = tem_output[:, 2, :]

        gt_bbox = [
            sample.gt_instances['gt_bbox'] for sample in batch_data_samples
        ]
        label_action, label_start, label_end = self.generate_labels(gt_bbox)
        device = batch_inputs.device
        label_action = label_action.to(device)
        label_start = label_start.to(device)
        label_end = label_end.to(device)

        loss_action = self.loss_cls(score_action, label_action,
                                    self.match_threshold)
        loss_start = self.loss_cls(score_start, label_start,
                                   self.match_threshold)
        loss_end = self.loss_cls(score_end, label_end, self.match_threshold)

        loss_dict = {
            'loss_action': loss_action * self.loss_weight,
            'loss_start': loss_start,
            'loss_end': loss_end
        }

        return loss_dict

    def predict(self, batch_inputs, batch_data_samples, **kwargs):
        """Define the computation performed at every call when testing."""
        tem_output = self._forward(batch_inputs).cpu().numpy()
        batch_action = tem_output[:, 0, :]
        batch_start = tem_output[:, 1, :]
        batch_end = tem_output[:, 2, :]

        video_results = []
        for batch_idx, _ in enumerate(batch_action):
            video_name = batch_data_samples[batch_idx].metainfo['video_name']
            video_action = batch_action[batch_idx]
            video_start = batch_start[batch_idx]
            video_end = batch_end[batch_idx]
            video_result = np.stack((video_action, video_start, video_end,
                                     self.anchors_tmins, self.anchors_tmaxs),
                                    axis=1)
            video_results.append((video_name, video_result))
        return video_results

    def generate_labels(self, gt_bbox):
        """Generate training labels."""
        # TODO: do this without numpy
        match_score_action_list = []
        match_score_start_list = []
        match_score_end_list = []
        for every_gt_bbox in gt_bbox:
            gt_tmins = every_gt_bbox[:, 0].cpu().numpy()
            gt_tmaxs = every_gt_bbox[:, 1].cpu().numpy()

            gt_lens = gt_tmaxs - gt_tmins
            gt_len_pad = np.maximum(1. / self.temporal_dim,
                                    self.boundary_ratio * gt_lens)

            gt_start_bboxs = np.stack(
                (gt_tmins - gt_len_pad / 2, gt_tmins + gt_len_pad / 2), axis=1)
            gt_end_bboxs = np.stack(
                (gt_tmaxs - gt_len_pad / 2, gt_tmaxs + gt_len_pad / 2), axis=1)

            match_score_action = []
            match_score_start = []
            match_score_end = []

            for anchor_tmin, anchor_tmax in zip(self.anchors_tmins,
                                                self.anchors_tmaxs):
                match_score_action.append(
                    np.max(
                        temporal_iop(anchor_tmin, anchor_tmax, gt_tmins,
                                     gt_tmaxs)))
                match_score_start.append(
                    np.max(
                        temporal_iop(anchor_tmin, anchor_tmax,
                                     gt_start_bboxs[:, 0], gt_start_bboxs[:,
                                                                          1])))
                match_score_end.append(
                    np.max(
                        temporal_iop(anchor_tmin, anchor_tmax,
                                     gt_end_bboxs[:, 0], gt_end_bboxs[:, 1])))
            match_score_action_list.append(match_score_action)
            match_score_start_list.append(match_score_start)
            match_score_end_list.append(match_score_end)
        match_score_action_list = torch.Tensor(match_score_action_list)
        match_score_start_list = torch.Tensor(match_score_start_list)
        match_score_end_list = torch.Tensor(match_score_end_list)
        return (match_score_action_list, match_score_start_list,
                match_score_end_list)

    def forward(self, inputs, data_samples, mode, **kwargs):
        """The unified entry for a forward process in both training and test.



        The method should accept three modes:



        - ``tensor``: Forward the whole network and return tensor or tuple of

        tensor without any post-processing, same as a common nn.Module.

        - ``predict``: Forward and return the predictions, which are fully

        processed to a list of :obj:`ActionDataSample`.

        - ``loss``: Forward and return a dict of losses according to the given

        inputs and data samples.



        Note that this method doesn't handle neither back propagation nor

        optimizer updating, which are done in the :meth:`train_step`.



        Args:

            inputs (Tensor): The input tensor with shape

                (N, C, ...) in general.

            data_samples (List[:obj:`ActionDataSample`], optional): The

                annotation data of every samples. Defaults to None.

            mode (str): Return what kind of value. Defaults to ``tensor``.



        Returns:

            The return type depends on ``mode``.



            - If ``mode="tensor"``, return a tensor or a tuple of tensor.

            - If ``mode="predict"``, return a list of ``ActionDataSample``.

            - If ``mode="loss"``, return a dict of tensor.

        """
        if type(inputs) is not torch.Tensor:
            inputs = torch.stack(inputs)

        if mode == 'tensor':
            return self._forward(inputs, **kwargs)
        if mode == 'predict':
            return self.predict(inputs, data_samples, **kwargs)
        elif mode == 'loss':
            return self.loss(inputs, data_samples, **kwargs)
        else:
            raise RuntimeError(f'Invalid mode "{mode}". '
                               'Only supports loss, predict and tensor mode')


@MODELS.register_module()
class PEM(BaseModel):
    """Proposals Evaluation Model for Boundary Sensitive Network.



    Please refer `BSN: Boundary Sensitive Network for Temporal Action

    Proposal Generation <http://arxiv.org/abs/1806.02964>`_.

    Code reference

    https://github.com/wzmsltw/BSN-boundary-sensitive-network

    Args:

        pem_feat_dim (int): Feature dimension.

        pem_hidden_dim (int): Hidden layer dimension.

        pem_u_ratio_m (float): Ratio for medium score proprosals to balance

            data.

        pem_u_ratio_l (float): Ratio for low score proprosals to balance data.

        pem_high_temporal_iou_threshold (float): High IoU threshold.

        pem_low_temporal_iou_threshold (float): Low IoU threshold.

        soft_nms_alpha (float): Soft NMS alpha.

        soft_nms_low_threshold (float): Soft NMS low threshold.

        soft_nms_high_threshold (float): Soft NMS high threshold.

        post_process_top_k (int): Top k proposals in post process.

        feature_extraction_interval (int):

            Interval used in feature extraction. Default: 16.

        fc1_ratio (float): Ratio for fc1 layer output. Default: 0.1.

        fc2_ratio (float): Ratio for fc2 layer output. Default: 0.1.

        output_dim (int): Output dimension. Default: 1.

    """

    def __init__(self,

                 pem_feat_dim: int,

                 pem_hidden_dim: int,

                 pem_u_ratio_m: float,

                 pem_u_ratio_l: float,

                 pem_high_temporal_iou_threshold: float,

                 pem_low_temporal_iou_threshold: float,

                 soft_nms_alpha: float,

                 soft_nms_low_threshold: float,

                 soft_nms_high_threshold: float,

                 post_process_top_k: int,

                 feature_extraction_interval: int = 16,

                 fc1_ratio: float = 0.1,

                 fc2_ratio: float = 0.1,

                 output_dim: int = 1):
        super().__init__()

        self.feat_dim = pem_feat_dim
        self.hidden_dim = pem_hidden_dim
        self.u_ratio_m = pem_u_ratio_m
        self.u_ratio_l = pem_u_ratio_l
        self.pem_high_temporal_iou_threshold = pem_high_temporal_iou_threshold
        self.pem_low_temporal_iou_threshold = pem_low_temporal_iou_threshold
        self.soft_nms_alpha = soft_nms_alpha
        self.soft_nms_low_threshold = soft_nms_low_threshold
        self.soft_nms_high_threshold = soft_nms_high_threshold
        self.post_process_top_k = post_process_top_k
        self.feature_extraction_interval = feature_extraction_interval
        self.fc1_ratio = fc1_ratio
        self.fc2_ratio = fc2_ratio
        self.output_dim = output_dim

        self.fc1 = nn.Linear(
            in_features=self.feat_dim, out_features=self.hidden_dim, bias=True)
        self.fc2 = nn.Linear(
            in_features=self.hidden_dim,
            out_features=self.output_dim,
            bias=True)

    def init_weights(self) -> None:
        """Initiate the parameters either from existing checkpoint or from

        scratch."""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                kaiming_init(m)
            elif isinstance(m, nn.BatchNorm2d):
                constant_init(m, 1)

    def _forward(self, x):
        """Define the computation performed at every call.



        Args:

            x (torch.Tensor): The input data.

        Returns:

            torch.Tensor: The output of the module.

        """
        x = F.relu(self.fc1_ratio * self.fc1(x))
        x = torch.sigmoid(self.fc2_ratio * self.fc2(x))
        return x

    def loss(self, batch_inputs, batch_data_samples, **kwargs):
        """Calculate losses from a batch of inputs and data samples.



        Args:

            batch_inputs (Tensor): Raw Inputs of the recognizer.

                These should usually be mean centered and std scaled.

            batch_data_samples (List[:obj:`ActionDataSample`]): The batch

                data samples. It usually includes information such

                as ``gt_labels``.



        Returns:

            dict: A dictionary of loss components.

        """
        device = self.fc1.weight.device

        bsp_feature = torch.cat([
            sample.gt_instances['bsp_feature'] for sample in batch_data_samples
        ]).to(device)

        reference_temporal_iou = torch.cat([
            sample.gt_instances['reference_temporal_iou']
            for sample in batch_data_samples
        ]).to(device)

        pem_output = self._forward(bsp_feature)

        anchors_temporal_iou = pem_output.view(-1)
        u_hmask = (reference_temporal_iou >
                   self.pem_high_temporal_iou_threshold).float()
        u_mmask = (
            (reference_temporal_iou <= self.pem_high_temporal_iou_threshold)
            & (reference_temporal_iou > self.pem_low_temporal_iou_threshold)
        ).float()
        u_lmask = (reference_temporal_iou <=
                   self.pem_low_temporal_iou_threshold).float()

        num_h = torch.sum(u_hmask)
        num_m = torch.sum(u_mmask)
        num_l = torch.sum(u_lmask)

        r_m = self.u_ratio_m * num_h / (num_m)
        r_m = torch.min(r_m, torch.Tensor([1.0]).to(device))[0]
        u_smmask = torch.rand(u_hmask.size()[0], device=device)
        u_smmask = u_smmask * u_mmask
        u_smmask = (u_smmask > (1. - r_m)).float()

        r_l = self.u_ratio_l * num_h / (num_l)
        r_l = torch.min(r_l, torch.Tensor([1.0]).to(device))[0]
        u_slmask = torch.rand(u_hmask.size()[0], device=device)
        u_slmask = u_slmask * u_lmask
        u_slmask = (u_slmask > (1. - r_l)).float()

        temporal_iou_weights = u_hmask + u_smmask + u_slmask
        temporal_iou_loss = F.smooth_l1_loss(anchors_temporal_iou,
                                             reference_temporal_iou)
        temporal_iou_loss = torch.sum(
            temporal_iou_loss *
            temporal_iou_weights) / torch.sum(temporal_iou_weights)
        loss_dict = dict(temporal_iou_loss=temporal_iou_loss)

        return loss_dict

    def _parse(self, gt_instances, key):
        out = torch.cat([gt[key] for gt in gt_instances])
        out = out.view(-1).cpu().numpy().reshape(-1, 1)
        return out

    def predict(self, batch_inputs, batch_data_samples, **kwargs):
        """Define the computation performed at every call when testing."""
        device = self.fc1.weight.device

        bsp_feature = torch.cat([
            sample.gt_instances['bsp_feature'] for sample in batch_data_samples
        ]).to(device)

        pem_output = self._forward(bsp_feature).view(-1).cpu().numpy()
        pem_output = pem_output.reshape(-1, 1)

        gt_instances = [sample.gt_instances for sample in batch_data_samples]

        tmin = self._parse(gt_instances, 'tmin')
        tmax = self._parse(gt_instances, 'tmax')
        tmin_score = self._parse(gt_instances, 'tmin_score')
        tmax_score = self._parse(gt_instances, 'tmax_score')

        score = np.array(pem_output * tmin_score * tmax_score).reshape(-1, 1)
        result = np.concatenate(
            (tmin, tmax, tmin_score, tmax_score, pem_output, score), axis=1)
        result = result.reshape(-1, 6)

        video_info = batch_data_samples[0].metainfo
        proposal_list = post_processing(result, video_info,
                                        self.soft_nms_alpha,
                                        self.soft_nms_low_threshold,
                                        self.soft_nms_high_threshold,
                                        self.post_process_top_k,
                                        self.feature_extraction_interval)
        output = [
            dict(
                video_name=video_info['video_name'],
                proposal_list=proposal_list)
        ]
        return output

    def forward(self, inputs, data_samples, mode, **kwargs):
        """The unified entry for a forward process in both training and test.



        The method should accept three modes:



        - ``tensor``: Forward the whole network and return tensor or tuple of

        tensor without any post-processing, same as a common nn.Module.

        - ``predict``: Forward and return the predictions, which are fully

        processed to a list of :obj:`ActionDataSample`.

        - ``loss``: Forward and return a dict of losses according to the given

        inputs and data samples.



        Note that this method doesn't handle neither back propagation nor

        optimizer updating, which are done in the :meth:`train_step`.



        Args:

            batch_inputs (Tensor): The input tensor with shape

                (N, C, ...) in general.

            batch_data_samples (List[:obj:`ActionDataSample`], optional): The

                annotation data of every samples. Defaults to None.

            mode (str): Return what kind of value. Defaults to ``tensor``.



        Returns:

            The return type depends on ``mode``.



            - If ``mode="tensor"``, return a tensor or a tuple of tensor.

            - If ``mode="predict"``, return a list of ``ActionDataSample``.

            - If ``mode="loss"``, return a dict of tensor.

        """
        inputs = torch.stack(inputs)
        if mode == 'tensor':
            return self._forward(inputs, **kwargs)
        if mode == 'predict':
            return self.predict(inputs, data_samples, **kwargs)
        elif mode == 'loss':
            return self.loss(inputs, data_samples, **kwargs)
        else:
            raise RuntimeError(f'Invalid mode "{mode}". '
                               'Only supports loss, predict and tensor mode')