Spaces:
Running
Running
| """ | |
| Title: Multiclass semantic segmentation using DeepLabV3+ | |
| Author: [Soumik Rakshit](http://github.com/soumik12345) | |
| Date created: 2021/08/31 | |
| Last modified: 2024/01/05 | |
| Description: Implement DeepLabV3+ architecture for Multi-class Semantic Segmentation. | |
| Accelerator: GPU | |
| Converted to Keras 3: [Muhammad Anas Raza](https://anasrz.com) | |
| """ | |
| """ | |
| ## Introduction | |
| Semantic segmentation, with the goal to assign semantic labels to every pixel in an image, | |
| is an essential computer vision task. In this example, we implement | |
| the **DeepLabV3+** model for multi-class semantic segmentation, a fully-convolutional | |
| architecture that performs well on semantic segmentation benchmarks. | |
| ### References: | |
| - [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1802.02611) | |
| - [Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587) | |
| - [DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs](https://arxiv.org/abs/1606.00915) | |
| """ | |
| """ | |
| ## Downloading the data | |
| We will use the [Crowd Instance-level Human Parsing Dataset](https://arxiv.org/abs/1811.12596) | |
| for training our model. The Crowd Instance-level Human Parsing (CIHP) dataset has 38,280 diverse human images. | |
| Each image in CIHP is labeled with pixel-wise annotations for 20 categories, as well as instance-level identification. | |
| This dataset can be used for the "human part segmentation" task. | |
| """ | |
| import keras | |
| from keras import layers | |
| from keras import ops | |
| import os | |
| import numpy as np | |
| from glob import glob | |
| import cv2 | |
| from scipy.io import loadmat | |
| import matplotlib.pyplot as plt | |
| # For data preprocessing | |
| from tensorflow import image as tf_image | |
| from tensorflow import data as tf_data | |
| from tensorflow import io as tf_io | |
| """shell | |
| gdown "1B9A9UCJYMwTL4oBEo4RZfbMZMaZhKJaz&confirm=t" | |
| unzip -q instance-level-human-parsing.zip | |
| """ | |
| """ | |
| ## Creating a TensorFlow Dataset | |
| Training on the entire CIHP dataset with 38,280 images takes a lot of time, hence we will be using | |
| a smaller subset of 200 images for training our model in this example. | |
| """ | |
| IMAGE_SIZE = 512 | |
| BATCH_SIZE = 4 | |
| NUM_CLASSES = 20 | |
| DATA_DIR = "./instance-level_human_parsing/instance-level_human_parsing/Training" | |
| NUM_TRAIN_IMAGES = 1000 | |
| NUM_VAL_IMAGES = 50 | |
| train_images = sorted(glob(os.path.join(DATA_DIR, "Images/*")))[:NUM_TRAIN_IMAGES] | |
| train_masks = sorted(glob(os.path.join(DATA_DIR, "Category_ids/*")))[:NUM_TRAIN_IMAGES] | |
| val_images = sorted(glob(os.path.join(DATA_DIR, "Images/*")))[ | |
| NUM_TRAIN_IMAGES : NUM_VAL_IMAGES + NUM_TRAIN_IMAGES | |
| ] | |
| val_masks = sorted(glob(os.path.join(DATA_DIR, "Category_ids/*")))[ | |
| NUM_TRAIN_IMAGES : NUM_VAL_IMAGES + NUM_TRAIN_IMAGES | |
| ] | |
| def read_image(image_path, mask=False): | |
| image = tf_io.read_file(image_path) | |
| if mask: | |
| image = tf_image.decode_png(image, channels=1) | |
| image.set_shape([None, None, 1]) | |
| image = tf_image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE]) | |
| else: | |
| image = tf_image.decode_png(image, channels=3) | |
| image.set_shape([None, None, 3]) | |
| image = tf_image.resize(images=image, size=[IMAGE_SIZE, IMAGE_SIZE]) | |
| return image | |
| def load_data(image_list, mask_list): | |
| image = read_image(image_list) | |
| mask = read_image(mask_list, mask=True) | |
| return image, mask | |
| def data_generator(image_list, mask_list): | |
| dataset = tf_data.Dataset.from_tensor_slices((image_list, mask_list)) | |
| dataset = dataset.map(load_data, num_parallel_calls=tf_data.AUTOTUNE) | |
| dataset = dataset.batch(BATCH_SIZE, drop_remainder=True) | |
| return dataset | |
| train_dataset = data_generator(train_images, train_masks) | |
| val_dataset = data_generator(val_images, val_masks) | |
| print("Train Dataset:", train_dataset) | |
| print("Val Dataset:", val_dataset) | |
| """ | |
| ## Building the DeepLabV3+ model | |
| DeepLabv3+ extends DeepLabv3 by adding an encoder-decoder structure. The encoder module | |
| processes multiscale contextual information by applying dilated convolution at multiple | |
| scales, while the decoder module refines the segmentation results along object boundaries. | |
|  | |
| **Dilated convolution:** With dilated convolution, as we go deeper in the network, we can keep the | |
| stride constant but with larger field-of-view without increasing the number of parameters | |
| or the amount of computation. Besides, it enables larger output feature maps, which is | |
| useful for semantic segmentation. | |
| The reason for using **Dilated Spatial Pyramid Pooling** is that it was shown that as the | |
| sampling rate becomes larger, the number of valid filter weights (i.e., weights that | |
| are applied to the valid feature region, instead of padded zeros) becomes smaller. | |
| """ | |
| def convolution_block( | |
| block_input, | |
| num_filters=256, | |
| kernel_size=3, | |
| dilation_rate=1, | |
| use_bias=False, | |
| ): | |
| x = layers.Conv2D( | |
| num_filters, | |
| kernel_size=kernel_size, | |
| dilation_rate=dilation_rate, | |
| padding="same", | |
| use_bias=use_bias, | |
| kernel_initializer=keras.initializers.HeNormal(), | |
| )(block_input) | |
| x = layers.BatchNormalization()(x) | |
| return ops.nn.relu(x) | |
| def DilatedSpatialPyramidPooling(dspp_input): | |
| dims = dspp_input.shape | |
| x = layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input) | |
| x = convolution_block(x, kernel_size=1, use_bias=True) | |
| out_pool = layers.UpSampling2D( | |
| size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]), | |
| interpolation="bilinear", | |
| )(x) | |
| out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1) | |
| out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6) | |
| out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12) | |
| out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18) | |
| x = layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18]) | |
| output = convolution_block(x, kernel_size=1) | |
| return output | |
| """ | |
| The encoder features are first bilinearly upsampled by a factor 4, and then | |
| concatenated with the corresponding low-level features from the network backbone that | |
| have the same spatial resolution. For this example, we | |
| use a ResNet50 pretrained on ImageNet as the backbone model, and we use | |
| the low-level features from the `conv4_block6_2_relu` block of the backbone. | |
| """ | |
| def DeeplabV3Plus(image_size, num_classes): | |
| model_input = keras.Input(shape=(image_size, image_size, 3)) | |
| preprocessed = keras.applications.resnet50.preprocess_input(model_input) | |
| resnet50 = keras.applications.ResNet50( | |
| weights="imagenet", include_top=False, input_tensor=preprocessed | |
| ) | |
| x = resnet50.get_layer("conv4_block6_2_relu").output | |
| x = DilatedSpatialPyramidPooling(x) | |
| input_a = layers.UpSampling2D( | |
| size=(image_size // 4 // x.shape[1], image_size // 4 // x.shape[2]), | |
| interpolation="bilinear", | |
| )(x) | |
| input_b = resnet50.get_layer("conv2_block3_2_relu").output | |
| input_b = convolution_block(input_b, num_filters=48, kernel_size=1) | |
| x = layers.Concatenate(axis=-1)([input_a, input_b]) | |
| x = convolution_block(x) | |
| x = convolution_block(x) | |
| x = layers.UpSampling2D( | |
| size=(image_size // x.shape[1], image_size // x.shape[2]), | |
| interpolation="bilinear", | |
| )(x) | |
| model_output = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x) | |
| return keras.Model(inputs=model_input, outputs=model_output) | |
| model = DeeplabV3Plus(image_size=IMAGE_SIZE, num_classes=NUM_CLASSES) | |
| model.summary() | |
| """ | |
| ## Training | |
| We train the model using sparse categorical crossentropy as the loss function, and | |
| Adam as the optimizer. | |
| """ | |
| loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True) | |
| model.compile( | |
| optimizer=keras.optimizers.Adam(learning_rate=0.001), | |
| loss=loss, | |
| metrics=["accuracy"], | |
| ) | |
| history = model.fit(train_dataset, validation_data=val_dataset, epochs=25) | |
| plt.plot(history.history["loss"]) | |
| plt.title("Training Loss") | |
| plt.ylabel("loss") | |
| plt.xlabel("epoch") | |
| plt.show() | |
| plt.plot(history.history["accuracy"]) | |
| plt.title("Training Accuracy") | |
| plt.ylabel("accuracy") | |
| plt.xlabel("epoch") | |
| plt.show() | |
| plt.plot(history.history["val_loss"]) | |
| plt.title("Validation Loss") | |
| plt.ylabel("val_loss") | |
| plt.xlabel("epoch") | |
| plt.show() | |
| plt.plot(history.history["val_accuracy"]) | |
| plt.title("Validation Accuracy") | |
| plt.ylabel("val_accuracy") | |
| plt.xlabel("epoch") | |
| plt.show() | |
| """ | |
| ## Inference using Colormap Overlay | |
| The raw predictions from the model represent a one-hot encoded tensor of shape `(N, 512, 512, 20)` | |
| where each one of the 20 channels is a binary mask corresponding to a predicted label. | |
| In order to visualize the results, we plot them as RGB segmentation masks where each pixel | |
| is represented by a unique color corresponding to the particular label predicted. We can easily | |
| find the color corresponding to each label from the `human_colormap.mat` file provided as part | |
| of the dataset. We would also plot an overlay of the RGB segmentation mask on the input image as | |
| this further helps us to identify the different categories present in the image more intuitively. | |
| """ | |
| # Loading the Colormap | |
| colormap = loadmat( | |
| "./instance-level_human_parsing/instance-level_human_parsing/human_colormap.mat" | |
| )["colormap"] | |
| colormap = colormap * 100 | |
| colormap = colormap.astype(np.uint8) | |
| def infer(model, image_tensor): | |
| predictions = model.predict(np.expand_dims((image_tensor), axis=0)) | |
| predictions = np.squeeze(predictions) | |
| predictions = np.argmax(predictions, axis=2) | |
| return predictions | |
| def decode_segmentation_masks(mask, colormap, n_classes): | |
| r = np.zeros_like(mask).astype(np.uint8) | |
| g = np.zeros_like(mask).astype(np.uint8) | |
| b = np.zeros_like(mask).astype(np.uint8) | |
| for l in range(0, n_classes): | |
| idx = mask == l | |
| r[idx] = colormap[l, 0] | |
| g[idx] = colormap[l, 1] | |
| b[idx] = colormap[l, 2] | |
| rgb = np.stack([r, g, b], axis=2) | |
| return rgb | |
| def get_overlay(image, colored_mask): | |
| image = keras.utils.array_to_img(image) | |
| image = np.array(image).astype(np.uint8) | |
| overlay = cv2.addWeighted(image, 0.35, colored_mask, 0.65, 0) | |
| return overlay | |
| def plot_samples_matplotlib(display_list, figsize=(5, 3)): | |
| _, axes = plt.subplots(nrows=1, ncols=len(display_list), figsize=figsize) | |
| for i in range(len(display_list)): | |
| if display_list[i].shape[-1] == 3: | |
| axes[i].imshow(keras.utils.array_to_img(display_list[i])) | |
| else: | |
| axes[i].imshow(display_list[i]) | |
| plt.show() | |
| def plot_predictions(images_list, colormap, model): | |
| for image_file in images_list: | |
| image_tensor = read_image(image_file) | |
| prediction_mask = infer(image_tensor=image_tensor, model=model) | |
| prediction_colormap = decode_segmentation_masks(prediction_mask, colormap, 20) | |
| overlay = get_overlay(image_tensor, prediction_colormap) | |
| plot_samples_matplotlib( | |
| [image_tensor, overlay, prediction_colormap], figsize=(18, 14) | |
| ) | |
| """ | |
| ### Inference on Train Images | |
| """ | |
| plot_predictions(train_images[:4], colormap, model=model) | |
| """ | |
| ### Inference on Validation Images | |
| You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/deeplabv3p-resnet50) | |
| and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/Human-Part-Segmentation). | |
| """ | |
| plot_predictions(val_images[:4], colormap, model=model) | |