MultivexAI commited on
Commit
ac9a5d2
·
verified ·
1 Parent(s): 764be8c

Delete model.py

Browse files
Files changed (1) hide show
  1. model.py +0 -127
model.py DELETED
@@ -1,127 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- class InputPreparer(nn.Module):
6
- def __init__(self):
7
- super().__init__()
8
- # smoothing and diff filters
9
- matrix_a = torch.tensor([[1., 2., 1.],
10
- [2., 4., 2.],
11
- [1., 2., 1.]], dtype=torch.float32) / 16.0
12
- self.register_buffer('filter_pattern_a', matrix_a.view(1, 1, 3, 3))
13
-
14
- matrix_b = torch.tensor([[-1., 0., 1.],[-2., 0., 2.],[-1., 0., 1.]], dtype=torch.float32).view(1, 1, 3, 3)
15
- matrix_c = torch.tensor([[-1., -2., -1.],
16
- [ 0., 0., 0.],
17
- [ 1., 2., 1.]], dtype=torch.float32).view(1, 1, 3, 3)
18
- self.register_buffer('filter_pattern_b', matrix_b)
19
- self.register_buffer('filter_pattern_c',matrix_c)
20
-
21
- self.gating_network = nn.Sequential(
22
- nn.AdaptiveAvgPool2d(1),
23
- nn.Conv2d(2,2, kernel_size=1),
24
- nn.Sigmoid()
25
- )
26
- self.mapping_conv = nn.Conv2d(2, 32, kernel_size=3, padding=1, bias=False)
27
- self.normalization = nn.BatchNorm2d(32)
28
-
29
- def forward(self, x: torch.Tensor) -> torch.Tensor:
30
- filtered_input = F.conv2d(x, self.filter_pattern_a, padding=1)
31
- response_b = F.conv2d(filtered_input, self.filter_pattern_b, padding=1)
32
- response_c = F.conv2d(filtered_input, self.filter_pattern_c, padding=1)
33
- combined_response = torch.sqrt(response_b**2 + response_c**2+1e-5)
34
-
35
- integrated_features = torch.cat([x, combined_response], dim=1)
36
- modulated_features = integrated_features * self.gating_network(integrated_features)
37
- return F.silu(self.normalization(self.mapping_conv(modulated_features)))
38
-
39
-
40
- class MagnitudeScaler(nn.Module):
41
- def __init__(self, kernel_size=2, stride=2, padding=0):
42
- super().__init__()
43
- self.kernel_size = kernel_size
44
- self.stride = stride
45
- self.padding = padding
46
-
47
- def forward(self, x: torch.Tensor) -> torch.Tensor:
48
- squared_values = torch.clamp(x, min=0.0)**2
49
- aggregated_values = F.avg_pool2d(squared_values, self.kernel_size, self.stride, self.padding)
50
- return torch.sqrt(aggregated_values + 1e-5)
51
-
52
-
53
- class FeatureWeighting(nn.Module):
54
- def __init__(self, kernel_size: int = 7):
55
- super().__init__()
56
- self.spatial_weighting = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)
57
- self.activation = nn.Sigmoid()
58
-
59
- def forward(self, x: torch.Tensor) -> torch.Tensor:
60
- mean_projection = torch.mean(x, dim=1, keepdim=True)
61
- max_projection, _ = torch.max(x, dim=1, keepdim=True)
62
- combined_projection = torch.cat([mean_projection, max_projection], dim=1)
63
- return x * self.activation(self.spatial_weighting(combined_projection))
64
-
65
-
66
- class ProcessingBlock(nn.Module):
67
- def __init__(self, in_c: int, out_c: int, drop: float = 0.1) -> None:
68
- super().__init__()
69
- self.core_conv = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, bias=False)
70
- self.core_norm = nn.BatchNorm2d(out_c)
71
- self.refinement = FeatureWeighting()
72
- self.nonlinearity = nn.SiLU()
73
- self.regularization = nn.Dropout2d(p=drop)
74
-
75
- def forward(self, x: torch.Tensor) -> torch.Tensor:
76
- out = self.nonlinearity(self.core_norm(self.core_conv(x)))
77
- out = self.regularization(out)
78
- return self.refinement(out)
79
-
80
-
81
- class HierarchicalNetwork(nn.Module):
82
- def __init__(self, out_dims: int = 11):
83
- super().__init__()
84
- self.pre_processor = InputPreparer()
85
-
86
- self.stage_a = ProcessingBlock(32, 64, drop=0.1)
87
- self.downsampler_a = MagnitudeScaler(kernel_size=2, stride=2)
88
-
89
- self.stage_b = ProcessingBlock(64, 128, drop=0.1)
90
- self.downsampler_b = MagnitudeScaler(kernel_size=2, stride=2)
91
-
92
- self.stage_c = ProcessingBlock(128, 256, drop=0.1)
93
- self.global_reducer_a = nn.AdaptiveAvgPool2d(1)
94
- self.global_reducer_b = nn.AdaptiveMaxPool2d(1)
95
-
96
- self.decision_network = nn.Sequential(
97
- nn.Linear(256 * 2, 128),
98
- nn.SiLU(),
99
- nn.Dropout(0.2),
100
- nn.Linear(128, out_dims)
101
- )
102
- self._reset_parameters()
103
-
104
-
105
- def _reset_parameters(self):
106
- for m in self.modules():
107
- if isinstance(m, nn.Conv2d):
108
- nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
109
- if m.bias is not None:
110
- nn.init.zeros_(m.bias)
111
- elif isinstance(m, nn.BatchNorm2d):
112
- nn.init.ones_(m.weight)
113
- nn.init.zeros_(m.bias)
114
- elif isinstance(m, nn.Linear):
115
- nn.init.normal_(m.weight, 0, 0.01)
116
- nn.init.zeros_(m.bias)
117
-
118
- def forward(self, x: torch.Tensor) -> torch.Tensor:
119
- x = self.pre_processor(x)
120
- x = self.downsampler_a(self.stage_a(x))
121
- x = self.downsampler_b(self.stage_b(x))
122
- x = self.stage_c(x)
123
-
124
- reduced_a = self.global_reducer_a(x).view(x.size(0), -1)
125
- reduced_b = self.global_reducer_b(x).view(x.size(0), -1)
126
-
127
- return self.decision_network(torch.cat([reduced_a, reduced_b], dim=1))