import math import cv2 import numpy as np import torch from torch import nn from ..image_processor import convert_to_gray, image_binarize, min_max_normalize class GaussianFilter(nn.Module): def __init__(self, module_size: int, filter_thres: float = 0.1) -> None: super().__init__() self.module_size = module_size self.filter_thres = filter_thres self.conv = nn.Conv2d( in_channels=1, out_channels=1, kernel_size=module_size, stride=module_size, padding=0, bias=False, groups=1, ) self._setup_filter_weights() def _setup_filter_weights(self) -> None: filter_1d = cv2.getGaussianKernel( ksize=self.module_size, sigma=1.5, ktype=cv2.CV_32F ) filter_2d = filter_1d * filter_1d.T filter_2d = min_max_normalize(filter_2d) filter_2d[filter_2d < self.filter_thres] = .0 gaussian_filter_init = torch.tensor(filter_2d, dtype=torch.float32) self.conv.weight = nn.Parameter( gaussian_filter_init.reshape(1, 1, *gaussian_filter_init.shape), requires_grad=False, ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.conv(x) class RegionMeanFilter(nn.Module): def __init__(self, module_size: int) -> None: super().__init__() self.module_size = module_size self.conv = nn.Conv2d( in_channels=1, out_channels=1, kernel_size=module_size, stride=module_size, padding=0, bias=None, groups=1, ) self._setup_kernel_weights() def _setup_kernel_weights(self) -> None: module_center = int(self.module_size / 2) radius = math.ceil(self.module_size / 6) center_filter = torch.zeros((1, 1, self.module_size, self.module_size)) center_filter[ :, :, module_center-radius : module_center+radius, module_center-radius : module_center+radius, ] = 1.0 self.conv.weight = nn.Parameter( center_filter / center_filter.sum(), requires_grad=False, ) @torch.no_grad() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.conv(x) class CenterPixelExtractor(nn.Module): def __init__(self, module_size: int) -> None: super().__init__() self.module_size = module_size self.conv = nn.Conv2d( in_channels=1, out_channels=1, kernel_size=module_size, stride=module_size, padding=0, bias=None, groups=1, ) self._setup_kernel_weights() def _setup_kernel_weights(self) -> None: module_center = int(self.module_size / 2) + 1 center_filter = torch.zeros((1, 1, self.module_size, self.module_size)) center_filter[:, :, module_center, module_center] = 1.0 self.conv.weight = nn.Parameter(center_filter, requires_grad=False) @torch.no_grad() def forward(self, x: torch.Tensor) -> torch.Tensor: return self.conv(x) class QRCodeErrorExtractor(nn.Module): def __init__(self, module_size: int) -> None: super().__init__() self.module_size = module_size self.region_mean_filter = RegionMeanFilter(module_size) self.center_pixel_extractor = CenterPixelExtractor(module_size=module_size) def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: x_center_mean = self.region_mean_filter(x) y_center_pixel = self.center_pixel_extractor(y) error_mask = (y_center_pixel == 0) & (x_center_mean > 0.45) | \ (y_center_pixel == 1) & (x_center_mean < 0.65) return error_mask.float() class ScanningRobustLoss(nn.Module): def __init__(self, module_size: int) -> None: super().__init__() self.gaussian_filter = GaussianFilter(module_size=module_size) self.center_filter = RegionMeanFilter(module_size=module_size) self.module_error_extractor = QRCodeErrorExtractor(module_size=module_size) def _compute_error(self, image: torch.Tensor, qrcode: torch.Tensor) -> torch.Tensor: gray_image = convert_to_gray(image) error0 = 2 * torch.relu(gray_image - 0.45) * (1 - qrcode) error1 = 2 * torch.relu(0.65 - gray_image) * qrcode return error0 + error1 def _compute_ealy_stopping_mask(self, image: torch.Tensor, qrcode: torch.Tensor) -> torch.Tensor: return self.module_error_extractor( convert_to_gray(image.clone().detach()), image_binarize(qrcode), ) def forward(self, image: torch.Tensor, qrcode: torch.Tensor) -> torch.Tensor: error = self._compute_error(image, qrcode) sample_error = self.gaussian_filter(error) mask = self._compute_ealy_stopping_mask(image, qrcode) return torch.mean(sample_error * mask)