Spaces:
Runtime error
Runtime error
| import tensorflow as tf | |
| from typing import Any, Tuple | |
| import tensorflow_addons as tfda | |
| class ResidualBlock(tf.keras.layers.Layer): | |
| def __init__(self, filter_num: int, filter_size: int, seed: Any = None, name=None, padding="default", | |
| instance_normalization: bool = False): | |
| super(ResidualBlock, self).__init__(name=name) | |
| self.filter_num = filter_num | |
| self.filter_size = filter_size | |
| self.seed = seed | |
| self.padding_type = padding | |
| self.activation_1 = tf.keras.layers.Activation("linear", trainable=False) | |
| if padding == "default": | |
| self.conv_1 = tf.keras.layers.Conv2D(filters=self.filter_num, kernel_size=self.filter_size, | |
| padding="same", trainable=True) | |
| elif padding == "reflect": | |
| self.pad_1 = ReflectionPadding2D(padding=(1, 1)) | |
| self.conv_1 = tf.keras.layers.Conv2D(filters=self.filter_num, kernel_size=self.filter_size, | |
| padding="valid", trainable=True) | |
| else: | |
| raise RuntimeError("Non valid padding type.") | |
| self.activation_2 = tf.keras.layers.Activation("relu") | |
| if instance_normalization: | |
| self.bn_1 = tfda.layers.InstanceNormalization(trainable=True) | |
| self.bn_2 = tfda.layers.InstanceNormalization(trainable=True) | |
| else: | |
| self.bn_1 = tf.keras.layers.BatchNormalization(trainable=True) | |
| self.bn_2 = tf.keras.layers.BatchNormalization(trainable=True) | |
| if padding == "default": | |
| self.conv_2 = tf.keras.layers.Conv2D(filters=self.filter_num, kernel_size=self.filter_size, | |
| padding="same", trainable=True) | |
| elif padding == "reflect": | |
| self.pad_2 = ReflectionPadding2D(padding=(1, 1)) | |
| self.conv_2 = tf.keras.layers.Conv2D(filters=self.filter_num, kernel_size=self.filter_size, | |
| padding="valid", trainable=True) | |
| else: | |
| raise RuntimeError("Non valid padding type.") | |
| self.activation_3 = tf.keras.layers.Activation("relu") | |
| def call(self, inputs, *args, **kwargs): | |
| identity = self.activation_1(inputs) | |
| x = identity | |
| if self.padding_type == "reflect": | |
| x = self.pad_1(x) | |
| x = self.conv_1(x) | |
| x = self.activation_2(x) | |
| x = self.bn_1(x) | |
| if self.padding_type == "reflect": | |
| x = self.pad_2(x) | |
| x = self.conv_2(x) | |
| x = self.bn_2(x) | |
| residual = tf.keras.layers.Add()([x, identity]) | |
| x = self.activation_3(residual) | |
| return x | |
| class ReflectionPadding2D(tf.keras.layers.Layer): | |
| def __init__(self, padding: Tuple[int, int]): | |
| super(ReflectionPadding2D, self).__init__() | |
| self.pad_width, self.pad_height = padding | |
| def call(self, inputs, *args, **kwargs): | |
| padding_tensor = tf.constant([ | |
| [0, 0], # Batch | |
| [self.pad_height, self.pad_height], # Height | |
| [self.pad_width, self.pad_width], # Width | |
| [0, 0] # Channels | |
| ]) | |
| return tf.pad(inputs, padding_tensor, mode="REFLECT") |