import torch import torch.nn as nn import torch.nn.functional as F class wConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, den, stride=1, padding=1, groups=1, bias=False): super(wConv2d, self).__init__() self.stride = stride self.padding = padding self.kernel_size = kernel_size self.groups = groups self.weight = nn.Parameter(torch.empty(out_channels, in_channels // groups, kernel_size, kernel_size)) nn.init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='relu') self.bias = nn.Parameter(torch.zeros(out_channels)) if bias else None device = torch.device('cpu') self.register_buffer('alfa', torch.cat([torch.tensor(den, device=device),torch.tensor([1.0], device=device),torch.flip(torch.tensor(den, device=device), dims=[0])])) self.register_buffer('Phi', torch.outer(self.alfa, self.alfa)) if self.Phi.shape != (kernel_size, kernel_size): raise ValueError(f"Phi shape {self.Phi.shape} must match kernel size ({kernel_size}, {kernel_size})") def forward(self, x): Phi = self.Phi.to(x.device) weight_Phi = self.weight * Phi return F.conv2d(x, weight_Phi, bias=self.bias, stride=self.stride, padding=self.padding, groups=self.groups)