File size: 4,102 Bytes
1c4c77a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
""" "This module contains an implementation of C3D model for video

processing."""

import itertools

import torch
from torch import Tensor, nn


class C3D(nn.Module):
    """The C3D network."""

    def __init__(self, pretrained=None):
        super().__init__()

        self.pretrained = pretrained

        self.conv1 = nn.Conv3d(3, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))

        self.conv2 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))

        self.conv3a = nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.conv3b = nn.Conv3d(256, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))

        self.conv4a = nn.Conv3d(256, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.conv4b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.pool4 = nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2))

        self.conv5a = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.conv5b = nn.Conv3d(512, 512, kernel_size=(3, 3, 3), padding=(1, 1, 1))
        self.pool5 = nn.MaxPool3d(
            kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 1, 1)
        )

        self.fc6 = nn.Linear(8192, 4096)
        self.relu = nn.ReLU()
        self.__init_weight()

        if pretrained:
            self.__load_pretrained_weights()

    def forward(self, x: Tensor):
        x = self.relu(self.conv1(x))
        x = self.pool1(x)
        x = self.relu(self.conv2(x))
        x = self.pool2(x)
        x = self.relu(self.conv3a(x))
        x = self.relu(self.conv3b(x))
        x = self.pool3(x)
        x = self.relu(self.conv4a(x))
        x = self.relu(self.conv4b(x))
        x = self.pool4(x)
        x = self.relu(self.conv5a(x))
        x = self.relu(self.conv5b(x))
        x = self.pool5(x)
        # x = x.view(-1, 8192)
        x = x.view(x.size(0), -1) # changed
        x = self.relu(self.fc6(x))

        return x

    def __load_pretrained_weights(self):
        """Initialiaze network."""
        corresp_name = [
            # Conv1
            "conv1.weight",
            "conv1.bias",
            # Conv2
            "conv2.weight",
            "conv2.bias",
            # Conv3a
            "conv3a.weight",
            "conv3a.bias",
            # Conv3b
            "conv3b.weight",
            "conv3b.bias",
            # Conv4a
            "conv4a.weight",
            "conv4a.bias",
            # Conv4b
            "conv4b.weight",
            "conv4b.bias",
            # Conv5a
            "conv5a.weight",
            "conv5a.bias",
            # Conv5b
            "conv5b.weight",
            "conv5b.bias",
            # fc6
            "fc6.weight",
            "fc6.bias",
        ]

        ignored_weights = [
            f"{layer}.{type_}"
            for layer, type_ in itertools.product(["fc7", "fc8"], ["bias", "weight"])
        ]

        p_dict = torch.load(self.pretrained)
        s_dict = self.state_dict()
        for name in p_dict:
            if name not in corresp_name:
                if name in ignored_weights:
                    continue
                print("no corresponding::", name)
                continue
            s_dict[name] = p_dict[name]
        self.load_state_dict(s_dict)

    def __init_weight(self):
        """Initialize weights of the model."""
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


if __name__ == "__main__":
    inputs = torch.ones((1, 3, 16, 112, 112))
    net = C3D(pretrained=False)

    outputs = net.forward(inputs)
    print(outputs.size())