{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "56e96915", "metadata": {}, "outputs": [], "source": [ "import os\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import matplotlib.patches as patches\n", "\n", "from wildlife_datasets.datasets import TurtlesOfSMSRC\n", "from wildlife_datasets.datasets.utils import parse_bbox_mask\n", "from turtle_detector import assign_flippers, initialize_sam3, mask_to_rle, rle_to_mask, compute_iou, mask_to_bbox" ] }, { "cell_type": "code", "execution_count": null, "id": "4c8a0449", "metadata": {}, "outputs": [], "source": [ "root = '/data/wildlife_datasets/TurtlesOfSMSRC'\n", "root_figures = 'figures'\n", "dataset = TurtlesOfSMSRC(root)\n", "masks = pd.read_csv('masks.csv')\n", "masks['mask'] = masks['mask'].apply(parse_bbox_mask)\n", "\n", "os.makedirs(root_figures, exist_ok=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "2f521710", "metadata": {}, "outputs": [], "source": [ "colors_map = {\n", " \"head\": 0,\n", " \"flipper\": 1,\n", " \"turtle\": 2,\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "e82f9db7", "metadata": {}, "outputs": [], "source": [ "for image_id, masks_image in masks.groupby('image_id'):\n", " i = np.where(dataset.metadata.image_id == image_id)[0][0]\n", " image = dataset[i]\n", " width, height = image.size\n", "\n", " overlay = np.zeros((height, width, 3), dtype=np.float32)\n", " for _, m in masks_image.iterrows():\n", " mask_bool = rle_to_mask(m['mask']).astype(bool)\n", " overlay[mask_bool, colors_map[m['label']]] = 1.0\n", "\n", " fig, ax = plt.subplots(figsize=(8, 8))\n", " plt.imshow(image)\n", " plt.imshow(overlay, alpha=0.5)\n", "\n", " for _, m in masks_image.iterrows():\n", " rect = patches.Rectangle(\n", " (m['bbox_x'], m['bbox_y']),\n", " m['bbox_w'],\n", " m['bbox_h'],\n", " linewidth=2,\n", " edgecolor=\"white\",\n", " facecolor=\"none\"\n", " )\n", " ax.add_patch(rect)\n", " ax.text(\n", " m['bbox_x'],\n", " m['bbox_y'] - 3,\n", " m['label_side'],\n", " color=\"white\",\n", " fontsize=10,\n", " weight=\"bold\",\n", " bbox=dict(facecolor=\"black\", alpha=0.5, pad=2)\n", " )\n", " \n", " n_head = (masks_image['label'] == 'head').sum()\n", " n_flipper = (masks_image['label'] == 'flipper').sum()\n", " n_turtle = (masks_image['label'] == 'head').sum()\n", "\n", " plt.axis(\"off\")\n", " plt.title(f'{n_head}, {n_flipper}, {n_turtle}')\n", " plt.savefig(f'{root_figures}/{image_id}.png', bbox_inches='tight', dpi=600)\n", " plt.close()" ] }, { "cell_type": "code", "execution_count": null, "id": "54035a2d", "metadata": {}, "outputs": [], "source": [ "for image_id, masks_image in masks.groupby('image_id'):\n", " if masks_image['label_side'].value_counts().max() > 1:\n", " print(f'Image id {image_id} has multiple annotations.')\n", " display(masks_image)\n", "display(masks['label'].value_counts())\n", "display(masks['label_side'].value_counts())" ] } ], "metadata": { "kernelspec": { "display_name": "sam3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 5 }