File size: 2,872 Bytes
d670799
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Copyright (c) OpenMMLab. All rights reserved.
import platform
from unittest.mock import MagicMock

import pytest
import torch

from mmaction.registry import MODELS
from mmaction.structures import ActionDataSample
from mmaction.testing import get_similarity_cfg
from mmaction.utils import register_all_modules


@pytest.mark.skipif(platform.system() == 'Windows', reason='Windows mem limit')
def test_clip_similarity():
    register_all_modules()
    cfg = get_similarity_cfg(
        'clip4clip/'
        'clip4clip_vit-base-p32-res224-clip-pre_8xb16-u12-5e_msrvtt-9k-rgb.py')
    cfg.model.frozen_layers = -1  # no frozen layers
    model = MODELS.build(cfg.model)
    model.train()

    data_batch = {
        'inputs': {
            'imgs': [torch.randint(0, 256, (2, 3, 224, 224))],
            'text': [torch.randint(0, 49408, (77, ))]
        },
        'data_samples': [ActionDataSample()]
    }

    # test train_step
    optim_wrapper = MagicMock()
    loss_vars = model.train_step(data_batch, optim_wrapper)
    assert 'loss' in loss_vars
    assert 'sim_loss_v2t' in loss_vars
    assert 'sim_loss_t2v' in loss_vars
    optim_wrapper.update_params.assert_called_once()

    # test test_step
    with torch.no_grad():
        predictions = model.test_step(data_batch)
    features = predictions[0].features
    assert len(predictions) == 1
    assert features.video_feature.size() == (512, )
    assert features.text_feature.size() == (512, )

    # test frozen layers
    def check_frozen_layers(mdl, frozen_layers):
        if frozen_layers >= 0:
            top_layers = [
                'ln_final', 'text_projection', 'logit_scale', 'visual.ln_post',
                'visual.proj'
            ]
            mid_layers = [
                'visual.transformer.resblocks', 'transformer.resblocks'
            ]

            for name, param in mdl.clip.named_parameters():
                if any(name.find(n) == 0 for n in top_layers):
                    assert param.requires_grad is True
                elif any(name.find(n) == 0 for n in mid_layers):
                    layer_n = int(name.split('.resblocks.')[1].split('.')[0])
                    if layer_n >= frozen_layers:
                        assert param.requires_grad is True
                    else:
                        assert param.requires_grad is False
                else:
                    assert param.requires_grad is False
        else:
            assert all([p.requires_grad for p in mdl.clip.parameters()])

    check_frozen_layers(model, -1)

    model.frozen_layers = 0
    model.train()
    check_frozen_layers(model, 0)

    model.frozen_layers = 6
    model.train()
    check_frozen_layers(model, 6)

    model.frozen_layers = 12
    model.train()
    check_frozen_layers(model, 12)