| "Filter definitions, with pre-processing, post-processing and compilation methods." |
|
|
| import numpy as np |
| import torch |
| from common import AVAILABLE_FILTERS, INPUT_SHAPE |
| from concrete.numpy.compilation.compiler import Compiler |
| from torch import nn |
|
|
| from concrete.ml.common.debugging.custom_assert import assert_true |
| from concrete.ml.common.utils import generate_proxy_function |
| from concrete.ml.onnx.convert import get_equivalent_numpy_forward |
| from concrete.ml.torch.numpy_module import NumpyModule |
|
|
|
|
| class _TorchIdentity(nn.Module): |
| """Torch identity model.""" |
|
|
| def forward(self, x): |
| """Identity forward pass. |
| |
| Args: |
| x (torch.Tensor): The input image. |
| |
| Returns: |
| x (torch.Tensor): The input image. |
| """ |
| return x |
|
|
|
|
| class _TorchInverted(nn.Module): |
| """Torch inverted model.""" |
|
|
| def forward(self, x): |
| """Forward pass for inverting an image's colors. |
| |
| Args: |
| x (torch.Tensor): The input image. |
| |
| Returns: |
| torch.Tensor: The (color) inverted image. |
| """ |
| return 255 - x |
|
|
|
|
| class _TorchRotate(nn.Module): |
| """Torch rotated model.""" |
|
|
| def forward(self, x): |
| """Forward pass for rotating an image. |
| |
| Args: |
| x (torch.Tensor): The input image. |
| |
| Returns: |
| torch.Tensor: The rotated image. |
| """ |
| return x.transpose(2, 3) |
|
|
|
|
| class _TorchConv2D(nn.Module): |
| """Torch model for applying a single 2D convolution operator on images.""" |
|
|
| def __init__(self, kernel, n_in_channels=3, n_out_channels=3, groups=1, threshold=None): |
| """Initialize the filter. |
| |
| Args: |
| kernel (np.ndarray): The convolution kernel to consider. |
| """ |
| super().__init__() |
| self.kernel = torch.tensor(kernel, dtype=torch.int64) |
| self.n_out_channels = n_out_channels |
| self.n_in_channels = n_in_channels |
| self.groups = groups |
| self.threshold = threshold |
|
|
| def forward(self, x): |
| """Forward pass for filtering the image using a 2D kernel. |
| |
| Args: |
| x (torch.Tensor): The input image. |
| |
| Returns: |
| torch.Tensor: The filtered image. |
| |
| """ |
| |
| stride = 1 |
| kernel_shape = self.kernel.shape |
|
|
| |
| |
| if len(kernel_shape) == 1: |
| kernel = self.kernel.reshape( |
| self.n_out_channels, |
| self.n_in_channels // self.groups, |
| 1, |
| 1, |
| ) |
|
|
| |
| elif len(kernel_shape) == 2: |
| kernel = self.kernel.expand( |
| self.n_out_channels, |
| self.n_in_channels // self.groups, |
| kernel_shape[0], |
| kernel_shape[1], |
| ) |
| else: |
| raise ValueError( |
| "Wrong kernel shape, only 1D or 2D kernels are accepted. Got kernel of shape " |
| f"{kernel_shape}" |
| ) |
|
|
| |
| x = nn.functional.conv2d(x, kernel, stride=stride, groups=self.groups) |
|
|
| |
| if self.threshold is not None: |
| x -= self.threshold |
|
|
| return x |
|
|
|
|
| class Filter: |
| """Filter class used in the app.""" |
|
|
| def __init__(self, filter_name): |
| """Initializing the filter class using a given filter. |
| |
| Most filters can be found at https://en.wikipedia.org/wiki/Kernel_(image_processing). |
| |
| Args: |
| filter_name (str): The filter to consider. |
| """ |
|
|
| assert_true( |
| filter_name in AVAILABLE_FILTERS, |
| f"Unsupported image filter or transformation. Expected one of {*AVAILABLE_FILTERS,}, " |
| f"but got {filter_name}", |
| ) |
|
|
| |
| |
| self.post_processing_params = {"filter_name": filter_name} |
| self.input_quantizers = [] |
| self.output_quantizers = [] |
|
|
| |
| self.filter = filter_name |
| self.onnx_model = None |
| self.fhe_circuit = None |
| self.divide = None |
| self.repeat_out_channels = False |
|
|
| |
| if filter_name == "identity": |
| self.torch_model = _TorchIdentity() |
|
|
| elif filter_name == "inverted": |
| self.torch_model = _TorchInverted() |
|
|
| elif filter_name == "rotate": |
| self.torch_model = _TorchRotate() |
|
|
| elif filter_name == "black and white": |
| |
| |
| |
| |
| |
| |
| |
| |
| kernel = [299, 587, 114] |
|
|
| self.torch_model = _TorchConv2D(kernel, n_out_channels=1, groups=1) |
|
|
| |
| self.divide = 1000 |
|
|
| |
| |
| self.repeat_out_channels = True |
|
|
| elif filter_name == "blur": |
| kernel = np.ones((3, 3)) |
|
|
| self.torch_model = _TorchConv2D(kernel, n_out_channels=3, groups=3) |
|
|
| |
| self.divide = 9 |
|
|
| elif filter_name == "sharpen": |
| kernel = [ |
| [0, -1, 0], |
| [-1, 5, -1], |
| [0, -1, 0], |
| ] |
|
|
| self.torch_model = _TorchConv2D(kernel, n_out_channels=3, groups=3) |
|
|
| elif filter_name == "ridge detection": |
| kernel = [ |
| [-1, -1, -1], |
| [-1, 9, -1], |
| [-1, -1, -1], |
| ] |
|
|
| |
| |
| self.torch_model = _TorchConv2D(kernel, n_out_channels=1, groups=1, threshold=900) |
|
|
| |
| |
| |
| self.repeat_out_channels = True |
|
|
| def compile(self, onnx_model=None): |
| """Compile the model on a representative inputset. |
| |
| Args: |
| onnx_model (onnx.ModelProto): The loaded onnx model to consider. If None, it will be |
| generated automatically using a NumpyModule. Default to None. |
| """ |
| |
| |
| |
| |
| np.random.seed(42) |
| inputset = tuple( |
| np.random.randint(0, 255, size=((1, 3) + INPUT_SHAPE), dtype=np.int64) for _ in range(10) |
| ) |
|
|
| |
| if onnx_model is None: |
| numpy_module = NumpyModule( |
| self.torch_model, |
| dummy_input=torch.from_numpy(inputset[0]), |
| ) |
|
|
| onnx_model = numpy_module.onnx_model |
|
|
| |
| self.onnx_model = onnx_model |
| numpy_filter = get_equivalent_numpy_forward(onnx_model) |
|
|
| numpy_filter_proxy, parameters_mapping = generate_proxy_function(numpy_filter, ["inputs"]) |
|
|
| compiler = Compiler( |
| numpy_filter_proxy, |
| {parameters_mapping["inputs"]: "encrypted"}, |
| ) |
|
|
| |
| self.fhe_circuit = compiler.compile(inputset) |
|
|
| return self.fhe_circuit |
|
|
| def quantize_input(self, input_image): |
| """Quantize the input. |
| |
| Images are already quantized in this case, however we need to define this method in order |
| to prevent the Concrete-ML client-server interface from breaking. |
| |
| Args: |
| input_image (np.ndarray): The input to quantize. |
| |
| Returns: |
| np.ndarray: The quantized input. |
| """ |
| return input_image |
|
|
| def pre_processing(self, input_image): |
| """Apply pre-processing to the encrypted input images. |
| |
| Args: |
| input_image (np.ndarray): The image to pre-process. |
| |
| Returns: |
| input_image (np.ndarray): The pre-processed image. |
| """ |
| |
| |
| |
| input_image = np.expand_dims(input_image.transpose(2, 0, 1), axis=0).astype(np.int64) |
|
|
| return input_image |
|
|
| def post_processing(self, output_image): |
| """Apply post-processing to the encrypted output images. |
| |
| Args: |
| input_image (np.ndarray): The decrypted image to post-process. |
| |
| Returns: |
| input_image (np.ndarray): The post-processed image. |
| """ |
| |
| if self.divide is not None: |
| output_image //= self.divide |
|
|
| |
| output_image = output_image.clip(0, 255) |
|
|
| |
| |
| output_image = output_image.transpose(0, 2, 3, 1).squeeze(0) |
|
|
| |
| if self.repeat_out_channels: |
| output_image = output_image.repeat(3, axis=2) |
|
|
| return output_image |
|
|