|
|
import tensorflow as tf |
|
|
from tensorflow.keras.__internal__.layers import BaseRandomLayer |
|
|
from tensorflow.keras.layers import ( |
|
|
Dense, Flatten, Conv2D, Activation, BatchNormalization, |
|
|
MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D, |
|
|
Dropout, Input, concatenate, add, Conv2DTranspose, Lambda, |
|
|
SpatialDropout2D, Cropping2D, UpSampling2D, LeakyReLU, |
|
|
ZeroPadding2D, Reshape, Concatenate, Multiply, Permute, Add |
|
|
) |
|
|
from keras import backend as K |
|
|
|
|
|
from .utils import normalize_tuple |
|
|
|
|
|
|
|
|
class MultipleTrackers(): |
|
|
def __init__(self, callback_lists: list): |
|
|
self.callbacks_list = callback_lists |
|
|
|
|
|
def __getattr__(self, attr): |
|
|
def helper(*arg, **kwarg): |
|
|
for cb in self.callbacks_list: |
|
|
getattr(cb, attr)(*arg, **kwarg) |
|
|
if attr in self.__class__.__dict__: |
|
|
return getattr(self, attr) |
|
|
else: |
|
|
return helper |
|
|
|
|
|
|
|
|
class DropBlockNoise(BaseRandomLayer): |
|
|
def __init__( |
|
|
self, |
|
|
rate, |
|
|
block_size, |
|
|
seed=None, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(seed=seed, **kwargs) |
|
|
if not 0.0 <= rate <= 1.0: |
|
|
raise ValueError( |
|
|
f"rate must be a number between 0 and 1. " f"Received: {rate}" |
|
|
) |
|
|
|
|
|
self._rate = rate |
|
|
( |
|
|
self._dropblock_height, |
|
|
self._dropblock_width, |
|
|
) = normalize_tuple( |
|
|
value=block_size, n=2, name="block_size", allow_zero=False |
|
|
) |
|
|
self.seed = seed |
|
|
|
|
|
def call(self, x, training=None): |
|
|
if not training or self._rate == 0.0: |
|
|
return x |
|
|
|
|
|
_, height, width, _ = tf.split(tf.shape(x), 4) |
|
|
|
|
|
|
|
|
height = tf.squeeze(height) |
|
|
width = tf.squeeze(width) |
|
|
|
|
|
dropblock_height = tf.math.minimum(self._dropblock_height, height) |
|
|
dropblock_width = tf.math.minimum(self._dropblock_width, width) |
|
|
|
|
|
gamma = ( |
|
|
self._rate |
|
|
* tf.cast(width * height, dtype=tf.float32) |
|
|
/ tf.cast(dropblock_height * dropblock_width, dtype=tf.float32) |
|
|
/ tf.cast( |
|
|
(width - self._dropblock_width + 1) |
|
|
* (height - self._dropblock_height + 1), |
|
|
tf.float32, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
w_i, h_i = tf.meshgrid(tf.range(width), tf.range(height)) |
|
|
valid_block = tf.logical_and( |
|
|
tf.logical_and( |
|
|
w_i >= int(dropblock_width // 2), |
|
|
w_i < width - (dropblock_width - 1) // 2, |
|
|
), |
|
|
tf.logical_and( |
|
|
h_i >= int(dropblock_height // 2), |
|
|
h_i < width - (dropblock_height - 1) // 2, |
|
|
), |
|
|
) |
|
|
|
|
|
valid_block = tf.reshape(valid_block, [1, height, width, 1]) |
|
|
|
|
|
random_noise = self._random_generator.random_uniform( |
|
|
tf.shape(x), dtype=tf.float32 |
|
|
) |
|
|
valid_block = tf.cast(valid_block, dtype=tf.float32) |
|
|
seed_keep_rate = tf.cast(1 - gamma, dtype=tf.float32) |
|
|
block_pattern = (1 - valid_block + seed_keep_rate + random_noise) >= 1 |
|
|
block_pattern = tf.cast(block_pattern, dtype=tf.float32) |
|
|
|
|
|
window_size = [1, self._dropblock_height, self._dropblock_width, 1] |
|
|
|
|
|
|
|
|
block_pattern = -tf.nn.max_pool( |
|
|
-block_pattern, |
|
|
ksize=window_size, |
|
|
strides=[1, 1, 1, 1], |
|
|
padding="SAME", |
|
|
) |
|
|
|
|
|
return ( |
|
|
x * tf.cast(block_pattern, x.dtype) |
|
|
) |
|
|
|
|
|
|
|
|
def squeeze_excite_block(input, ratio=16): |
|
|
''' Create a channel-wise squeeze-excite block |
|
|
|
|
|
Args: |
|
|
input: input tensor |
|
|
filters: number of output filters |
|
|
|
|
|
Returns: a keras tensor |
|
|
|
|
|
References |
|
|
- [Squeeze and Excitation Networks](https://arxiv.org/abs/1709.01507) |
|
|
''' |
|
|
init = input |
|
|
channel_axis = 1 if K.image_data_format() == "channels_first" else -1 |
|
|
filters = int(init.shape[channel_axis]) |
|
|
se_shape = (1, 1, filters) |
|
|
|
|
|
se = GlobalAveragePooling2D()(init) |
|
|
se = Reshape(se_shape)(se) |
|
|
se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se) |
|
|
se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se) |
|
|
|
|
|
if K.image_data_format() == 'channels_first': |
|
|
se = Permute((3, 1, 2))(se) |
|
|
|
|
|
x = Multiply()([init, se]) |
|
|
return x |
|
|
|
|
|
|
|
|
def spatial_squeeze_excite_block(input): |
|
|
''' Create a spatial squeeze-excite block |
|
|
|
|
|
Args: |
|
|
input: input tensor |
|
|
|
|
|
Returns: a keras tensor |
|
|
|
|
|
References |
|
|
- [Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks](https://arxiv.org/abs/1803.02579) |
|
|
''' |
|
|
|
|
|
se = Conv2D(1, (1, 1), activation='sigmoid', use_bias=False, |
|
|
kernel_initializer='he_normal')(input) |
|
|
|
|
|
x = Multiply()([input, se]) |
|
|
return x |
|
|
|
|
|
|
|
|
def channel_spatial_squeeze_excite(input, ratio=16): |
|
|
''' Create a spatial squeeze-excite block |
|
|
|
|
|
Args: |
|
|
input: input tensor |
|
|
filters: number of output filters |
|
|
|
|
|
Returns: a keras tensor |
|
|
|
|
|
References |
|
|
- [Squeeze and Excitation Networks](https://arxiv.org/abs/1709.01507) |
|
|
- [Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks](https://arxiv.org/abs/1803.02579) |
|
|
''' |
|
|
|
|
|
cse = squeeze_excite_block(input, ratio) |
|
|
sse = spatial_squeeze_excite_block(input) |
|
|
|
|
|
x = Add()([cse, sse]) |
|
|
return x |
|
|
|
|
|
|
|
|
def DoubleConv(filters, kernel_size, initializer='glorot_uniform'): |
|
|
def layer(x): |
|
|
|
|
|
x = Conv2D(filters, kernel_size, padding='same', kernel_initializer=initializer)(x) |
|
|
x = BatchNormalization()(x) |
|
|
x = Activation('swish')(x) |
|
|
x = Conv2D(filters, kernel_size, padding='same', kernel_initializer=initializer)(x) |
|
|
x = BatchNormalization()(x) |
|
|
x = Activation('swish')(x) |
|
|
|
|
|
return x |
|
|
|
|
|
return layer |
|
|
|
|
|
|
|
|
def UpSampling2D_block(filters, kernel_size=(3, 3), upsample_rate=(2, 2), interpolation='bilinear', |
|
|
initializer='glorot_uniform', skip=None): |
|
|
def layer(input_tensor): |
|
|
|
|
|
x = UpSampling2D(size=upsample_rate, interpolation=interpolation)(input_tensor) |
|
|
|
|
|
if skip is not None: |
|
|
x = Concatenate()([x, skip]) |
|
|
|
|
|
x = DoubleConv(filters, kernel_size, initializer=initializer)(x) |
|
|
x = channel_spatial_squeeze_excite(x) |
|
|
return x |
|
|
|
|
|
return layer |
|
|
|
|
|
|
|
|
def Conv2DTranspose_block(filters, transpose_kernel_size=(3, 3), upsample_rate=(2, 2), |
|
|
initializer='glorot_uniform', skip=None, met_input=None, sat_input=None): |
|
|
def layer(input_tensor): |
|
|
x = Conv2DTranspose(filters, transpose_kernel_size, strides=upsample_rate, padding='same')(input_tensor) |
|
|
if skip is not None: |
|
|
x = Concatenate()([x, skip]) |
|
|
|
|
|
x = DoubleConv(filters, transpose_kernel_size, initializer=initializer)(x) |
|
|
x = channel_spatial_squeeze_excite(x) |
|
|
return x |
|
|
|
|
|
return layer |
|
|
|
|
|
|
|
|
def PixelShuffle_block(filters, kernel_size=(3, 3), upsample_rate=2, |
|
|
initializer='glorot_uniform', skip=None, met_input=None, sat_input=None): |
|
|
def layer(input_tensor): |
|
|
x = Conv2D(filters * (upsample_rate ** 2), kernel_size, padding="same", |
|
|
activation="swish", kernel_initializer='Orthogonal')(input_tensor) |
|
|
x = tf.nn.depth_to_space(x, upsample_rate) |
|
|
if skip is not None: |
|
|
x = Concatenate()([x, skip]) |
|
|
|
|
|
x = DoubleConv(filters, kernel_size, initializer=initializer)(x) |
|
|
x = channel_spatial_squeeze_excite(x) |
|
|
return x |
|
|
|
|
|
return layer |