File size: 2,805 Bytes
278bf2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import numpy as np
import torch.nn as nn
import torch
from model.modules import ResidualAttentionNet, conv_block


class AVRA_rnn(nn.Module):
    def __init__(self, input_dims):
        super(AVRA_rnn, self).__init__()
        self.features = ResidualAttentionNet()
        input_size = [1, input_dims[0], input_dims[1]]
        self.l = self.get_flat_fts(input_size, self.features)
        self.hs = 256
        self.rnn = nn.LSTM(
            input_size=self.l,
            hidden_size=self.hs,
            num_layers=2,
            batch_first=True,
            bidirectional=False
        )
        f = 1
        if self.rnn.bidirectional:
            f = 2
        self.linear = nn.Linear(self.hs * f, 1)

    def get_flat_fts(self, in_size, fts):
        f = fts(torch.Tensor(torch.ones(1, *in_size)))
        return int(np.prod(f.size()[1:]))

    def forward(self, x, return_r_out=False):
        batch_size, timesteps, C, H, W = x.size()
        c_in = x.view(batch_size * timesteps, C, H, W)
        c_out = self.features(c_in)
        r_in = c_out.view(batch_size, timesteps, -1)
        r_out, (h_n, h_c) = self.rnn(r_in)
        r_out_last = self.linear(r_out[:, -1, :])
        if return_r_out:
            return r_out_last, r_in.view(batch_size, -1)
        return r_out_last


class VGG_bl(nn.Module):
    def __init__(self, input_dims):
        super().__init__()
        x, y, z = input_dims
        self.num_filters = [64, 128, 256, 512, 512]
        self.convxd = nn.Conv2d
        self.pooling = nn.MaxPool2d
        self.norm = nn.BatchNorm2d
        self.relu = nn.LeakyReLU
        self.features = nn.Sequential(
            conv_block(z, self.num_filters[0], False, self.convxd, self.norm, self.pooling, relu=self.relu),
            conv_block(self.num_filters[0], self.num_filters[1], False, self.convxd, self.norm, self.pooling, relu=self.relu),
            conv_block(self.num_filters[1], self.num_filters[2], True, self.convxd, self.norm, self.pooling, relu=self.relu),
            conv_block(self.num_filters[2], self.num_filters[3], True, self.convxd, self.norm, self.pooling, relu=self.relu),
            conv_block(self.num_filters[3], self.num_filters[4], True, self.convxd, self.norm, self.pooling, relu=self.relu),
        )
        a = (x // (2 ** np.shape(self.num_filters)[0])) * (y // (2 ** np.shape(self.num_filters)[0])) * self.num_filters[-1]
        a = int(a)
        N = 4096
        self.fc1 = nn.Sequential(
            nn.Linear(a, N),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(N, N),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(N, 1))

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        return x