| | import tensorflow as tf |
| | import matplotlib.pyplot as plt |
| |
|
| |
|
| | class Augment(tf.keras.layers.Layer): |
| | def __init__(self, seed=42): |
| | super().__init__() |
| | |
| | self.augment_inputs = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed) |
| | self.augment_labels = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed) |
| |
|
| | def call(self, inputs, labels): |
| | inputs = self.augment_inputs(inputs) |
| | labels = self.augment_labels(labels) |
| | return inputs, labels |
| | |
| | def load_image(datapoint): |
| | input_image = tf.image.resize(datapoint['image'], (128, 128)) |
| | input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128), method = tf.image.ResizeMethod.NEAREST_NEIGHBOR) |
| | input_image, input_mask = normalize_image(input_image, input_mask) |
| | return input_image, input_mask |
| |
|
| | def normalize_image(input_image, input_mask): |
| | input_image = tf.cast(input_image, tf.float32) / 255.0 |
| | input_mask -= 1 |
| | return input_image, input_mask |
| |
|
| | def create_mask(pred_mask): |
| | pred_mask = tf.math.argmax(pred_mask, axis=-1) |
| | pred_mask = pred_mask[..., tf.newaxis] |
| | return pred_mask[0] |
| |
|
| | def U_net_model(output_channels:int, down_stack, up_stack): |
| | inputs = tf.keras.layers.Input(shape=[128, 128, 3]) |
| | skips = down_stack(inputs) |
| | outputs = skips[-1] |
| | skips = reversed(skips[:-1]) |
| | for up, skip in zip(up_stack, skips): |
| | outputs = up(outputs) |
| | concatenate = tf.keras.layers.Concatenate() |
| | outputs = concatenate([outputs, skip]) |
| | last = tf.keras.layers.Conv2DTranspose(filters=output_channels, kernel_size=3, strides=2, padding='same') |
| | outputs = last(outputs) |
| | return tf.keras.Model(inputs=inputs, outputs=outputs) |
| |
|
| | def display(display_list): |
| | plt.figure(figsize=(15, 15)) |
| | titles = ['Input Image', 'Predicted Mask'] |
| | for i in range(len(display_list)): |
| | plt.subplot(1, len(display_list), i+1) |
| | plt.title(titles[i]) |
| | plt.imshow(tf.keras.utils.array_to_img(display_list[i])) |
| | plt.axis('off') |
| | plt.show() |
| |
|
| | def show_predictions(image_url, model): |
| | image = tf.keras.utils.get_file(origin=image_url) |
| | image = tf.keras.utils.load_img(image) |
| | image = tf.keras.utils.img_to_array(image) |
| | image = tf.image.resize(image, (128,128)) |
| | image = tf.cast(image, tf.float32) / 255.0 |
| | image = tf.expand_dims(image, axis=0) |
| | pred_mask = model.predict(image) |
| | display([image[0], create_mask(pred_mask)]) |