Spaces:
Running
Running
File size: 20,796 Bytes
9ce984a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 |
"""
Title: Self-supervised contrastive learning with NNCLR
Author: [Rishit Dagli](https://twitter.com/rishit_dagli)
Date created: 2021/09/13
Last modified: 2024/01/22
Description: Implementation of NNCLR, a self-supervised learning method for computer vision.
Accelerator: GPU
"""
"""
## Introduction
### Self-supervised learning
Self-supervised representation learning aims to obtain robust representations of samples
from raw data without expensive labels or annotations. Early methods in this field
focused on defining pretraining tasks which involved a surrogate task on a domain with ample
weak supervision labels. Encoders trained to solve such tasks are expected to
learn general features that might be useful for other downstream tasks requiring
expensive annotations like image classification.
### Contrastive Learning
A broad category of self-supervised learning techniques are those that use *contrastive
losses*, which have been used in a wide range of computer vision applications like
[image similarity](https://www.jmlr.org/papers/v11/chechik10a.html),
[dimensionality reduction (DrLIM)](http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf)
and [face verification/identification](https://openaccess.thecvf.com/content_cvpr_2015/html/Schroff_FaceNet_A_Unified_2015_CVPR_paper.html).
These methods learn a latent space that clusters positive samples together while
pushing apart negative samples.
### NNCLR
In this example, we implement NNCLR as proposed in the paper
[With a Little Help from My Friends: Nearest-Neighbor Contrastive Learning of Visual Representations](https://arxiv.org/abs/2104.14548),
by Google Research and DeepMind.
NNCLR learns self-supervised representations that go beyond single-instance positives, which
allows for learning better features that are invariant to different viewpoints, deformations,
and even intra-class variations.
Clustering based methods offer a great approach to go beyond single instance positives,
but assuming the entire cluster to be positives could hurt performance due to early
over-generalization. Instead, NNCLR uses nearest neighbors in the learned representation
space as positives.
In addition, NNCLR increases the performance of existing contrastive learning methods like
[SimCLR](https://arxiv.org/abs/2002.05709)([Keras Example](https://keras.io/examples/vision/semisupervised_simclr))
and reduces the reliance of self-supervised methods on data augmentation strategies.
Here is a great visualization by the paper authors showing how NNCLR builds on ideas from
SimCLR:

We can see that SimCLR uses two views of the same image as the positive pair. These two
views, which are produced using random data augmentations, are fed through an encoder to
obtain the positive embedding pair, we end up using two augmentations. NNCLR instead
keeps a _support set_ of embeddings representing the full data distribution, and forms
the positive pairs using nearest-neighbours. A support set is used as memory during
training, similar to a queue (i.e. first-in-first-out) as in
[MoCo](https://arxiv.org/abs/1911.05722).
This example requires `tensorflow_datasets`, which can
be installed with this command:
"""
"""shell
pip install tensorflow-datasets
"""
"""
## Setup
"""
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import keras_cv
from keras import ops
from keras import layers
"""
## Hyperparameters
A greater `queue_size` most likely means better performance as shown in the original
paper, but introduces significant computational overhead. The authors show that the best
results of NNCLR are achieved with a queue size of 98,304 (the largest `queue_size` they
experimented on). We here use 10,000 to show a working example.
"""
AUTOTUNE = tf.data.AUTOTUNE
shuffle_buffer = 5000
# The below two values are taken from https://www.tensorflow.org/datasets/catalog/stl10
labelled_train_images = 5000
unlabelled_images = 100000
temperature = 0.1
queue_size = 10000
contrastive_augmenter = {
"brightness": 0.5,
"name": "contrastive_augmenter",
"scale": (0.2, 1.0),
}
classification_augmenter = {
"brightness": 0.2,
"name": "classification_augmenter",
"scale": (0.5, 1.0),
}
input_shape = (96, 96, 3)
width = 128
num_epochs = 5 # Use 25 for better results
steps_per_epoch = 50 # Use 200 for better results
"""
## Load the Dataset
We load the [STL-10](http://ai.stanford.edu/~acoates/stl10/) dataset from
TensorFlow Datasets, an image recognition dataset for developing unsupervised
feature learning, deep learning, self-taught learning algorithms. It is inspired by the
CIFAR-10 dataset, with some modifications.
"""
dataset_name = "stl10"
def prepare_dataset():
unlabeled_batch_size = unlabelled_images // steps_per_epoch
labeled_batch_size = labelled_train_images // steps_per_epoch
batch_size = unlabeled_batch_size + labeled_batch_size
unlabeled_train_dataset = (
tfds.load(
dataset_name, split="unlabelled", as_supervised=True, shuffle_files=True
)
.shuffle(buffer_size=shuffle_buffer)
.batch(unlabeled_batch_size, drop_remainder=True)
)
labeled_train_dataset = (
tfds.load(dataset_name, split="train", as_supervised=True, shuffle_files=True)
.shuffle(buffer_size=shuffle_buffer)
.batch(labeled_batch_size, drop_remainder=True)
)
test_dataset = (
tfds.load(dataset_name, split="test", as_supervised=True)
.batch(batch_size)
.prefetch(buffer_size=AUTOTUNE)
)
train_dataset = tf.data.Dataset.zip(
(unlabeled_train_dataset, labeled_train_dataset)
).prefetch(buffer_size=AUTOTUNE)
return batch_size, train_dataset, labeled_train_dataset, test_dataset
batch_size, train_dataset, labeled_train_dataset, test_dataset = prepare_dataset()
"""
## Augmentations
Other self-supervised techniques like [SimCLR](https://arxiv.org/abs/2002.05709),
[BYOL](https://arxiv.org/abs/2006.07733), [SwAV](https://arxiv.org/abs/2006.09882) etc.
rely heavily on a well-designed data augmentation pipeline to get the best performance.
However, NNCLR is _less_ dependent on complex augmentations as nearest-neighbors already
provide richness in sample variations. A few common techniques often included
augmentation pipelines are:
- Random resized crops
- Multiple color distortions
- Gaussian blur
Since NNCLR is less dependent on complex augmentations, we will only use random
crops and random brightness for augmenting the input images.
"""
"""
### Prepare augmentation module
"""
def augmenter(brightness, name, scale):
return keras.Sequential(
[
layers.Input(shape=input_shape),
layers.Rescaling(1 / 255),
layers.RandomFlip("horizontal"),
keras_cv.layers.RandomCropAndResize(
target_size=(input_shape[0], input_shape[1]),
crop_area_factor=scale,
aspect_ratio_factor=(3 / 4, 4 / 3),
),
keras_cv.layers.RandomBrightness(factor=brightness, value_range=(0.0, 1.0)),
],
name=name,
)
"""
### Encoder architecture
Using a ResNet-50 as the encoder architecture
is standard in the literature. In the original paper, the authors use ResNet-50 as
the encoder architecture and spatially average the outputs of ResNet-50. However, keep in
mind that more powerful models will not only increase training time but will also
require more memory and will limit the maximal batch size you can use. For the purpose of
this example, we just use four convolutional layers.
"""
def encoder():
return keras.Sequential(
[
layers.Input(shape=input_shape),
layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
layers.Conv2D(width, kernel_size=3, strides=2, activation="relu"),
layers.Flatten(),
layers.Dense(width, activation="relu"),
],
name="encoder",
)
"""
## The NNCLR model for contrastive pre-training
We train an encoder on unlabeled images with a contrastive loss. A nonlinear projection
head is attached to the top of the encoder, as it improves the quality of representations
of the encoder.
"""
class NNCLR(keras.Model):
def __init__(
self,
temperature,
queue_size,
):
super().__init__()
self.probe_accuracy = keras.metrics.SparseCategoricalAccuracy()
self.correlation_accuracy = keras.metrics.SparseCategoricalAccuracy()
self.contrastive_accuracy = keras.metrics.SparseCategoricalAccuracy()
self.probe_loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
self.contrastive_augmenter = augmenter(**contrastive_augmenter)
self.classification_augmenter = augmenter(**classification_augmenter)
self.encoder = encoder()
self.projection_head = keras.Sequential(
[
layers.Input(shape=(width,)),
layers.Dense(width, activation="relu"),
layers.Dense(width),
],
name="projection_head",
)
self.linear_probe = keras.Sequential(
[layers.Input(shape=(width,)), layers.Dense(10)], name="linear_probe"
)
self.temperature = temperature
feature_dimensions = self.encoder.output_shape[1]
self.feature_queue = keras.Variable(
keras.utils.normalize(
keras.random.normal(shape=(queue_size, feature_dimensions)),
axis=1,
order=2,
),
trainable=False,
)
def compile(self, contrastive_optimizer, probe_optimizer, **kwargs):
super().compile(**kwargs)
self.contrastive_optimizer = contrastive_optimizer
self.probe_optimizer = probe_optimizer
def nearest_neighbour(self, projections):
support_similarities = ops.matmul(
projections, ops.transpose(self.feature_queue)
)
nn_projections = ops.take(
self.feature_queue, ops.argmax(support_similarities, axis=1), axis=0
)
return projections + ops.stop_gradient(nn_projections - projections)
def update_contrastive_accuracy(self, features_1, features_2):
features_1 = keras.utils.normalize(features_1, axis=1, order=2)
features_2 = keras.utils.normalize(features_2, axis=1, order=2)
similarities = ops.matmul(features_1, ops.transpose(features_2))
batch_size = ops.shape(features_1)[0]
contrastive_labels = ops.arange(batch_size)
self.contrastive_accuracy.update_state(
ops.concatenate([contrastive_labels, contrastive_labels], axis=0),
ops.concatenate([similarities, ops.transpose(similarities)], axis=0),
)
def update_correlation_accuracy(self, features_1, features_2):
features_1 = (features_1 - ops.mean(features_1, axis=0)) / ops.std(
features_1, axis=0
)
features_2 = (features_2 - ops.mean(features_2, axis=0)) / ops.std(
features_2, axis=0
)
batch_size = ops.shape(features_1)[0]
cross_correlation = (
ops.matmul(ops.transpose(features_1), features_2) / batch_size
)
feature_dim = ops.shape(features_1)[1]
correlation_labels = ops.arange(feature_dim)
self.correlation_accuracy.update_state(
ops.concatenate([correlation_labels, correlation_labels], axis=0),
ops.concatenate(
[cross_correlation, ops.transpose(cross_correlation)], axis=0
),
)
def contrastive_loss(self, projections_1, projections_2):
projections_1 = keras.utils.normalize(projections_1, axis=1, order=2)
projections_2 = keras.utils.normalize(projections_2, axis=1, order=2)
similarities_1_2_1 = (
ops.matmul(
self.nearest_neighbour(projections_1), ops.transpose(projections_2)
)
/ self.temperature
)
similarities_1_2_2 = (
ops.matmul(
projections_2, ops.transpose(self.nearest_neighbour(projections_1))
)
/ self.temperature
)
similarities_2_1_1 = ( #
ops.matmul(
self.nearest_neighbour(projections_2), ops.transpose(projections_1)
)
/ self.temperature
)
similarities_2_1_2 = (
ops.matmul(
projections_1, ops.transpose(self.nearest_neighbour(projections_2))
)
/ self.temperature
)
batch_size = ops.shape(projections_1)[0]
contrastive_labels = ops.arange(batch_size)
loss = keras.losses.sparse_categorical_crossentropy(
ops.concatenate(
[
contrastive_labels,
contrastive_labels,
contrastive_labels,
contrastive_labels,
],
axis=0,
),
ops.concatenate(
[
similarities_1_2_1,
similarities_1_2_2,
similarities_2_1_1,
similarities_2_1_2,
],
axis=0,
),
from_logits=True,
)
self.feature_queue.assign(
ops.concatenate([projections_1, self.feature_queue[:-batch_size]], axis=0)
)
return loss
def train_step(self, data):
(unlabeled_images, _), (labeled_images, labels) = data
images = ops.concatenate((unlabeled_images, labeled_images), axis=0)
augmented_images_1 = self.contrastive_augmenter(images)
augmented_images_2 = self.contrastive_augmenter(images)
with tf.GradientTape() as tape:
features_1 = self.encoder(augmented_images_1)
features_2 = self.encoder(augmented_images_2)
projections_1 = self.projection_head(features_1)
projections_2 = self.projection_head(features_2)
contrastive_loss = self.contrastive_loss(projections_1, projections_2)
gradients = tape.gradient(
contrastive_loss,
self.encoder.trainable_weights + self.projection_head.trainable_weights,
)
self.contrastive_optimizer.apply_gradients(
zip(
gradients,
self.encoder.trainable_weights + self.projection_head.trainable_weights,
)
)
self.update_contrastive_accuracy(features_1, features_2)
self.update_correlation_accuracy(features_1, features_2)
preprocessed_images = self.classification_augmenter(labeled_images)
with tf.GradientTape() as tape:
features = self.encoder(preprocessed_images)
class_logits = self.linear_probe(features)
probe_loss = self.probe_loss(labels, class_logits)
gradients = tape.gradient(probe_loss, self.linear_probe.trainable_weights)
self.probe_optimizer.apply_gradients(
zip(gradients, self.linear_probe.trainable_weights)
)
self.probe_accuracy.update_state(labels, class_logits)
return {
"c_loss": contrastive_loss,
"c_acc": self.contrastive_accuracy.result(),
"r_acc": self.correlation_accuracy.result(),
"p_loss": probe_loss,
"p_acc": self.probe_accuracy.result(),
}
def test_step(self, data):
labeled_images, labels = data
preprocessed_images = self.classification_augmenter(
labeled_images, training=False
)
features = self.encoder(preprocessed_images, training=False)
class_logits = self.linear_probe(features, training=False)
probe_loss = self.probe_loss(labels, class_logits)
self.probe_accuracy.update_state(labels, class_logits)
return {"p_loss": probe_loss, "p_acc": self.probe_accuracy.result()}
"""
## Pre-train NNCLR
We train the network using a `temperature` of 0.1 as suggested in the paper and
a `queue_size` of 10,000 as explained earlier. We use Adam as our contrastive and probe
optimizer. For this example we train the model for only 30 epochs but it should be
trained for more epochs for better performance.
The following two metrics can be used for monitoring the pretraining performance
which we also log (taken from
[this Keras example](https://keras.io/examples/vision/semisupervised_simclr/#selfsupervised-model-for-contrastive-pretraining)):
- Contrastive accuracy: self-supervised metric, the ratio of cases in which the
representation of an image is more similar to its differently augmented version's one,
than to the representation of any other image in the current batch. Self-supervised
metrics can be used for hyperparameter tuning even in the case when there are no labeled
examples.
- Linear probing accuracy: linear probing is a popular metric to evaluate self-supervised
classifiers. It is computed as the accuracy of a logistic regression classifier trained
on top of the encoder's features. In our case, this is done by training a single dense
layer on top of the frozen encoder. Note that contrary to traditional approach where the
classifier is trained after the pretraining phase, in this example we train it during
pretraining. This might slightly decrease its accuracy, but that way we can monitor its
value during training, which helps with experimentation and debugging.
"""
model = NNCLR(temperature=temperature, queue_size=queue_size)
model.compile(
contrastive_optimizer=keras.optimizers.Adam(),
probe_optimizer=keras.optimizers.Adam(),
jit_compile=False,
)
pretrain_history = model.fit(
train_dataset, epochs=num_epochs, validation_data=test_dataset
)
"""
## Evaluate our model
A popular way to evaluate a SSL method in computer vision or for that fact any other
pre-training method as such is to learn a linear classifier on the frozen features of the
trained backbone model and evaluate the classifier on unseen images. Other methods often
include fine-tuning on the source dataset or even a target dataset with 5% or 10% labels
present. You can use the backbone we just trained for any downstream task such as image
classification (like we do here) or segmentation or detection, where the backbone models
are usually pre-trained with supervised learning.
"""
finetuning_model = keras.Sequential(
[
layers.Input(shape=input_shape),
augmenter(**classification_augmenter),
model.encoder,
layers.Dense(10),
],
name="finetuning_model",
)
finetuning_model.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[keras.metrics.SparseCategoricalAccuracy(name="acc")],
jit_compile=False,
)
finetuning_history = finetuning_model.fit(
labeled_train_dataset, epochs=num_epochs, validation_data=test_dataset
)
"""
Self supervised learning is particularly helpful when you do only have access to very
limited labeled training data but you can manage to build a large corpus of unlabeled
data as shown by previous methods like [SEER](https://arxiv.org/abs/2103.01988),
[SimCLR](https://arxiv.org/abs/2002.05709), [SwAV](https://arxiv.org/abs/2006.09882) and
more.
You should also take a look at the blog posts for these papers which neatly show that it is
possible to achieve good results with few class labels by first pretraining on a large
unlabeled dataset and then fine-tuning on a smaller labeled dataset:
- [Advancing Self-Supervised and Semi-Supervised Learning with SimCLR](https://ai.googleblog.com/2020/04/advancing-self-supervised-and-semi.html)
- [High-performance self-supervised image classification with contrastive clustering](https://ai.facebook.com/blog/high-performance-self-supervised-image-classification-with-contrastive-clustering/)
- [Self-supervised learning: The dark matter of intelligence](https://ai.facebook.com/blog/self-supervised-learning-the-dark-matter-of-intelligence/)
You are also advised to check out the [original paper](https://arxiv.org/abs/2104.14548).
*Many thanks to [Debidatta Dwibedi](https://twitter.com/debidatta) (Google Research),
primary author of the NNCLR paper for his super-insightful reviews for this example.
This example also takes inspiration from the [SimCLR Keras Example](https://keras.io/examples/vision/semisupervised_simclr/).*
"""
|