Spaces:
Running
Running
| """ | |
| Title: Consistency training with supervision | |
| Author: [Sayak Paul](https://twitter.com/RisingSayak) | |
| Date created: 2021/04/13 | |
| Last modified: 2021/04/19 | |
| Description: Training with consistency regularization for robustness against data distribution shifts. | |
| Accelerator: GPU | |
| """ | |
| """ | |
| Deep learning models excel in many image recognition tasks when the data is independent | |
| and identically distributed (i.i.d.). However, they can suffer from performance | |
| degradation caused by subtle distribution shifts in the input data (such as random | |
| noise, contrast change, and blurring). So, naturally, there arises a question of | |
| why. As discussed in [A Fourier Perspective on Model Robustness in Computer Vision](https://arxiv.org/pdf/1906.08988.pdf)), | |
| there's no reason for deep learning models to be robust against such shifts. Standard | |
| model training procedures (such as standard image classification training workflows) | |
| *don't* enable a model to learn beyond what's fed to it in the form of training data. | |
| In this example, we will be training an image classification model enforcing a sense of | |
| *consistency* inside it by doing the following: | |
| * Train a standard image classification model. | |
| * Train an _equal or larger_ model on a noisy version of the dataset (augmented using | |
| [RandAugment](https://arxiv.org/abs/1909.13719)). | |
| * To do this, we will first obtain predictions of the previous model on the clean images | |
| of the dataset. | |
| * We will then use these predictions and train the second model to match these | |
| predictions on the noisy variant of the same images. This is identical to the workflow of | |
| [*Knowledge Distillation*](https://keras.io/examples/vision/knowledge_distillation/) but | |
| since the student model is equal or larger in size this process is also referred to as | |
| ***Self-Training***. | |
| This overall training workflow finds its roots in works like | |
| [FixMatch](https://arxiv.org/abs/2001.07685), [Unsupervised Data Augmentation for Consistency Training](https://arxiv.org/abs/1904.12848), | |
| and [Noisy Student Training](https://arxiv.org/abs/1911.04252). Since this training | |
| process encourages a model yield consistent predictions for clean as well as noisy | |
| images, it's often referred to as *consistency training* or *training with consistency | |
| regularization*. Although the example focuses on using consistency training to enhance | |
| the robustness of models to common corruptions this example can also serve a template | |
| for performing _weakly supervised learning_. | |
| This example requires TensorFlow 2.4 or higher, as well as TensorFlow Hub and TensorFlow | |
| Models, which can be installed using the following command: | |
| """ | |
| """shell | |
| pip install -q tf-models-official tensorflow-addons | |
| """ | |
| """ | |
| ## Imports and setup | |
| """ | |
| from official.vision.image_classification.augment import RandAugment | |
| from tensorflow.keras import layers | |
| import tensorflow as tf | |
| import tensorflow_addons as tfa | |
| import matplotlib.pyplot as plt | |
| tf.random.set_seed(42) | |
| """ | |
| ## Define hyperparameters | |
| """ | |
| AUTO = tf.data.AUTOTUNE | |
| BATCH_SIZE = 128 | |
| EPOCHS = 5 | |
| CROP_TO = 72 | |
| RESIZE_TO = 96 | |
| """ | |
| ## Load the CIFAR-10 dataset | |
| """ | |
| (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() | |
| val_samples = 49500 | |
| new_train_x, new_y_train = x_train[: val_samples + 1], y_train[: val_samples + 1] | |
| val_x, val_y = x_train[val_samples:], y_train[val_samples:] | |
| """ | |
| ## Create TensorFlow `Dataset` objects | |
| """ | |
| # Initialize `RandAugment` object with 2 layers of | |
| # augmentation transforms and strength of 9. | |
| augmenter = RandAugment(num_layers=2, magnitude=9) | |
| """ | |
| For training the teacher model, we will only be using two geometric augmentation | |
| transforms: random horizontal flip and random crop. | |
| """ | |
| def preprocess_train(image, label, noisy=True): | |
| image = tf.image.random_flip_left_right(image) | |
| # We first resize the original image to a larger dimension | |
| # and then we take random crops from it. | |
| image = tf.image.resize(image, [RESIZE_TO, RESIZE_TO]) | |
| image = tf.image.random_crop(image, [CROP_TO, CROP_TO, 3]) | |
| if noisy: | |
| image = augmenter.distort(image) | |
| return image, label | |
| def preprocess_test(image, label): | |
| image = tf.image.resize(image, [CROP_TO, CROP_TO]) | |
| return image, label | |
| train_ds = tf.data.Dataset.from_tensor_slices((new_train_x, new_y_train)) | |
| validation_ds = tf.data.Dataset.from_tensor_slices((val_x, val_y)) | |
| test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)) | |
| """ | |
| We make sure `train_clean_ds` and `train_noisy_ds` are shuffled using the *same* seed to | |
| ensure their orders are exactly the same. This will be helpful during training the | |
| student model. | |
| """ | |
| # This dataset will be used to train the first model. | |
| train_clean_ds = ( | |
| train_ds.shuffle(BATCH_SIZE * 10, seed=42) | |
| .map(lambda x, y: (preprocess_train(x, y, noisy=False)), num_parallel_calls=AUTO) | |
| .batch(BATCH_SIZE) | |
| .prefetch(AUTO) | |
| ) | |
| # This prepares the `Dataset` object to use RandAugment. | |
| train_noisy_ds = ( | |
| train_ds.shuffle(BATCH_SIZE * 10, seed=42) | |
| .map(preprocess_train, num_parallel_calls=AUTO) | |
| .batch(BATCH_SIZE) | |
| .prefetch(AUTO) | |
| ) | |
| validation_ds = ( | |
| validation_ds.map(preprocess_test, num_parallel_calls=AUTO) | |
| .batch(BATCH_SIZE) | |
| .prefetch(AUTO) | |
| ) | |
| test_ds = ( | |
| test_ds.map(preprocess_test, num_parallel_calls=AUTO) | |
| .batch(BATCH_SIZE) | |
| .prefetch(AUTO) | |
| ) | |
| # This dataset will be used to train the second model. | |
| consistency_training_ds = tf.data.Dataset.zip((train_clean_ds, train_noisy_ds)) | |
| """ | |
| ## Visualize the datasets | |
| """ | |
| sample_images, sample_labels = next(iter(train_clean_ds)) | |
| plt.figure(figsize=(10, 10)) | |
| for i, image in enumerate(sample_images[:9]): | |
| ax = plt.subplot(3, 3, i + 1) | |
| plt.imshow(image.numpy().astype("int")) | |
| plt.axis("off") | |
| sample_images, sample_labels = next(iter(train_noisy_ds)) | |
| plt.figure(figsize=(10, 10)) | |
| for i, image in enumerate(sample_images[:9]): | |
| ax = plt.subplot(3, 3, i + 1) | |
| plt.imshow(image.numpy().astype("int")) | |
| plt.axis("off") | |
| """ | |
| ## Define a model building utility function | |
| We now define our model building utility. Our model is based on the [ResNet50V2 architecture](https://arxiv.org/abs/1603.05027). | |
| """ | |
| def get_training_model(num_classes=10): | |
| resnet50_v2 = tf.keras.applications.ResNet50V2( | |
| weights=None, | |
| include_top=False, | |
| input_shape=(CROP_TO, CROP_TO, 3), | |
| ) | |
| model = tf.keras.Sequential( | |
| [ | |
| layers.Input((CROP_TO, CROP_TO, 3)), | |
| layers.Rescaling(scale=1.0 / 127.5, offset=-1), | |
| resnet50_v2, | |
| layers.GlobalAveragePooling2D(), | |
| layers.Dense(num_classes), | |
| ] | |
| ) | |
| return model | |
| """ | |
| In the interest of reproducibility, we serialize the initial random weights of the | |
| teacher network. | |
| """ | |
| initial_teacher_model = get_training_model() | |
| initial_teacher_model.save_weights("initial_teacher_model.h5") | |
| """ | |
| ## Train the teacher model | |
| As noted in Noisy Student Training, if the teacher model is trained with *geometric | |
| ensembling* and when the student model is forced to mimic that, it leads to better | |
| performance. The original work uses [Stochastic Depth](https://arxiv.org/abs/1603.09382) | |
| and [Dropout](https://jmlr.org/papers/v15/srivastava14a.html) to bring in the ensembling | |
| part but for this example, we will use [Stochastic Weight Averaging](https://arxiv.org/abs/1803.05407) | |
| (SWA) which also resembles geometric ensembling. | |
| """ | |
| # Define the callbacks. | |
| reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(patience=3) | |
| early_stopping = tf.keras.callbacks.EarlyStopping( | |
| patience=10, restore_best_weights=True | |
| ) | |
| # Initialize SWA from tf-hub. | |
| SWA = tfa.optimizers.SWA | |
| # Compile and train the teacher model. | |
| teacher_model = get_training_model() | |
| teacher_model.load_weights("initial_teacher_model.h5") | |
| teacher_model.compile( | |
| # Notice that we are wrapping our optimizer within SWA | |
| optimizer=SWA(tf.keras.optimizers.Adam()), | |
| loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), | |
| metrics=["accuracy"], | |
| ) | |
| history = teacher_model.fit( | |
| train_clean_ds, | |
| epochs=EPOCHS, | |
| validation_data=validation_ds, | |
| callbacks=[reduce_lr, early_stopping], | |
| ) | |
| # Evaluate the teacher model on the test set. | |
| _, acc = teacher_model.evaluate(test_ds, verbose=0) | |
| print(f"Test accuracy: {acc*100}%") | |
| """ | |
| ## Define a self-training utility | |
| For this part, we will borrow the `Distiller` class from [this Keras Example](https://keras.io/examples/vision/knowledge_distillation/). | |
| """ | |
| # Majority of the code is taken from: | |
| # https://keras.io/examples/vision/knowledge_distillation/ | |
| class SelfTrainer(tf.keras.Model): | |
| def __init__(self, student, teacher): | |
| super().__init__() | |
| self.student = student | |
| self.teacher = teacher | |
| def compile( | |
| self, | |
| optimizer, | |
| metrics, | |
| student_loss_fn, | |
| distillation_loss_fn, | |
| temperature=3, | |
| ): | |
| super().compile(optimizer=optimizer, metrics=metrics) | |
| self.student_loss_fn = student_loss_fn | |
| self.distillation_loss_fn = distillation_loss_fn | |
| self.temperature = temperature | |
| def train_step(self, data): | |
| # Since our dataset is a zip of two independent datasets, | |
| # after initially parsing them, we segregate the | |
| # respective images and labels next. | |
| clean_ds, noisy_ds = data | |
| clean_images, _ = clean_ds | |
| noisy_images, y = noisy_ds | |
| # Forward pass of teacher | |
| teacher_predictions = self.teacher(clean_images, training=False) | |
| with tf.GradientTape() as tape: | |
| # Forward pass of student | |
| student_predictions = self.student(noisy_images, training=True) | |
| # Compute losses | |
| student_loss = self.student_loss_fn(y, student_predictions) | |
| distillation_loss = self.distillation_loss_fn( | |
| tf.nn.softmax(teacher_predictions / self.temperature, axis=1), | |
| tf.nn.softmax(student_predictions / self.temperature, axis=1), | |
| ) | |
| total_loss = (student_loss + distillation_loss) / 2 | |
| # Compute gradients | |
| trainable_vars = self.student.trainable_variables | |
| gradients = tape.gradient(total_loss, trainable_vars) | |
| # Update weights | |
| self.optimizer.apply_gradients(zip(gradients, trainable_vars)) | |
| # Update the metrics configured in `compile()` | |
| self.compiled_metrics.update_state( | |
| y, tf.nn.softmax(student_predictions, axis=1) | |
| ) | |
| # Return a dict of performance | |
| results = {m.name: m.result() for m in self.metrics} | |
| results.update({"total_loss": total_loss}) | |
| return results | |
| def test_step(self, data): | |
| # During inference, we only pass a dataset consisting images and labels. | |
| x, y = data | |
| # Compute predictions | |
| y_prediction = self.student(x, training=False) | |
| # Update the metrics | |
| self.compiled_metrics.update_state(y, tf.nn.softmax(y_prediction, axis=1)) | |
| # Return a dict of performance | |
| results = {m.name: m.result() for m in self.metrics} | |
| return results | |
| """ | |
| The only difference in this implementation is the way loss is being calculated. **Instead | |
| of weighted the distillation loss and student loss differently we are taking their | |
| average following Noisy Student Training**. | |
| """ | |
| """ | |
| ## Train the student model | |
| """ | |
| # Define the callbacks. | |
| # We are using a larger decay factor to stabilize the training. | |
| reduce_lr = tf.keras.callbacks.ReduceLROnPlateau( | |
| patience=3, factor=0.5, monitor="val_accuracy" | |
| ) | |
| early_stopping = tf.keras.callbacks.EarlyStopping( | |
| patience=10, restore_best_weights=True, monitor="val_accuracy" | |
| ) | |
| # Compile and train the student model. | |
| self_trainer = SelfTrainer(student=get_training_model(), teacher=teacher_model) | |
| self_trainer.compile( | |
| # Notice we are *not* using SWA here. | |
| optimizer="adam", | |
| metrics=["accuracy"], | |
| student_loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), | |
| distillation_loss_fn=tf.keras.losses.KLDivergence(), | |
| temperature=10, | |
| ) | |
| history = self_trainer.fit( | |
| consistency_training_ds, | |
| epochs=EPOCHS, | |
| validation_data=validation_ds, | |
| callbacks=[reduce_lr, early_stopping], | |
| ) | |
| # Evaluate the student model. | |
| acc = self_trainer.evaluate(test_ds, verbose=0) | |
| print(f"Test accuracy from student model: {acc*100}%") | |
| """ | |
| ## Assess the robustness of the models | |
| A standard benchmark of assessing the robustness of vision models is to record their | |
| performance on corrupted datasets like ImageNet-C and CIFAR-10-C both of which were | |
| proposed in [Benchmarking Neural Network Robustness to Common Corruptions and | |
| Perturbations](https://arxiv.org/abs/1903.12261). For this example, we will be using the | |
| CIFAR-10-C dataset which has 19 different corruptions on 5 different severity levels. To | |
| assess the robustness of the models on this dataset, we will do the following: | |
| * Run the pre-trained models on the highest level of severities and obtain the top-1 | |
| accuracies. | |
| * Compute the mean top-1 accuracy. | |
| For the purpose of this example, we won't be going through these steps. This is why we | |
| trained the models for only 5 epochs. You can check out [this | |
| repository](https://github.com/sayakpaul/Consistency-Training-with-Supervision) that | |
| demonstrates the full-scale training experiments and also the aforementioned assessment. | |
| The figure below presents an executive summary of that assessment: | |
|  | |
| **Mean Top-1** results stand for the CIFAR-10-C dataset and **Test Top-1** results stand | |
| for the CIFAR-10 test set. It's clear that consistency training has an advantage on not | |
| only enhancing the model robustness but also on improving the standard test performance. | |
| """ | |