| import tensorflow as tf |
| import matplotlib.pyplot as plt |
|
|
|
|
| class Augment(tf.keras.layers.Layer): |
| def __init__(self, seed=42): |
| super().__init__() |
| |
| self.augment_inputs = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed) |
| self.augment_labels = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed) |
|
|
| def call(self, inputs, labels): |
| inputs = self.augment_inputs(inputs) |
| labels = self.augment_labels(labels) |
| return inputs, labels |
| |
| def load_image(datapoint): |
| input_image = tf.image.resize(datapoint['image'], (128, 128)) |
| input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128), method = tf.image.ResizeMethod.NEAREST_NEIGHBOR) |
| input_image, input_mask = normalize_image(input_image, input_mask) |
| return input_image, input_mask |
|
|
| def normalize_image(input_image, input_mask): |
| input_image = tf.cast(input_image, tf.float32) / 255.0 |
| input_mask -= 1 |
| return input_image, input_mask |
|
|
| def create_mask(pred_mask): |
| pred_mask = tf.math.argmax(pred_mask, axis=-1) |
| pred_mask = pred_mask[..., tf.newaxis] |
| return pred_mask[0] |
|
|
| def U_net_model(output_channels:int, down_stack, up_stack): |
| inputs = tf.keras.layers.Input(shape=[128, 128, 3]) |
| skips = down_stack(inputs) |
| outputs = skips[-1] |
| skips = reversed(skips[:-1]) |
| for up, skip in zip(up_stack, skips): |
| outputs = up(outputs) |
| concatenate = tf.keras.layers.Concatenate() |
| outputs = concatenate([outputs, skip]) |
| last = tf.keras.layers.Conv2DTranspose(filters=output_channels, kernel_size=3, strides=2, padding='same') |
| outputs = last(outputs) |
| return tf.keras.Model(inputs=inputs, outputs=outputs) |
|
|
| def display(display_list): |
| plt.figure(figsize=(15, 15)) |
| titles = ['Input Image', 'Predicted Mask'] |
| for i in range(len(display_list)): |
| plt.subplot(1, len(display_list), i+1) |
| plt.title(titles[i]) |
| plt.imshow(tf.keras.utils.array_to_img(display_list[i])) |
| plt.axis('off') |
| plt.show() |
|
|
| def show_predictions(image_url, model): |
| image = tf.keras.utils.get_file(origin=image_url) |
| image = tf.keras.utils.load_img(image) |
| image = tf.keras.utils.img_to_array(image) |
| image = tf.image.resize(image, (128,128)) |
| image = tf.cast(image, tf.float32) / 255.0 |
| image = tf.expand_dims(image, axis=0) |
| pred_mask = model.predict(image) |
| display([image[0], create_mask(pred_mask)]) |