File size: 5,777 Bytes
f2688f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import torch

class CNN2D(torch.nn.Module):
    
    def __init__(self, channels, conv_kernels, conv_strides, conv_padding, pool_padding, num_classes=15):
        assert len(conv_kernels) == len(channels) == len(conv_strides) == len(conv_padding)
        super(CNN2D, self).__init__()
        
        # Create convolutional blocks
        self.conv_blocks = torch.nn.ModuleList()
        prev_channel = 1  # Assuming the input is a grayscale image, modify if using more channels
        
        for i in range(len(channels)):
            # Add stacked conv layers
            block = []
            for j, conv_channel in enumerate(channels[i]):
                block.append(torch.nn.Conv2d(in_channels=prev_channel, out_channels=conv_channel, kernel_size=conv_kernels[i], stride=conv_strides[i], padding=conv_padding[i]))
                prev_channel = conv_channel
                # Add batch normalization
                block.append(torch.nn.BatchNorm2d(prev_channel))
                # Add ReLU activation
                block.append(torch.nn.ReLU())
            self.conv_blocks.append(torch.nn.Sequential(*block))

        # Create pooling blocks
        self.pool_blocks = torch.nn.ModuleList()
        for i in range(len(pool_padding)):
            # Adding Max Pool (reduces dimensions)
            self.pool_blocks.append(torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=pool_padding[i]))

        # Global pooling
        self.global_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
        self.linear = torch.nn.Linear(prev_channel, num_classes)

    def forward(self, inwav):
        for i in range(len(self.conv_blocks)):
            # Apply convolutional layer
            inwav = self.conv_blocks[i](inwav)
            # Apply max pooling
            if i < len(self.pool_blocks): inwav = self.pool_blocks[i](inwav)
        # Apply global pooling
        out = self.global_pool(inwav).squeeze()  # [batch_size, 256, 1, 1] -> [batch_size, 256]
        out = self.linear(out)  # [batch_size, num_classes]
        return out
    
class ResBlock2D(torch.nn.Module):
    
    def __init__(self, prev_channel, channel, conv_kernel, conv_stride, conv_pad):
        super(ResBlock2D, self).__init__()
        self.res = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=prev_channel, out_channels=channel, kernel_size=conv_kernel, stride=conv_stride, padding=conv_pad),
            torch.nn.BatchNorm2d(channel),
            torch.nn.ReLU(),
            torch.nn.Conv2d(in_channels=channel, out_channels=channel, kernel_size=conv_kernel, stride=conv_stride, padding=conv_pad),
            torch.nn.BatchNorm2d(channel),
        )
        self.bn = torch.nn.BatchNorm2d(channel)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        identity = x
        x = self.res(x)
        if x.shape[1] == identity.shape[1]:
            x += identity
        elif x.shape[1] > identity.shape[1]:
            if x.shape[1] % identity.shape[1] == 0:
                x += identity.repeat(1, x.shape[1]//identity.shape[1], 1, 1)
            else:
                raise RuntimeError("Dims in ResBlock needs to be divisible on the previous dims!!")
        else:
            if identity.shape[1] % x.shape[1] == 0:
                identity += x.repeat(1, identity.shape[1]//x.shape[1], 1, 1)
            else:
                raise RuntimeError("Dims in ResBlock needs to be divisible on the previous dims!!")
            x = identity
        x = self.bn(x)
        x = self.relu(x)
        return x
    
class CNNRes2D(torch.nn.Module):       
        
    def __init__(self, channels, conv_kernels, conv_strides, conv_padding, pool_padding, num_classes=15):
        assert len(conv_kernels) == len(channels) == len(conv_strides) == len(conv_padding)
        super(CNNRes2D, self).__init__()
        
        # Create initial convolutional block
        prev_channel = 1  # Assuming input has 1 channel, modify if needed
        self.conv_block = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels=prev_channel, out_channels=channels[0][0], kernel_size=conv_kernels[0], stride=conv_strides[0], padding=conv_padding[0]),
            torch.nn.BatchNorm2d(channels[0][0]),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=pool_padding[0]),
        )
        
        # Create residual blocks
        prev_channel = channels[0][0]
        self.res_blocks = torch.nn.ModuleList()
        for i in range(1, len(channels)):
            block = []
            for j, conv_channel in enumerate(channels[i]):
                block.append(ResBlock2D(prev_channel, conv_channel, conv_kernels[i], conv_strides[i], conv_padding[i]))
                prev_channel = conv_channel
            self.res_blocks.append(torch.nn.Sequential(*block))

        # Create pooling blocks
        self.pool_blocks = torch.nn.ModuleList()
        for i in range(1, len(pool_padding)):
            self.pool_blocks.append(torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=pool_padding[i]))

        # Global pooling
        self.global_pool = torch.nn.AdaptiveAvgPool2d((1, 1))
        self.linear = torch.nn.Linear(prev_channel, num_classes)

    def forward(self, inwav):
        inwav = self.conv_block(inwav)
        for i in range(len(self.res_blocks)):
            inwav = self.res_blocks[i](inwav)
            if i < len(self.pool_blocks): inwav = self.pool_blocks[i](inwav)
        out = self.global_pool(inwav).squeeze()
        out = self.linear(out)
        return out


# # Example instantiation of the network
# cnn2d_res = CNNRes2D(
#     channels=[[128], [128]*2],
#     conv_kernels=[(3, 3), (3, 3)],
#     conv_strides=[(1, 1), (1, 1)],
#     conv_padding=[(1, 1), (1, 1)],
#     pool_padding=[(0, 0), (0, 0)]
# )