{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "deepmac_colab.ipynb", "provenance": [], "collapsed_sections": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "P-esW81yhfCN" }, "source": [ "# Novel class segmentation demo with Deep-MAC\n", "\n", "Welcome to the Novel class segmentation (with Deep-MAC) demo --- this colab loads a Deep-MAC model and tests it interactively with user-specified boxes. Deep-MAC was only trained to detect and segment COCO classes, but generalizes well when segmenting within user-specified boxes of unseen classes.\n", "\n", "Estimated time to run through this colab (with GPU): 10-15 minutes.\n", "Note that the bulk of this time is in installing Tensorflow and downloading\n", "the checkpoint then running inference for the first time. Once you've done\n", "all that, running on new images is very fast." ] }, { "cell_type": "markdown", "metadata": { "id": "Kq1eGNssiW31" }, "source": [ "# Prerequisites\n", "\n", "Please change runtime to GPU." ] }, { "cell_type": "markdown", "metadata": { "id": "UT7N0HJhiRKr" }, "source": [ "# Installation and Imports\n", "\n", "This takes 3-4 minutes." ] }, { "cell_type": "code", "metadata": { "id": "nNdls0Pe0UPK" }, "source": [ "import os\n", "import pathlib\n", "\n", "# Clone the tensorflow models repository if it doesn't already exist\n", "if \"models\" in pathlib.Path.cwd().parts:\n", " while \"models\" in pathlib.Path.cwd().parts:\n", " os.chdir('..')\n", "elif not pathlib.Path('models').exists():\n", " !git clone --depth 1 https://github.com/tensorflow/models\n" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "WwjV9clX0n7S" }, "source": [ "# Install the Object Detection API\n", "%%bash\n", "cd models/research/\n", "protoc object_detection/protos/*.proto --python_out=.\n", "cp object_detection/packages/tf2/setup.py .\n", "\n", "# The latest tf-models-official installs tensorflow 2.9 which has an\n", "# incompatible CuDNN dependency. Here we restrict ourselves to versions 2.8 and\n", "# below.\n", "python -m pip install \"tf-models-official<=2.8\" ." ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "sfrrno2L0sRR" }, "source": [ "import glob\n", "import io\n", "import logging\n", "import os\n", "import random\n", "import warnings\n", "\n", "import imageio\n", "from IPython.display import display, Javascript\n", "from IPython.display import Image as IPyImage\n", "import matplotlib\n", "from matplotlib import patches\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from object_detection.utils import colab_utils\n", "from object_detection.utils import ops\n", "from object_detection.utils import visualization_utils as viz_utils\n", "from PIL import Image, ImageDraw, ImageFont\n", "import scipy.misc\n", "from six import BytesIO\n", "from skimage import color\n", "from skimage import transform\n", "from skimage import util\n", "from skimage.color import rgb_colors\n", "import tensorflow as tf\n", "\n", "%matplotlib inline\n", "\n", "COLORS = ([rgb_colors.cyan, rgb_colors.orange, rgb_colors.pink,\n", " rgb_colors.purple, rgb_colors.limegreen , rgb_colors.crimson] +\n", " [(color) for (name, color) in color.color_dict.items()])\n", "random.shuffle(COLORS)\n", "\n", "logging.disable(logging.WARNING)\n", "\n", "\n", "def read_image(path):\n", " \"\"\"Read an image and optionally resize it for better plotting.\"\"\"\n", " with tf.io.gfile.GFile(path, 'rb') as f:\n", " img = Image.open(f)\n", " return np.array(img, dtype=np.uint8)\n", "\n", "\n", "def resize_for_display(image, max_height=600):\n", " height, width, _ = image.shape\n", " width = int(width * max_height / height)\n", " with warnings.catch_warnings():\n", " warnings.simplefilter(\"ignore\", UserWarning)\n", " return util.img_as_ubyte(transform.resize(image, (height, width)))\n", "\n", "\n", "def get_mask_prediction_function(model):\n", " \"\"\"Get single image mask prediction function using a model.\"\"\"\n", "\n", " @tf.function\n", " def predict_masks(image, boxes):\n", " height, width, _ = image.shape.as_list()\n", " batch = image[tf.newaxis]\n", " boxes = boxes[tf.newaxis]\n", "\n", " detections = model(batch, boxes)\n", " masks = detections['detection_masks']\n", "\n", " return ops.reframe_box_masks_to_image_masks(masks[0], boxes[0],\n", " height, width)\n", "\n", " return predict_masks\n", "\n", "\n", "def plot_image_annotations(image, boxes, masks, darken_image=0.5):\n", " fig, ax = plt.subplots(figsize=(16, 12))\n", " ax.set_axis_off()\n", " image = (image * darken_image).astype(np.uint8)\n", " ax.imshow(image)\n", "\n", " height, width, _ = image.shape\n", "\n", " num_colors = len(COLORS)\n", " color_index = 0\n", "\n", " for box, mask in zip(boxes, masks):\n", " ymin, xmin, ymax, xmax = box\n", " ymin *= height\n", " ymax *= height\n", " xmin *= width\n", " xmax *= width\n", "\n", " color = COLORS[color_index]\n", " color = np.array(color)\n", " rect = patches.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,\n", " linewidth=2.5, edgecolor=color, facecolor='none')\n", " ax.add_patch(rect)\n", " mask = (mask > 0.5).astype(np.float32)\n", " color_image = np.ones_like(image) * color[np.newaxis, np.newaxis, :]\n", " color_and_mask = np.concatenate(\n", " [color_image, mask[:, :, np.newaxis]], axis=2)\n", "\n", " ax.imshow(color_and_mask, alpha=0.5)\n", "\n", " color_index = (color_index + 1) % num_colors\n", "\n", " return ax" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "ry9yq8zsi0Gg" }, "source": [ "# Load Deep-MAC Model\n", "\n", "This can take up to 5 minutes." ] }, { "cell_type": "code", "metadata": { "id": "PZ-wnbYu05K8" }, "source": [ "print('Downloading and untarring model')\n", "!wget http://download.tensorflow.org/models/object_detection/tf2/20210329/deepmac_1024x1024_coco17.tar.gz\n", "!cp deepmac_1024x1024_coco17.tar.gz models/research/object_detection/test_data/\n", "!tar -xzf models/research/object_detection/test_data/deepmac_1024x1024_coco17.tar.gz\n", "!mv deepmac_1024x1024_coco17 models/research/object_detection/test_data/\n", "model_path = 'models/research/object_detection/test_data/deepmac_1024x1024_coco17/saved_model'\n", "\n", "print('Loading SavedModel')\n", "model = tf.keras.models.load_model(model_path)\n", "prediction_function = get_mask_prediction_function(model)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "ilXkYOB_NUSc" }, "source": [ "# Load image" ] }, { "cell_type": "code", "metadata": { "id": "txj4UkoDNaOq" }, "source": [ "image_path = 'models/research/object_detection/test_images/image3.jpg'\n", "image = read_image(image_path)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "zyhudgYUjcvE" }, "source": [ "# Annotate an image with one or more boxes\n", "\n", "This model is trained on COCO categories, but we encourage you to try segmenting\n", "anything you want!\n", "\n", "Don't forget to hit **submit** when done." ] }, { "cell_type": "code", "metadata": { "id": "aZvY4At0074j" }, "source": [ "display_image = resize_for_display(image)\n", "\n", "boxes_list = []\n", "colab_utils.annotate([display_image], boxes_list)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "gUUG7NPBJMoa" }, "source": [ "# In case you didn't want to label...\n", "\n", "Run this cell only if you didn't annotate anything above and would prefer to just use our preannotated boxes. Don't forget to uncomment.\n", "\n" ] }, { "cell_type": "code", "metadata": { "id": "lupqTv1HJK5K" }, "source": [ "# boxes_list = [np.array([[0.000, 0.160, 0.362, 0.812],\n", "# [0.340, 0.286, 0.472, 0.619],\n", "# [0.437, 0.008, 0.650, 0.263],\n", "# [0.382, 0.003, 0.538, 0.594],\n", "# [0.518, 0.444, 0.625,0.554]], dtype=np.float32)]" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "Ak1WO93NjvN-" }, "source": [ "# Visualize mask predictions" ] }, { "cell_type": "code", "metadata": { "id": "vdzuKnpj1A3L" }, "source": [ "\n", "%matplotlib inline\n", "\n", "boxes = boxes_list[0]\n", "masks = prediction_function(tf.convert_to_tensor(image),\n", " tf.convert_to_tensor(boxes, dtype=tf.float32))\n", "plot_image_annotations(image, boxes, masks.numpy())\n", "plt.show()" ], "execution_count": null, "outputs": [] } ] }