{ "cells": [ { "cell_type": "markdown", "id": "926f340c", "metadata": {}, "source": [ "The notebook prepares the SMSRC data for training of the turtle detector. It uses SAM3 to detect the turtle, its head and flippers. Then it uses a heuristic to assing the left/right and front/rear orientation of the flipper. These assignments were manually checked and fixed when not correct.\n", "\n", "The output is the notebook is the masks.csv file which is then used in the segmentation_prepare notebook to create the training dataset for detection." ] }, { "cell_type": "code", "execution_count": null, "id": "6774dc0c", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "from wildlife_datasets.datasets import TurtlesOfSMSRC\n", "from turtle_detector import assign_flippers, initialize_sam3, mask_to_rle, rle_to_mask, compute_iou, mask_to_bbox" ] }, { "cell_type": "code", "execution_count": null, "id": "4c8a0449", "metadata": {}, "outputs": [], "source": [ "root = '/data/wildlife_datasets/TurtlesOfSMSRC'\n", "dataset = TurtlesOfSMSRC(root)\n", "\n", "idx_ranges = [\n", " (333582414, 333582440),\n", " (327367311, 327367335),\n", "]\n", "idx = np.zeros(len(dataset), dtype=bool)\n", "for idx_min, idx_max in idx_ranges:\n", " encounter_id = dataset.metadata['encounter_id'].to_numpy()\n", " idx += (encounter_id >= idx_min) * (encounter_id <= idx_max)\n", "\n", "dataset = dataset.get_subset(idx)" ] }, { "cell_type": "code", "execution_count": null, "id": "7a35f952", "metadata": {}, "outputs": [], "source": [ "model, processor = initialize_sam3()" ] }, { "cell_type": "code", "execution_count": null, "id": "2f521710", "metadata": {}, "outputs": [], "source": [ "prompt_map = {\n", " \"head\": \"turtle head\",\n", " \"flipper\": \"turtle flipper\",\n", " \"turtle\": \"turtle\",\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "aef61bcb", "metadata": {}, "outputs": [], "source": [ "min_area = 500\n", "iou_threshold = 0.1\n", "\n", "masks = []\n", "for i in range(len(dataset)):\n", " image_path = f\"{dataset.root}/{dataset.metadata['path'].iloc[i]}\"\n", " image = dataset[i]\n", " inference_state = processor.set_image(image)\n", "\n", " for label, prompt in prompt_map.items():\n", " processor.reset_all_prompts(inference_state)\n", " inference_state = processor.set_text_prompt(state=inference_state, prompt=prompt)\n", "\n", " for m in inference_state[\"masks\"]:\n", " m = m.cpu().numpy().astype(bool)\n", " if m.ndim == 3 and m.shape[0] == 1:\n", " m = m[0]\n", " if m.sum() > min_area:\n", " masks.append({\n", " 'image_id': dataset.metadata['image_id'].loc[i],\n", " 'mask': mask_to_rle(m),\n", " 'label': label,\n", " })\n", "masks = pd.DataFrame(masks)" ] }, { "cell_type": "code", "execution_count": null, "id": "d4a052cd", "metadata": {}, "outputs": [], "source": [ "masks['keep'] = True\n", "for _, masks_image in masks.groupby('image_id'):\n", " keep = masks_image['keep'].copy()\n", " for i, (j, mask_j) in enumerate(masks_image.iterrows()):\n", " for k, mask_k in masks_image.iloc[i+1:].iterrows(): \n", " if not keep.loc[j] or not keep.loc[k]:\n", " continue\n", " \n", " mj = rle_to_mask(masks.loc[j, 'mask'])\n", " mk = rle_to_mask(masks.loc[k, 'mask'])\n", "\n", " iou = compute_iou(mj, mk)\n", " if iou < iou_threshold:\n", " continue\n", "\n", " if mask_j['label'] == mask_k['label']:\n", " masks.at[j, 'mask'] = mask_to_rle(mj | mk)\n", " keep.loc[k] = False\n", "\n", " elif {\"head\", \"flipper\"} == {mask_j['label'], mask_k['label']}:\n", " if (keep * (masks_image['label'] == 'head')).sum() == 1:\n", " if mask_j['label'] == \"flipper\":\n", " keep.loc[j] = False\n", " else:\n", " keep.loc[k] = False\n", " else:\n", " if mask_j['label'] == \"head\":\n", " keep.loc[j] = False\n", " else:\n", " keep.loc[k] = False\n", " masks.loc[masks_image.index, 'keep'] = keep\n", "masks = masks[masks['keep']]\n", "masks = masks.drop('keep', axis=1)\n", "\n", "for i, m in masks.iterrows():\n", " bbox = mask_to_bbox(rle_to_mask(m['mask']))\n", " x0, y0, x1, y1 = bbox\n", " masks.loc[i, 'bbox_x'] = x0\n", " masks.loc[i, 'bbox_y'] = y0\n", " masks.loc[i, 'bbox_w'] = x1 - x0\n", " masks.loc[i, 'bbox_h'] = y1 - y0\n", "\n", "for _, masks_image in masks.groupby('image_id'):\n", " masks.loc[masks_image.index, 'label_side'] = assign_flippers(masks_image)['label']\n", "\n", "masks.to_csv('masks.csv', index=False)" ] } ], "metadata": { "kernelspec": { "display_name": "sam3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 5 }