MultivexAI commited on
Commit
0f5b4cb
·
verified ·
1 Parent(s): 51a15c5

Upload 2 files

Browse files
Files changed (2) hide show
  1. model.pt +3 -0
  2. model.py +127 -0
model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0740d68419d8bcf879803726364ac1a49b7c42fa6db171444828f85a74dd98ab
3
+ size 1837731
model.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class InputPreparer(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+ # smoothing and diff filters
9
+ matrix_a = torch.tensor([[1., 2., 1.],
10
+ [2., 4., 2.],
11
+ [1., 2., 1.]], dtype=torch.float32) / 16.0
12
+ self.register_buffer('filter_pattern_a', matrix_a.view(1, 1, 3, 3))
13
+
14
+ matrix_b = torch.tensor([[-1., 0., 1.],[-2., 0., 2.],[-1., 0., 1.]], dtype=torch.float32).view(1, 1, 3, 3)
15
+ matrix_c = torch.tensor([[-1., -2., -1.],
16
+ [ 0., 0., 0.],
17
+ [ 1., 2., 1.]], dtype=torch.float32).view(1, 1, 3, 3)
18
+ self.register_buffer('filter_pattern_b', matrix_b)
19
+ self.register_buffer('filter_pattern_c',matrix_c)
20
+
21
+ self.gating_network = nn.Sequential(
22
+ nn.AdaptiveAvgPool2d(1),
23
+ nn.Conv2d(2,2, kernel_size=1),
24
+ nn.Sigmoid()
25
+ )
26
+ self.mapping_conv = nn.Conv2d(2, 32, kernel_size=3, padding=1, bias=False)
27
+ self.normalization = nn.BatchNorm2d(32)
28
+
29
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
30
+ filtered_input = F.conv2d(x, self.filter_pattern_a, padding=1)
31
+ response_b = F.conv2d(filtered_input, self.filter_pattern_b, padding=1)
32
+ response_c = F.conv2d(filtered_input, self.filter_pattern_c, padding=1)
33
+ combined_response = torch.sqrt(response_b**2 + response_c**2+1e-5)
34
+
35
+ integrated_features = torch.cat([x, combined_response], dim=1)
36
+ modulated_features = integrated_features * self.gating_network(integrated_features)
37
+ return F.silu(self.normalization(self.mapping_conv(modulated_features)))
38
+
39
+
40
+ class MagnitudeScaler(nn.Module):
41
+ def __init__(self, kernel_size=2, stride=2, padding=0):
42
+ super().__init__()
43
+ self.kernel_size = kernel_size
44
+ self.stride = stride
45
+ self.padding = padding
46
+
47
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
48
+ squared_values = torch.clamp(x, min=0.0)**2
49
+ aggregated_values = F.avg_pool2d(squared_values, self.kernel_size, self.stride, self.padding)
50
+ return torch.sqrt(aggregated_values + 1e-5)
51
+
52
+
53
+ class FeatureWeighting(nn.Module):
54
+ def __init__(self, kernel_size: int = 7):
55
+ super().__init__()
56
+ self.spatial_weighting = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
57
+ self.activation = nn.Sigmoid()
58
+
59
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
60
+ mean_projection = torch.mean(x, dim=1, keepdim=True)
61
+ max_projection, _ = torch.max(x, dim=1, keepdim=True)
62
+ combined_projection = torch.cat([mean_projection, max_projection], dim=1)
63
+ return x * self.activation(self.spatial_weighting(combined_projection))
64
+
65
+
66
+ class ProcessingBlock(nn.Module):
67
+ def __init__(self, in_c: int, out_c: int, drop: float = 0.1) -> None:
68
+ super().__init__()
69
+ self.core_conv = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, bias=False)
70
+ self.core_norm = nn.BatchNorm2d(out_c)
71
+ self.refinement = FeatureWeighting()
72
+ self.nonlinearity = nn.SiLU()
73
+ self.regularization = nn.Dropout2d(p=drop)
74
+
75
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
76
+ out = self.nonlinearity(self.core_norm(self.core_conv(x)))
77
+ out = self.regularization(out)
78
+ return self.refinement(out)
79
+
80
+
81
+ class HierarchicalNetwork(nn.Module):
82
+ def __init__(self, out_dims: int = 11):
83
+ super().__init__()
84
+ self.pre_processor = InputPreparer()
85
+
86
+ self.stage_a = ProcessingBlock(32, 64, drop=0.1)
87
+ self.downsampler_a = MagnitudeScaler(kernel_size=2, stride=2)
88
+
89
+ self.stage_b = ProcessingBlock(64, 128, drop=0.1)
90
+ self.downsampler_b = MagnitudeScaler(kernel_size=2, stride=2)
91
+
92
+ self.stage_c = ProcessingBlock(128, 256, drop=0.1)
93
+ self.global_reducer_a = nn.AdaptiveAvgPool2d(1)
94
+ self.global_reducer_b = nn.AdaptiveMaxPool2d(1)
95
+
96
+ self.decision_network = nn.Sequential(
97
+ nn.Linear(256 * 2, 128),
98
+ nn.SiLU(),
99
+ nn.Dropout(0.2),
100
+ nn.Linear(128, out_dims)
101
+ )
102
+ self._reset_parameters()
103
+
104
+
105
+ def _reset_parameters(self):
106
+ for m in self.modules():
107
+ if isinstance(m, nn.Conv2d):
108
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
109
+ if m.bias is not None:
110
+ nn.init.zeros_(m.bias)
111
+ elif isinstance(m, nn.BatchNorm2d):
112
+ nn.init.ones_(m.weight)
113
+ nn.init.zeros_(m.bias)
114
+ elif isinstance(m, nn.Linear):
115
+ nn.init.normal_(m.weight, 0, 0.01)
116
+ nn.init.zeros_(m.bias)
117
+
118
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
119
+ x = self.pre_processor(x)
120
+ x = self.downsampler_a(self.stage_a(x))
121
+ x = self.downsampler_b(self.stage_b(x))
122
+ x = self.stage_c(x)
123
+
124
+ reduced_a = self.global_reducer_a(x).view(x.size(0), -1)
125
+ reduced_b = self.global_reducer_b(x).view(x.size(0), -1)
126
+
127
+ return self.decision_network(torch.cat([reduced_a, reduced_b], dim=1))