File size: 4,057 Bytes
bb14d6a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | {
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "56e96915",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.patches as patches\n",
"\n",
"from wildlife_datasets.datasets import TurtlesOfSMSRC\n",
"from wildlife_datasets.datasets.utils import parse_bbox_mask\n",
"from turtle_detector import assign_flippers, initialize_sam3, mask_to_rle, rle_to_mask, compute_iou, mask_to_bbox"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4c8a0449",
"metadata": {},
"outputs": [],
"source": [
"root = '/data/wildlife_datasets/TurtlesOfSMSRC'\n",
"root_figures = 'figures'\n",
"dataset = TurtlesOfSMSRC(root)\n",
"masks = pd.read_csv('masks.csv')\n",
"masks['mask'] = masks['mask'].apply(parse_bbox_mask)\n",
"\n",
"os.makedirs(root_figures, exist_ok=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2f521710",
"metadata": {},
"outputs": [],
"source": [
"colors_map = {\n",
" \"head\": 0,\n",
" \"flipper\": 1,\n",
" \"turtle\": 2,\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e82f9db7",
"metadata": {},
"outputs": [],
"source": [
"for image_id, masks_image in masks.groupby('image_id'):\n",
" i = np.where(dataset.metadata.image_id == image_id)[0][0]\n",
" image = dataset[i]\n",
" width, height = image.size\n",
"\n",
" overlay = np.zeros((height, width, 3), dtype=np.float32)\n",
" for _, m in masks_image.iterrows():\n",
" mask_bool = rle_to_mask(m['mask']).astype(bool)\n",
" overlay[mask_bool, colors_map[m['label']]] = 1.0\n",
"\n",
" fig, ax = plt.subplots(figsize=(8, 8))\n",
" plt.imshow(image)\n",
" plt.imshow(overlay, alpha=0.5)\n",
"\n",
" for _, m in masks_image.iterrows():\n",
" rect = patches.Rectangle(\n",
" (m['bbox_x'], m['bbox_y']),\n",
" m['bbox_w'],\n",
" m['bbox_h'],\n",
" linewidth=2,\n",
" edgecolor=\"white\",\n",
" facecolor=\"none\"\n",
" )\n",
" ax.add_patch(rect)\n",
" ax.text(\n",
" m['bbox_x'],\n",
" m['bbox_y'] - 3,\n",
" m['label_side'],\n",
" color=\"white\",\n",
" fontsize=10,\n",
" weight=\"bold\",\n",
" bbox=dict(facecolor=\"black\", alpha=0.5, pad=2)\n",
" )\n",
" \n",
" n_head = (masks_image['label'] == 'head').sum()\n",
" n_flipper = (masks_image['label'] == 'flipper').sum()\n",
" n_turtle = (masks_image['label'] == 'head').sum()\n",
"\n",
" plt.axis(\"off\")\n",
" plt.title(f'{n_head}, {n_flipper}, {n_turtle}')\n",
" plt.savefig(f'{root_figures}/{image_id}.png', bbox_inches='tight', dpi=600)\n",
" plt.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "54035a2d",
"metadata": {},
"outputs": [],
"source": [
"for image_id, masks_image in masks.groupby('image_id'):\n",
" if masks_image['label_side'].value_counts().max() > 1:\n",
" print(f'Image id {image_id} has multiple annotations.')\n",
" display(masks_image)\n",
"display(masks['label'].value_counts())\n",
"display(masks['label_side'].value_counts())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "sam3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|