| import os |
| import numpy as np |
| import sam3 |
| from sam3 import build_sam3_image_model |
| from sam3.model.sam3_image_processor import Sam3Processor |
| from .masks import rle_to_mask |
|
|
| def get_index(dataset, image_id): |
| idx = dataset.metadata['image_id'] == image_id |
| if idx.sum() != 1: |
| raise ValueError('image_id not found or found multiple times.') |
| return dataset.metadata[idx].index[0] |
|
|
| def mask_centroid(mask): |
| ys, xs = np.nonzero(mask) |
| return np.array([xs.mean(), ys.mean()]) |
|
|
| def rle_centroid(rle): |
| return mask_centroid(rle_to_mask(rle)) |
|
|
| def assign_flippers(df): |
| df = df.copy() |
|
|
| |
| head_rows = df[df['label'] == 'head'] |
| if len(head_rows) != 1: |
| return df |
| |
| |
| head_center = rle_centroid(head_rows.iloc[0]['mask']) |
|
|
| |
| flippers = df[df['label'] == 'flipper'] |
| n_flippers = len(flippers) |
| if n_flippers == 0: |
| return df |
|
|
| |
| flipper_centers = np.vstack([ |
| rle_centroid(rle) for rle in flippers['mask'] |
| ]) |
|
|
| |
| turtle_center = flipper_centers.mean(axis=0) |
| forward_vec = head_center - turtle_center |
| forward_vec /= np.linalg.norm(forward_vec) |
|
|
| |
| left_vec = np.array([-forward_vec[1], forward_vec[0]]) |
|
|
| |
| forward_proj = flipper_centers @ forward_vec |
| lateral_proj = flipper_centers @ left_vec |
|
|
| if n_flippers <= 2: |
| |
| order = np.argsort(lateral_proj) |
| left_idx, right_idx = order[0], order[-1] |
|
|
| df.loc[flippers.index[left_idx], 'label'] = 'flipper_fl' |
| df.loc[flippers.index[right_idx], 'label'] = 'flipper_fr' |
| return df |
| elif n_flippers <= 4: |
| |
| order_fwd = np.argsort(forward_proj) |
| rear_idxs = order_fwd[:2] |
| front_idxs = order_fwd[-2:] |
|
|
| |
| front_l = front_idxs[np.argmin(lateral_proj[front_idxs])] |
| front_r = front_idxs[np.argmax(lateral_proj[front_idxs])] |
|
|
| df.loc[flippers.index[front_l], 'label'] = 'flipper_fl' |
| df.loc[flippers.index[front_r], 'label'] = 'flipper_fr' |
|
|
| |
| if len(rear_idxs) == 2: |
| rear_l = rear_idxs[np.argmin(lateral_proj[rear_idxs])] |
| rear_r = rear_idxs[np.argmax(lateral_proj[rear_idxs])] |
|
|
| df.loc[flippers.index[rear_l], 'label'] = 'flipper_rl' |
| df.loc[flippers.index[rear_r], 'label'] = 'flipper_rr' |
| else: |
| |
| idx = rear_idxs[0] |
| side = 'l' if lateral_proj[idx] < 0 else 'r' |
| df.loc[flippers.index[idx], 'label'] = f'flipper_r{side}' |
|
|
| return df |
|
|
| def initialize_sam3(): |
| sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..") |
| bpe_path = f"{sam3_root}/sam3/assets/bpe_simple_vocab_16e6.txt.gz" |
| model = build_sam3_image_model(bpe_path=bpe_path) |
| processor = Sam3Processor(model, confidence_threshold=0.5) |
| return model, processor |