Spaces:
Running
Running
| try: | |
| from jax import numpy as jnp | |
| except ModuleNotFoundError: | |
| # jax doesn't support windows os yet. | |
| import numpy as jnp | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| from tensorflow.keras import layers | |
| from layers.window_attention import WindowAttention | |
| from utils.drop_path import DropPath | |
| from utils.swin_window import window_partition | |
| from utils.swin_window import window_reverse | |
| class SwinTransformer(layers.Layer): | |
| def __init__( | |
| self, | |
| dim, | |
| num_patch, | |
| num_heads, | |
| window_size=7, | |
| shift_size=0, | |
| num_mlp=1024, | |
| qkv_bias=True, | |
| dropout_rate=0.0, | |
| **kwargs, | |
| ): | |
| super(SwinTransformer, self).__init__(**kwargs) | |
| self.dim = dim | |
| self.num_patch = num_patch | |
| self.num_heads = num_heads | |
| self.window_size = window_size | |
| self.shift_size = shift_size | |
| self.num_mlp = num_mlp | |
| self.norm1 = layers.LayerNormalization(epsilon=1e-5) | |
| self.attn = WindowAttention( | |
| dim, | |
| window_size=(self.window_size, self.window_size), | |
| num_heads=num_heads, | |
| qkv_bias=qkv_bias, | |
| dropout_rate=dropout_rate, | |
| ) | |
| self.drop_path = DropPath(dropout_rate) if dropout_rate > 0.0 else tf.identity | |
| self.norm2 = layers.LayerNormalization(epsilon=1e-5) | |
| self.mlp = keras.Sequential( | |
| [ | |
| layers.Dense(num_mlp), | |
| layers.Activation(keras.activations.gelu), | |
| layers.Dropout(dropout_rate), | |
| layers.Dense(dim), | |
| layers.Dropout(dropout_rate), | |
| ] | |
| ) | |
| if min(self.num_patch) < self.window_size: | |
| self.shift_size = 0 | |
| self.window_size = min(self.num_patch) | |
| def build(self, input_shape): | |
| if self.shift_size == 0: | |
| self.attn_mask = None | |
| else: | |
| height, width = self.num_patch | |
| h_slices = ( | |
| slice(0, -self.window_size), | |
| slice(-self.window_size, -self.shift_size), | |
| slice(-self.shift_size, None), | |
| ) | |
| w_slices = ( | |
| slice(0, -self.window_size), | |
| slice(-self.window_size, -self.shift_size), | |
| slice(-self.shift_size, None), | |
| ) | |
| mask_array = jnp.zeros((1, height, width, 1)) | |
| count = 0 | |
| for h in h_slices: | |
| for w in w_slices: | |
| mask_array[:, h, w, :] = count | |
| count += 1 | |
| mask_array = tf.convert_to_tensor(mask_array) | |
| # mask array to windows | |
| mask_windows = window_partition(mask_array, self.window_size) | |
| mask_windows = tf.reshape( | |
| mask_windows, shape=[-1, self.window_size * self.window_size] | |
| ) | |
| attn_mask = tf.expand_dims(mask_windows, axis=1) - tf.expand_dims( | |
| mask_windows, axis=2 | |
| ) | |
| attn_mask = tf.where(attn_mask != 0, -100.0, attn_mask) | |
| attn_mask = tf.where(attn_mask == 0, 0.0, attn_mask) | |
| self.attn_mask = tf.Variable(initial_value=attn_mask, trainable=False) | |
| def call(self, x): | |
| height, width = self.num_patch | |
| _, num_patches_before, channels = x.shape | |
| x_skip = x | |
| x = self.norm1(x) | |
| x = tf.reshape(x, shape=(-1, height, width, channels)) | |
| if self.shift_size > 0: | |
| shifted_x = tf.roll( | |
| x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2] | |
| ) | |
| else: | |
| shifted_x = x | |
| x_windows = window_partition(shifted_x, self.window_size) | |
| x_windows = tf.reshape( | |
| x_windows, shape=(-1, self.window_size * self.window_size, channels) | |
| ) | |
| attn_windows = self.attn(x_windows, mask=self.attn_mask) | |
| attn_windows = tf.reshape( | |
| attn_windows, shape=(-1, self.window_size, self.window_size, channels) | |
| ) | |
| shifted_x = window_reverse( | |
| attn_windows, self.window_size, height, width, channels | |
| ) | |
| if self.shift_size > 0: | |
| x = tf.roll( | |
| shifted_x, shift=[self.shift_size, self.shift_size], axis=[1, 2] | |
| ) | |
| else: | |
| x = shifted_x | |
| x = tf.reshape(x, shape=(-1, height * width, channels)) | |
| x = self.drop_path(x) | |
| x = tf.cast(x_skip, dtype=tf.float32) + tf.cast(x, dtype=tf.float32) | |
| x_skip = x | |
| x = self.norm2(x) | |
| x = self.mlp(x) | |
| x = self.drop_path(x) | |
| x = tf.cast(x_skip, dtype=tf.float32) + tf.cast(x, dtype=tf.float32) | |
| return x | |