Spaces:
Running
Running
| import tensorflow as tf | |
| from tensorflow import keras | |
| from tensorflow.keras import layers | |
| from layers.swin_blocks import SwinTransformer | |
| from utils.model_utils import * | |
| from utils.patch import PatchEmbedding | |
| from utils.patch import PatchExtract | |
| from utils.patch import PatchMerging | |
| class HybridSwinTransformer(keras.Model): | |
| def __init__(self, model_name, **kwargs): | |
| super().__init__(name=model_name, **kwargs) | |
| # base models | |
| base = keras.applications.EfficientNetB0( | |
| include_top=False, | |
| weights=None, | |
| input_tensor=keras.Input((params.image_size, params.image_size, 3)), | |
| ) | |
| # base model with compatible output which will be an input of transformer model | |
| self.new_base = keras.Model( | |
| [base.inputs], | |
| [base.get_layer("block6a_expand_activation").output, base.output], | |
| name="efficientnet", | |
| ) | |
| # stuff of swin transformers | |
| self.patch_extract = PatchExtract(patch_size) | |
| self.patch_embedds = PatchEmbedding(num_patch_x * num_patch_y, embed_dim) | |
| self.patch_merging = PatchMerging( | |
| (num_patch_x, num_patch_y), embed_dim=embed_dim | |
| ) | |
| # swin blocks containers | |
| self.swin_sequences = keras.Sequential(name="swin_blocks") | |
| for i in range(shift_size): | |
| self.swin_sequences.add( | |
| SwinTransformer( | |
| dim=embed_dim, | |
| num_patch=(num_patch_x, num_patch_y), | |
| num_heads=num_heads, | |
| window_size=window_size, | |
| shift_size=i, | |
| num_mlp=num_mlp, | |
| qkv_bias=qkv_bias, | |
| dropout_rate=dropout_rate, | |
| ) | |
| ) | |
| # swin block's head | |
| self.swin_head = keras.Sequential( | |
| [ | |
| layers.GlobalAveragePooling1D(), | |
| layers.AlphaDropout(0.5), | |
| layers.BatchNormalization(), | |
| ], | |
| name="swin_head", | |
| ) | |
| # base model's (cnn model) head | |
| self.conv_head = keras.Sequential( | |
| [ | |
| layers.GlobalAveragePooling2D(), | |
| layers.AlphaDropout(0.5), | |
| ], | |
| name="conv_head", | |
| ) | |
| # classifier | |
| self.classifier = layers.Dense( | |
| params.class_number, activation=None, dtype="float32" | |
| ) | |
| self.build_graph() | |
| def call(self, inputs, training=None, **kwargs): | |
| x, base_gcam_top = self.new_base(inputs) | |
| x = self.patch_extract(x) | |
| x = self.patch_embedds(x) | |
| x = self.swin_sequences(tf.cast(x, dtype=tf.float32)) | |
| x, swin_gcam_top = self.patch_merging(x) | |
| swin_top = self.swin_head(x) | |
| conv_top = self.conv_head(base_gcam_top) | |
| preds = self.classifier(tf.concat([swin_top, conv_top], axis=-1)) | |
| if training: # training phase | |
| return preds | |
| else: # inference phase | |
| return preds, base_gcam_top, swin_gcam_top | |
| def build_graph(self): | |
| x = keras.Input(shape=(params.image_size, params.image_size, 3)) | |
| return keras.Model(inputs=[x], outputs=self.call(x)) | |
| class GradientAccumulation(HybridSwinTransformer): | |
| """ref: https://gist.github.com/innat/ba6740293e7b7b227829790686f2119c""" | |
| def __init__(self, n_gradients, **kwargs): | |
| super().__init__(**kwargs) | |
| self.n_gradients = tf.constant(n_gradients, dtype=tf.int32) | |
| self.n_acum_step = tf.Variable(0, dtype=tf.int32, trainable=False) | |
| self.gradient_accumulation = [ | |
| tf.Variable(tf.zeros_like(v, dtype=tf.float32), trainable=False) | |
| for v in self.trainable_variables | |
| ] | |
| def train_step(self, data): | |
| # track accumulation step update | |
| self.n_acum_step.assign_add(1) | |
| # Unpack the data. Its structure depends on your model and | |
| # on what you pass to `fit()`. | |
| x, y = data | |
| with tf.GradientTape() as tape: | |
| y_pred = self(x, training=True) # Forward pass | |
| loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses) | |
| # Calculate batch gradients | |
| gradients = tape.gradient(loss, self.trainable_variables) | |
| # Accumulate batch gradients | |
| for i in range(len(self.gradient_accumulation)): | |
| self.gradient_accumulation[i].assign_add(gradients[i]) | |
| # If n_acum_step reach the n_gradients then we apply accumulated gradients to - | |
| # update the variables otherwise do nothing | |
| tf.cond( | |
| tf.equal(self.n_acum_step, self.n_gradients), | |
| self.apply_accu_gradients, | |
| lambda: None, | |
| ) | |
| # Return a dict mapping metric names to current value. | |
| # Note that it will include the loss (tracked in self.metrics). | |
| self.compiled_metrics.update_state(y, y_pred) | |
| return {m.name: m.result() for m in self.metrics} | |
| def apply_accu_gradients(self): | |
| # Update weights | |
| self.optimizer.apply_gradients( | |
| zip(self.gradient_accumulation, self.trainable_variables) | |
| ) | |
| # reset accumulation step | |
| self.n_acum_step.assign(0) | |
| for i in range(len(self.gradient_accumulation)): | |
| self.gradient_accumulation[i].assign( | |
| tf.zeros_like(self.trainable_variables[i], dtype=tf.float32) | |
| ) | |
| def test_step(self, data): | |
| # Unpack the data | |
| x, y = data | |
| # Compute predictions | |
| y_pred, base_gcam_top, swin_gcam_top = self(x, training=False) | |
| # Updates the metrics tracking the loss | |
| self.compiled_loss(y, y_pred, regularization_losses=self.losses) | |
| # Update the metrics. | |
| self.compiled_metrics.update_state(y, y_pred) | |
| # Return a dict mapping metric names to current value. | |
| # Note that it will include the loss (tracked in self.metrics). | |
| return {m.name: m.result() for m in self.metrics} | |