Spaces:
Running
Running
| """ | |
| Title: Using the Forward-Forward Algorithm for Image Classification | |
| Author: [Suvaditya Mukherjee](https://twitter.com/halcyonrayes) | |
| Date created: 2023/01/08 | |
| Last modified: 2024/09/17 | |
| Description: Training a Dense-layer model using the Forward-Forward algorithm. | |
| Accelerator: GPU | |
| """ | |
| """ | |
| ## Introduction | |
| The following example explores how to use the Forward-Forward algorithm to perform | |
| training instead of the traditionally-used method of backpropagation, as proposed by | |
| Hinton in | |
| [The Forward-Forward Algorithm: Some Preliminary Investigations](https://www.cs.toronto.edu/~hinton/FFA13.pdf) | |
| (2022). | |
| The concept was inspired by the understanding behind | |
| [Boltzmann Machines](http://www.cs.toronto.edu/~fritz/absps/dbm.pdf). Backpropagation | |
| involves calculating the difference between actual and predicted output via a cost | |
| function to adjust network weights. On the other hand, the FF Algorithm suggests the | |
| analogy of neurons which get "excited" based on looking at a certain recognized | |
| combination of an image and its correct corresponding label. | |
| This method takes certain inspiration from the biological learning process that occurs in | |
| the cortex. A significant advantage that this method brings is the fact that | |
| backpropagation through the network does not need to be performed anymore, and that | |
| weight updates are local to the layer itself. | |
| As this is yet still an experimental method, it does not yield state-of-the-art results. | |
| But with proper tuning, it is supposed to come close to the same. | |
| Through this example, we will examine a process that allows us to implement the | |
| Forward-Forward algorithm within the layers themselves, instead of the traditional method | |
| of relying on the global loss functions and optimizers. | |
| The tutorial is structured as follows: | |
| - Perform necessary imports | |
| - Load the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) | |
| - Visualize Random samples from the MNIST dataset | |
| - Define a `FFDense` Layer to override `call` and implement a custom `forwardforward` | |
| method which performs weight updates. | |
| - Define a `FFNetwork` Layer to override `train_step`, `predict` and implement 2 custom | |
| functions for per-sample prediction and overlaying labels | |
| - Convert MNIST from `NumPy` arrays to `tf.data.Dataset` | |
| - Fit the network | |
| - Visualize results | |
| - Perform inference on test samples | |
| As this example requires the customization of certain core functions with | |
| `keras.layers.Layer` and `keras.models.Model`, refer to the following resources for | |
| a primer on how to do so: | |
| - [Customizing what happens in `model.fit()`](https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit) | |
| - [Making new Layers and Models via subclassing](https://www.tensorflow.org/guide/keras/custom_layers_and_models) | |
| """ | |
| """ | |
| ## Setup imports | |
| """ | |
| import os | |
| os.environ["KERAS_BACKEND"] = "tensorflow" | |
| import tensorflow as tf | |
| import keras | |
| from keras import ops | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from sklearn.metrics import accuracy_score | |
| import random | |
| from tensorflow.compiler.tf2xla.python import xla | |
| """ | |
| ## Load the dataset and visualize the data | |
| We use the `keras.datasets.mnist.load_data()` utility to directly pull the MNIST dataset | |
| in the form of `NumPy` arrays. We then arrange it in the form of the train and test | |
| splits. | |
| Following loading the dataset, we select 4 random samples from within the training set | |
| and visualize them using `matplotlib.pyplot`. | |
| """ | |
| (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() | |
| print("4 Random Training samples and labels") | |
| idx1, idx2, idx3, idx4 = random.sample(range(0, x_train.shape[0]), 4) | |
| img1 = (x_train[idx1], y_train[idx1]) | |
| img2 = (x_train[idx2], y_train[idx2]) | |
| img3 = (x_train[idx3], y_train[idx3]) | |
| img4 = (x_train[idx4], y_train[idx4]) | |
| imgs = [img1, img2, img3, img4] | |
| plt.figure(figsize=(10, 10)) | |
| for idx, item in enumerate(imgs): | |
| image, label = item[0], item[1] | |
| plt.subplot(2, 2, idx + 1) | |
| plt.imshow(image, cmap="gray") | |
| plt.title(f"Label : {label}") | |
| plt.show() | |
| """ | |
| ## Define `FFDense` custom layer | |
| In this custom layer, we have a base `keras.layers.Dense` object which acts as the | |
| base `Dense` layer within. Since weight updates will happen within the layer itself, we | |
| add an `keras.optimizers.Optimizer` object that is accepted from the user. Here, we | |
| use `Adam` as our optimizer with a rather higher learning rate of `0.03`. | |
| Following the algorithm's specifics, we must set a `threshold` parameter that will be | |
| used to make the positive-negative decision in each prediction. This is set to a default | |
| of 2.0. | |
| As the epochs are localized to the layer itself, we also set a `num_epochs` parameter | |
| (defaults to 50). | |
| We override the `call` method in order to perform a normalization over the complete | |
| input space followed by running it through the base `Dense` layer as would happen in a | |
| normal `Dense` layer call. | |
| We implement the Forward-Forward algorithm which accepts 2 kinds of input tensors, each | |
| representing the positive and negative samples respectively. We write a custom training | |
| loop here with the use of `tf.GradientTape()`, within which we calculate a loss per | |
| sample by taking the distance of the prediction from the threshold to understand the | |
| error and taking its mean to get a `mean_loss` metric. | |
| With the help of `tf.GradientTape()` we calculate the gradient updates for the trainable | |
| base `Dense` layer and apply them using the layer's local optimizer. | |
| Finally, we return the `call` result as the `Dense` results of the positive and negative | |
| samples while also returning the last `mean_loss` metric and all the loss values over a | |
| certain all-epoch run. | |
| """ | |
| class FFDense(keras.layers.Layer): | |
| """ | |
| A custom ForwardForward-enabled Dense layer. It has an implementation of the | |
| Forward-Forward network internally for use. | |
| This layer must be used in conjunction with the `FFNetwork` model. | |
| """ | |
| def __init__( | |
| self, | |
| units, | |
| init_optimizer, | |
| loss_metric, | |
| num_epochs=50, | |
| use_bias=True, | |
| kernel_initializer="glorot_uniform", | |
| bias_initializer="zeros", | |
| kernel_regularizer=None, | |
| bias_regularizer=None, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.dense = keras.layers.Dense( | |
| units=units, | |
| use_bias=use_bias, | |
| kernel_initializer=kernel_initializer, | |
| bias_initializer=bias_initializer, | |
| kernel_regularizer=kernel_regularizer, | |
| bias_regularizer=bias_regularizer, | |
| ) | |
| self.relu = keras.layers.ReLU() | |
| self.optimizer = init_optimizer() | |
| self.loss_metric = loss_metric | |
| self.threshold = 1.5 | |
| self.num_epochs = num_epochs | |
| # We perform a normalization step before we run the input through the Dense | |
| # layer. | |
| def call(self, x): | |
| x_norm = ops.norm(x, ord=2, axis=1, keepdims=True) | |
| x_norm = x_norm + 1e-4 | |
| x_dir = x / x_norm | |
| res = self.dense(x_dir) | |
| return self.relu(res) | |
| # The Forward-Forward algorithm is below. We first perform the Dense-layer | |
| # operation and then get a Mean Square value for all positive and negative | |
| # samples respectively. | |
| # The custom loss function finds the distance between the Mean-squared | |
| # result and the threshold value we set (a hyperparameter) that will define | |
| # whether the prediction is positive or negative in nature. Once the loss is | |
| # calculated, we get a mean across the entire batch combined and perform a | |
| # gradient calculation and optimization step. This does not technically | |
| # qualify as backpropagation since there is no gradient being | |
| # sent to any previous layer and is completely local in nature. | |
| def forward_forward(self, x_pos, x_neg): | |
| for i in range(self.num_epochs): | |
| with tf.GradientTape() as tape: | |
| g_pos = ops.mean(ops.power(self.call(x_pos), 2), 1) | |
| g_neg = ops.mean(ops.power(self.call(x_neg), 2), 1) | |
| loss = ops.log( | |
| 1 | |
| + ops.exp( | |
| ops.concatenate( | |
| [-g_pos + self.threshold, g_neg - self.threshold], 0 | |
| ) | |
| ) | |
| ) | |
| mean_loss = ops.cast(ops.mean(loss), dtype="float32") | |
| self.loss_metric.update_state([mean_loss]) | |
| gradients = tape.gradient(mean_loss, self.dense.trainable_weights) | |
| self.optimizer.apply_gradients(zip(gradients, self.dense.trainable_weights)) | |
| return ( | |
| ops.stop_gradient(self.call(x_pos)), | |
| ops.stop_gradient(self.call(x_neg)), | |
| self.loss_metric.result(), | |
| ) | |
| """ | |
| ## Define the `FFNetwork` Custom Model | |
| With our custom layer defined, we also need to override the `train_step` method and | |
| define a custom `keras.models.Model` that works with our `FFDense` layer. | |
| For this algorithm, we must 'embed' the labels onto the original image. To do so, we | |
| exploit the structure of MNIST images where the top-left 10 pixels are always zeros. We | |
| use that as a label space in order to visually one-hot-encode the labels within the image | |
| itself. This action is performed by the `overlay_y_on_x` function. | |
| We break down the prediction function with a per-sample prediction function which is then | |
| called over the entire test set by the overriden `predict()` function. The prediction is | |
| performed here with the help of measuring the `excitation` of the neurons per layer for | |
| each image. This is then summed over all layers to calculate a network-wide 'goodness | |
| score'. The label with the highest 'goodness score' is then chosen as the sample | |
| prediction. | |
| The `train_step` function is overriden to act as the main controlling loop for running | |
| training on each layer as per the number of epochs per layer. | |
| """ | |
| class FFNetwork(keras.Model): | |
| """ | |
| A `keras.Model` that supports a `FFDense` network creation. This model | |
| can work for any kind of classification task. It has an internal | |
| implementation with some details specific to the MNIST dataset which can be | |
| changed as per the use-case. | |
| """ | |
| # Since each layer runs gradient-calculation and optimization locally, each | |
| # layer has its own optimizer that we pass. As a standard choice, we pass | |
| # the `Adam` optimizer with a default learning rate of 0.03 as that was | |
| # found to be the best rate after experimentation. | |
| # Loss is tracked using `loss_var` and `loss_count` variables. | |
| def __init__( | |
| self, | |
| dims, | |
| init_layer_optimizer=lambda: keras.optimizers.Adam(learning_rate=0.03), | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.init_layer_optimizer = init_layer_optimizer | |
| self.loss_var = keras.Variable(0.0, trainable=False, dtype="float32") | |
| self.loss_count = keras.Variable(0.0, trainable=False, dtype="float32") | |
| self.layer_list = [keras.Input(shape=(dims[0],))] | |
| self.metrics_built = False | |
| for d in range(len(dims) - 1): | |
| self.layer_list += [ | |
| FFDense( | |
| dims[d + 1], | |
| init_optimizer=self.init_layer_optimizer, | |
| loss_metric=keras.metrics.Mean(), | |
| ) | |
| ] | |
| # This function makes a dynamic change to the image wherein the labels are | |
| # put on top of the original image (for this example, as MNIST has 10 | |
| # unique labels, we take the top-left corner's first 10 pixels). This | |
| # function returns the original data tensor with the first 10 pixels being | |
| # a pixel-based one-hot representation of the labels. | |
| def overlay_y_on_x(self, data): | |
| X_sample, y_sample = data | |
| max_sample = ops.amax(X_sample, axis=0, keepdims=True) | |
| max_sample = ops.cast(max_sample, dtype="float64") | |
| X_zeros = ops.zeros([10], dtype="float64") | |
| X_update = xla.dynamic_update_slice(X_zeros, max_sample, [y_sample]) | |
| X_sample = xla.dynamic_update_slice(X_sample, X_update, [0]) | |
| return X_sample, y_sample | |
| # A custom `predict_one_sample` performs predictions by passing the images | |
| # through the network, measures the results produced by each layer (i.e. | |
| # how high/low the output values are with respect to the set threshold for | |
| # each label) and then simply finding the label with the highest values. | |
| # In such a case, the images are tested for their 'goodness' with all | |
| # labels. | |
| def predict_one_sample(self, x): | |
| goodness_per_label = [] | |
| x = ops.reshape(x, [ops.shape(x)[0] * ops.shape(x)[1]]) | |
| for label in range(10): | |
| h, label = self.overlay_y_on_x(data=(x, label)) | |
| h = ops.reshape(h, [-1, ops.shape(h)[0]]) | |
| goodness = [] | |
| for layer_idx in range(1, len(self.layer_list)): | |
| layer = self.layer_list[layer_idx] | |
| h = layer(h) | |
| goodness += [ops.mean(ops.power(h, 2), 1)] | |
| goodness_per_label += [ops.expand_dims(ops.sum(goodness, keepdims=True), 1)] | |
| goodness_per_label = tf.concat(goodness_per_label, 1) | |
| return ops.cast(ops.argmax(goodness_per_label, 1), dtype="float64") | |
| def predict(self, data): | |
| x = data | |
| preds = list() | |
| preds = ops.vectorized_map(self.predict_one_sample, x) | |
| return np.asarray(preds, dtype=int) | |
| # This custom `train_step` function overrides the internal `train_step` | |
| # implementation. We take all the input image tensors, flatten them and | |
| # subsequently produce positive and negative samples on the images. | |
| # A positive sample is an image that has the right label encoded on it with | |
| # the `overlay_y_on_x` function. A negative sample is an image that has an | |
| # erroneous label present on it. | |
| # With the samples ready, we pass them through each `FFLayer` and perform | |
| # the Forward-Forward computation on it. The returned loss is the final | |
| # loss value over all the layers. | |
| def train_step(self, data): | |
| x, y = data | |
| if not self.metrics_built: | |
| # build metrics to ensure they can be queried without erroring out. | |
| # We can't update the metrics' state, as we would usually do, since | |
| # we do not perform predictions within the train step | |
| for metric in self.metrics: | |
| if hasattr(metric, "build"): | |
| metric.build(y, y) | |
| self.metrics_built = True | |
| # Flatten op | |
| x = ops.reshape(x, [-1, ops.shape(x)[1] * ops.shape(x)[2]]) | |
| x_pos, y = ops.vectorized_map(self.overlay_y_on_x, (x, y)) | |
| random_y = tf.random.shuffle(y) | |
| x_neg, y = tf.map_fn(self.overlay_y_on_x, (x, random_y)) | |
| h_pos, h_neg = x_pos, x_neg | |
| for idx, layer in enumerate(self.layers): | |
| if isinstance(layer, FFDense): | |
| print(f"Training layer {idx+1} now : ") | |
| h_pos, h_neg, loss = layer.forward_forward(h_pos, h_neg) | |
| self.loss_var.assign_add(loss) | |
| self.loss_count.assign_add(1.0) | |
| else: | |
| print(f"Passing layer {idx+1} now : ") | |
| x = layer(x) | |
| mean_res = ops.divide(self.loss_var, self.loss_count) | |
| return {"FinalLoss": mean_res} | |
| """ | |
| ## Convert MNIST `NumPy` arrays to `tf.data.Dataset` | |
| We now perform some preliminary processing on the `NumPy` arrays and then convert them | |
| into the `tf.data.Dataset` format which allows for optimized loading. | |
| """ | |
| x_train = x_train.astype(float) / 255 | |
| x_test = x_test.astype(float) / 255 | |
| y_train = y_train.astype(int) | |
| y_test = y_test.astype(int) | |
| train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) | |
| test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) | |
| train_dataset = train_dataset.batch(60000) | |
| test_dataset = test_dataset.batch(10000) | |
| """ | |
| ## Fit the network and visualize results | |
| Having performed all previous set-up, we are now going to run `model.fit()` and run 250 | |
| model epochs, which will perform 50*250 epochs on each layer. We get to see the plotted loss | |
| curve as each layer is trained. | |
| """ | |
| model = FFNetwork(dims=[784, 500, 500]) | |
| model.compile( | |
| optimizer=keras.optimizers.Adam(learning_rate=0.03), | |
| loss="mse", | |
| jit_compile=False, | |
| metrics=[], | |
| ) | |
| epochs = 250 | |
| history = model.fit(train_dataset, epochs=epochs) | |
| """ | |
| ## Perform inference and testing | |
| Having trained the model to a large extent, we now see how it performs on the | |
| test set. We calculate the Accuracy Score to understand the results closely. | |
| """ | |
| preds = model.predict(ops.convert_to_tensor(x_test)) | |
| preds = preds.reshape((preds.shape[0], preds.shape[1])) | |
| results = accuracy_score(preds, y_test) | |
| print(f"Test Accuracy score : {results*100}%") | |
| plt.plot(range(len(history.history["FinalLoss"])), history.history["FinalLoss"]) | |
| plt.title("Loss over training") | |
| plt.show() | |
| """ | |
| ## Conclusion | |
| This example has hereby demonstrated how the Forward-Forward algorithm works using | |
| the TensorFlow and Keras packages. While the investigation results presented by Prof. Hinton | |
| in their paper are currently still limited to smaller models and datasets like MNIST and | |
| Fashion-MNIST, subsequent results on larger models like LLMs are expected in future | |
| papers. | |
| Through the paper, Prof. Hinton has reported results of 1.36% test accuracy error with a | |
| 2000-units, 4 hidden-layer, fully-connected network run over 60 epochs (while mentioning | |
| that backpropagation takes only 20 epochs to achieve similar performance). Another run of | |
| doubling the learning rate and training for 40 epochs yields a slightly worse error rate | |
| of 1.46% | |
| The current example does not yield state-of-the-art results. But with proper tuning of | |
| the Learning Rate, model architecture (number of units in `Dense` layers, kernel | |
| activations, initializations, regularization etc.), the results can be improved | |
| to match the claims of the paper. | |
| """ | |