ML-Starter / knowledge_base /vision /oxford_pets_image_segmentation.py
emreatilgan's picture
feat: Initialize mcp_server with embedding and loader modules
9ce984a
"""
Title: Image segmentation with a U-Net-like architecture
Author: [fchollet](https://twitter.com/fchollet)
Date created: 2019/03/20
Last modified: 2020/04/20
Description: Image segmentation model trained from scratch on the Oxford Pets dataset.
Accelerator: GPU
"""
"""
## Download the data
"""
"""shell
!wget https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz
!wget https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz
curl -O https://thor.robots.ox.ac.uk/datasets/pets/images.tar.gz
curl -O https://thor.robots.ox.ac.uk/datasets/pets/annotations.tar.gz
tar -xf images.tar.gz
tar -xf annotations.tar.gz
"""
"""
## Prepare paths of input images and target segmentation masks
"""
import os
input_dir = "images/"
target_dir = "annotations/trimaps/"
img_size = (160, 160)
num_classes = 3
batch_size = 32
input_img_paths = sorted(
[
os.path.join(input_dir, fname)
for fname in os.listdir(input_dir)
if fname.endswith(".jpg")
]
)
target_img_paths = sorted(
[
os.path.join(target_dir, fname)
for fname in os.listdir(target_dir)
if fname.endswith(".png") and not fname.startswith(".")
]
)
print("Number of samples:", len(input_img_paths))
for input_path, target_path in zip(input_img_paths[:10], target_img_paths[:10]):
print(input_path, "|", target_path)
"""
## What does one input image and corresponding segmentation mask look like?
"""
from IPython.display import Image, display
from keras.utils import load_img
from PIL import ImageOps
# Display input image #7
display(Image(filename=input_img_paths[9]))
# Display auto-contrast version of corresponding target (per-pixel categories)
img = ImageOps.autocontrast(load_img(target_img_paths[9]))
display(img)
"""
## Prepare dataset to load & vectorize batches of data
"""
import keras
import numpy as np
from tensorflow import data as tf_data
from tensorflow import image as tf_image
from tensorflow import io as tf_io
def get_dataset(
batch_size,
img_size,
input_img_paths,
target_img_paths,
max_dataset_len=None,
):
"""Returns a TF Dataset."""
def load_img_masks(input_img_path, target_img_path):
input_img = tf_io.read_file(input_img_path)
input_img = tf_io.decode_png(input_img, channels=3)
input_img = tf_image.resize(input_img, img_size)
input_img = tf_image.convert_image_dtype(input_img, "float32")
target_img = tf_io.read_file(target_img_path)
target_img = tf_io.decode_png(target_img, channels=1)
target_img = tf_image.resize(target_img, img_size, method="nearest")
target_img = tf_image.convert_image_dtype(target_img, "uint8")
# Ground truth labels are 1, 2, 3. Subtract one to make them 0, 1, 2:
target_img -= 1
return input_img, target_img
# For faster debugging, limit the size of data
if max_dataset_len:
input_img_paths = input_img_paths[:max_dataset_len]
target_img_paths = target_img_paths[:max_dataset_len]
dataset = tf_data.Dataset.from_tensor_slices((input_img_paths, target_img_paths))
dataset = dataset.map(load_img_masks, num_parallel_calls=tf_data.AUTOTUNE)
return dataset.batch(batch_size)
"""
## Prepare U-Net Xception-style model
"""
from keras import layers
def get_model(img_size, num_classes):
inputs = keras.Input(shape=img_size + (3,))
### [First half of the network: downsampling inputs] ###
# Entry block
x = layers.Conv2D(32, 3, strides=2, padding="same")(inputs)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
previous_block_activation = x # Set aside residual
# Blocks 1, 2, 3 are identical apart from the feature depth.
for filters in [64, 128, 256]:
x = layers.Activation("relu")(x)
x = layers.SeparableConv2D(filters, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.SeparableConv2D(filters, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.MaxPooling2D(3, strides=2, padding="same")(x)
# Project residual
residual = layers.Conv2D(filters, 1, strides=2, padding="same")(
previous_block_activation
)
x = layers.add([x, residual]) # Add back residual
previous_block_activation = x # Set aside next residual
### [Second half of the network: upsampling inputs] ###
for filters in [256, 128, 64, 32]:
x = layers.Activation("relu")(x)
x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.Activation("relu")(x)
x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
x = layers.BatchNormalization()(x)
x = layers.UpSampling2D(2)(x)
# Project residual
residual = layers.UpSampling2D(2)(previous_block_activation)
residual = layers.Conv2D(filters, 1, padding="same")(residual)
x = layers.add([x, residual]) # Add back residual
previous_block_activation = x # Set aside next residual
# Add a per-pixel classification layer
outputs = layers.Conv2D(num_classes, 3, activation="softmax", padding="same")(x)
# Define the model
model = keras.Model(inputs, outputs)
return model
# Build model
model = get_model(img_size, num_classes)
model.summary()
"""
## Set aside a validation split
"""
import random
# Split our img paths into a training and a validation set
val_samples = 1000
random.Random(1337).shuffle(input_img_paths)
random.Random(1337).shuffle(target_img_paths)
train_input_img_paths = input_img_paths[:-val_samples]
train_target_img_paths = target_img_paths[:-val_samples]
val_input_img_paths = input_img_paths[-val_samples:]
val_target_img_paths = target_img_paths[-val_samples:]
# Instantiate dataset for each split
# Limit input files in `max_dataset_len` for faster epoch training time.
# Remove the `max_dataset_len` arg when running with full dataset.
train_dataset = get_dataset(
batch_size,
img_size,
train_input_img_paths,
train_target_img_paths,
max_dataset_len=1000,
)
valid_dataset = get_dataset(
batch_size, img_size, val_input_img_paths, val_target_img_paths
)
"""
## Train the model
"""
# Configure the model for training.
# We use the "sparse" version of categorical_crossentropy
# because our target data is integers.
model.compile(
optimizer=keras.optimizers.Adam(1e-4), loss="sparse_categorical_crossentropy"
)
callbacks = [
keras.callbacks.ModelCheckpoint("oxford_segmentation.keras", save_best_only=True)
]
# Train the model, doing validation at the end of each epoch.
epochs = 50
model.fit(
train_dataset,
epochs=epochs,
validation_data=valid_dataset,
callbacks=callbacks,
verbose=2,
)
"""
## Visualize predictions
"""
# Generate predictions for all images in the validation set
val_dataset = get_dataset(
batch_size, img_size, val_input_img_paths, val_target_img_paths
)
val_preds = model.predict(val_dataset)
def display_mask(i):
"""Quick utility to display a model's prediction."""
mask = np.argmax(val_preds[i], axis=-1)
mask = np.expand_dims(mask, axis=-1)
img = ImageOps.autocontrast(keras.utils.array_to_img(mask))
display(img)
# Display results for validation image #10
i = 10
# Display input image
display(Image(filename=val_input_img_paths[i]))
# Display ground-truth target mask
img = ImageOps.autocontrast(load_img(val_target_img_paths[i]))
display(img)
# Display mask predicted by our model
display_mask(i) # Note that the model only sees inputs at 150x150.