import torch import torch.nn as nn import torch.nn.functional as F class InputPreparer(nn.Module): def __init__(self): super().__init__() # smoothing and diff filters matrix_a = torch.tensor([[1., 2., 1.], [2., 4., 2.], [1., 2., 1.]], dtype=torch.float32) / 16.0 self.register_buffer('filter_pattern_a', matrix_a.view(1, 1, 3, 3)) matrix_b = torch.tensor([[-1., 0., 1.],[-2., 0., 2.],[-1., 0., 1.]], dtype=torch.float32).view(1, 1, 3, 3) matrix_c = torch.tensor([[-1., -2., -1.], [ 0., 0., 0.], [ 1., 2., 1.]], dtype=torch.float32).view(1, 1, 3, 3) self.register_buffer('filter_pattern_b', matrix_b) self.register_buffer('filter_pattern_c',matrix_c) self.gating_network = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(2,2, kernel_size=1), nn.Sigmoid() ) self.mapping_conv = nn.Conv2d(2, 32, kernel_size=3, padding=1, bias=False) self.normalization = nn.BatchNorm2d(32) def forward(self, x: torch.Tensor) -> torch.Tensor: filtered_input = F.conv2d(x, self.filter_pattern_a, padding=1) response_b = F.conv2d(filtered_input, self.filter_pattern_b, padding=1) response_c = F.conv2d(filtered_input, self.filter_pattern_c, padding=1) combined_response = torch.sqrt(response_b**2 + response_c**2+1e-5) integrated_features = torch.cat([x, combined_response], dim=1) modulated_features = integrated_features * self.gating_network(integrated_features) return F.silu(self.normalization(self.mapping_conv(modulated_features))) class MagnitudeScaler(nn.Module): def __init__(self, kernel_size=2, stride=2, padding=0): super().__init__() self.kernel_size = kernel_size self.stride = stride self.padding = padding def forward(self, x: torch.Tensor) -> torch.Tensor: squared_values = torch.clamp(x, min=0.0)**2 aggregated_values = F.avg_pool2d(squared_values, self.kernel_size, self.stride, self.padding) return torch.sqrt(aggregated_values + 1e-5) class FeatureWeighting(nn.Module): def __init__(self, kernel_size: int = 7): super().__init__() self.spatial_weighting = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2, bias=False) self.activation = nn.Sigmoid() def forward(self, x: torch.Tensor) -> torch.Tensor: mean_projection = torch.mean(x, dim=1, keepdim=True) max_projection, _ = torch.max(x, dim=1, keepdim=True) combined_projection = torch.cat([mean_projection, max_projection], dim=1) return x * self.activation(self.spatial_weighting(combined_projection)) class ProcessingBlock(nn.Module): def __init__(self, in_c: int, out_c: int, drop: float = 0.1) -> None: super().__init__() self.core_conv = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, bias=False) self.core_norm = nn.BatchNorm2d(out_c) self.refinement = FeatureWeighting() self.nonlinearity = nn.SiLU() self.regularization = nn.Dropout2d(p=drop) def forward(self, x: torch.Tensor) -> torch.Tensor: out = self.nonlinearity(self.core_norm(self.core_conv(x))) out = self.regularization(out) return self.refinement(out) class HierarchicalNetwork(nn.Module): def __init__(self, out_dims: int = 11): super().__init__() self.pre_processor = InputPreparer() self.stage_a = ProcessingBlock(32, 64, drop=0.1) self.downsampler_a = MagnitudeScaler(kernel_size=2, stride=2) self.stage_b = ProcessingBlock(64, 128, drop=0.1) self.downsampler_b = MagnitudeScaler(kernel_size=2, stride=2) self.stage_c = ProcessingBlock(128, 256, drop=0.1) self.global_reducer_a = nn.AdaptiveAvgPool2d(1) self.global_reducer_b = nn.AdaptiveMaxPool2d(1) self.decision_network = nn.Sequential( nn.Linear(256 * 2, 128), nn.SiLU(), nn.Dropout(0.2), nn.Linear(128, out_dims) ) self._reset_parameters() def _reset_parameters(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm2d): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.zeros_(m.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.pre_processor(x) x = self.downsampler_a(self.stage_a(x)) x = self.downsampler_b(self.stage_b(x)) x = self.stage_c(x) reduced_a = self.global_reducer_a(x).view(x.size(0), -1) reduced_b = self.global_reducer_b(x).view(x.size(0), -1) return self.decision_network(torch.cat([reduced_a, reduced_b], dim=1))