| | "Filter definitions, with pre-processing, post-processing and compilation methods." |
| |
|
| | import numpy as np |
| | import torch |
| | from torch import nn |
| | from common import AVAILABLE_FILTERS, INPUT_SHAPE |
| |
|
| | from concrete.fhe.compilation.compiler import Compiler |
| | from concrete.ml.common.utils import generate_proxy_function |
| | from concrete.ml.torch.numpy_module import NumpyModule |
| |
|
| |
|
| | class TorchIdentity(nn.Module): |
| | """Torch identity model.""" |
| |
|
| | def forward(self, x): |
| | """Identity forward pass. |
| | |
| | Args: |
| | x (torch.Tensor): The input image. |
| | |
| | Returns: |
| | x (torch.Tensor): The input image. |
| | """ |
| | return x |
| |
|
| |
|
| | class TorchInverted(nn.Module): |
| | """Torch inverted model.""" |
| |
|
| | def forward(self, x): |
| | """Forward pass for inverting an image's colors. |
| | |
| | Args: |
| | x (torch.Tensor): The input image. |
| | |
| | Returns: |
| | torch.Tensor: The (color) inverted image. |
| | """ |
| | return 255 - x |
| |
|
| |
|
| | class TorchRotate(nn.Module): |
| | """Torch rotated model.""" |
| |
|
| | def forward(self, x): |
| | """Forward pass for rotating an image. |
| | |
| | Args: |
| | x (torch.Tensor): The input image. |
| | |
| | Returns: |
| | torch.Tensor: The rotated image. |
| | """ |
| | return x.transpose(0, 1) |
| |
|
| |
|
| | class TorchConv(nn.Module): |
| | """Torch model with a single convolution operator.""" |
| |
|
| | def __init__(self, kernel, n_in_channels=3, n_out_channels=3, groups=1, threshold=None): |
| | """Initialize the filter. |
| | |
| | Args: |
| | kernel (np.ndarray): The convolution kernel to consider. |
| | """ |
| | super().__init__() |
| | self.kernel = torch.tensor(kernel, dtype=torch.int64) |
| | self.n_out_channels = n_out_channels |
| | self.n_in_channels = n_in_channels |
| | self.groups = groups |
| | self.threshold = threshold |
| |
|
| | def forward(self, x): |
| | """Forward pass with a single convolution using a 1D or 2D kernel. |
| | |
| | Args: |
| | x (torch.Tensor): The input image. |
| | |
| | Returns: |
| | torch.Tensor: The filtered image. |
| | """ |
| | |
| | stride = 1 |
| | kernel_shape = self.kernel.shape |
| |
|
| | |
| | |
| | if len(kernel_shape) == 1: |
| | self.kernel = self.kernel.repeat(self.n_out_channels) |
| | kernel = self.kernel.reshape( |
| | self.n_out_channels, |
| | self.n_in_channels // self.groups, |
| | 1, |
| | 1, |
| | ) |
| |
|
| | |
| | elif len(kernel_shape) == 2: |
| | kernel = self.kernel.expand( |
| | self.n_out_channels, |
| | self.n_in_channels // self.groups, |
| | kernel_shape[0], |
| | kernel_shape[1], |
| | ) |
| |
|
| |
|
| | else: |
| | raise ValueError( |
| | "Wrong kernel shape, only 1D or 2D kernels are accepted. Got kernel of shape " |
| | f"{kernel_shape}" |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | x = x.transpose(2, 0).unsqueeze(axis=0) |
| |
|
| | |
| | x = nn.functional.conv2d(x, kernel, stride=stride, groups=self.groups) |
| |
|
| | |
| | x = x.transpose(1, 3).reshape((x.shape[2], x.shape[3], self.n_out_channels)) |
| |
|
| | |
| | if self.threshold is not None: |
| | x -= self.threshold |
| |
|
| | return x |
| |
|
| |
|
| | class Filter: |
| | """Filter class used in the app.""" |
| |
|
| | def __init__(self, filter_name): |
| | """Initializing the filter class using a given filter. |
| | |
| | Most filters can be found at https://en.wikipedia.org/wiki/Kernel_(image_processing). |
| | |
| | Args: |
| | filter_name (str): The filter to consider. |
| | """ |
| |
|
| | assert filter_name in AVAILABLE_FILTERS, ( |
| | f"Unsupported image filter or transformation. Expected one of {*AVAILABLE_FILTERS,}, " |
| | f"but got {filter_name}", |
| | ) |
| |
|
| | |
| | self.filter_name = filter_name |
| | self.onnx_model = None |
| | self.fhe_circuit = None |
| | self.divide = None |
| |
|
| | |
| | if filter_name == "identity": |
| | self.torch_model = TorchIdentity() |
| |
|
| | elif filter_name == "inverted": |
| | self.torch_model = TorchInverted() |
| |
|
| | elif filter_name == "rotate": |
| | self.torch_model = TorchRotate() |
| |
|
| | elif filter_name == "black and white": |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | kernel = [299, 587, 114] |
| |
|
| | self.torch_model = TorchConv(kernel) |
| |
|
| | |
| | self.divide = 1000 |
| |
|
| |
|
| | elif filter_name == "blur": |
| | kernel = np.ones((3, 3)) |
| |
|
| | self.torch_model = TorchConv(kernel, groups=3) |
| |
|
| | |
| | self.divide = 9 |
| |
|
| | elif filter_name == "sharpen": |
| | kernel = [ |
| | [0, -1, 0], |
| | [-1, 5, -1], |
| | [0, -1, 0], |
| | ] |
| |
|
| | self.torch_model = TorchConv(kernel, groups=3) |
| |
|
| | elif filter_name == "ridge detection": |
| | kernel = [ |
| | [-1, -1, -1], |
| | [-1, 9, -1], |
| | [-1, -1, -1], |
| | ] |
| |
|
| | |
| | |
| | self.torch_model = TorchConv(kernel, threshold=900) |
| |
|
| |
|
| | def compile(self): |
| | """Compile the filter on a representative inputset.""" |
| | |
| | |
| | |
| | |
| | np.random.seed(42) |
| | inputset = tuple( |
| | np.random.randint(0, 256, size=(INPUT_SHAPE + (3, )), dtype=np.int64) for _ in range(100) |
| | ) |
| |
|
| | |
| | numpy_module = NumpyModule( |
| | self.torch_model, |
| | dummy_input=torch.from_numpy(inputset[0]), |
| | ) |
| |
|
| | |
| | |
| | |
| | numpy_filter_proxy, parameters_mapping = generate_proxy_function( |
| | numpy_module.numpy_forward, |
| | ["inputs"] |
| | ) |
| |
|
| | |
| | compiler = Compiler( |
| | numpy_filter_proxy, |
| | {parameters_mapping["inputs"]: "encrypted"}, |
| | ) |
| | self.fhe_circuit = compiler.compile(inputset) |
| |
|
| | return self.fhe_circuit |
| |
|
| | def post_processing(self, output_image): |
| | """Apply post-processing to the encrypted output images. |
| | |
| | Args: |
| | input_image (np.ndarray): The decrypted image to post-process. |
| | |
| | Returns: |
| | input_image (np.ndarray): The post-processed image. |
| | """ |
| | |
| | if self.divide is not None: |
| | output_image //= self.divide |
| |
|
| | |
| | output_image = output_image.clip(0, 255) |
| |
|
| | return output_image |
| |
|