| import numpy as np |
| import torch |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
|
|
| |
| def get_frame_id_list_from_mask(mask): |
| |
| """ |
| Vectorized approach to get frame id list from a boolean mask. |
| |
| Args: |
| mask (F,), bool tensor: Mask array where `True` indicates a frame to be processed. |
| |
| Returns: |
| frame_id_list: List of torch.Tensors, each tensor containing continuous indices where mask is True. |
| """ |
| |
| padded_mask = torch.cat( |
| [ |
| torch.tensor([False], device=mask.device), |
| mask, |
| torch.tensor([False], device=mask.device), |
| ] |
| ) |
| diffs = torch.diff(padded_mask.int()) |
| starts = (diffs == 1).nonzero(as_tuple=False).squeeze() |
| ends = (diffs == -1).nonzero(as_tuple=False).squeeze() |
| if starts.numel() == 0: |
| return [] |
| if starts.numel() == 1: |
| starts = starts.reshape(-1) |
| ends = ends.reshape(-1) |
|
|
| |
| frame_id_list = [torch.arange(start, end) for start, end in zip(starts, ends)] |
| return frame_id_list |
|
|
|
|
| def get_batch_frame_id_lists_from_mask_BLC(masks): |
| |
| """ |
| 处理三维掩码数组,为每个批次和通道提取连续True区段的索引列表。 |
| |
| 参数: |
| masks (B, L, C), 布尔张量:每个元素代表一个掩码,True表示需要处理的帧。 |
| |
| 返回: |
| batch_frame_id_lists: 对应于每个批次和每个通道的帧id列表的嵌套列表。 |
| """ |
| B, L, C = masks.size() |
| |
| padded_masks = torch.cat( |
| [ |
| torch.zeros((B, 1, C), dtype=torch.bool, device=masks.device), |
| masks, |
| torch.zeros((B, 1, C), dtype=torch.bool, device=masks.device), |
| ], |
| dim=1, |
| ) |
| |
| diffs = torch.diff(padded_masks.int(), dim=1) |
| starts = (diffs == 1).nonzero(as_tuple=True) |
| ends = (diffs == -1).nonzero(as_tuple=True) |
|
|
| |
| batch_frame_id_lists = [[[] for _ in range(C)] for _ in range(B)] |
| for b in range(B): |
| for c in range(C): |
| batch_start = starts[0][(starts[0] == b) & (starts[2] == c)] |
| batch_end = ends[0][(ends[0] == b) & (ends[2] == c)] |
| |
| batch_frame_id_lists[b][c] = [ |
| torch.arange(start.item(), end.item()) |
| for start, end in zip(batch_start, batch_end) |
| ] |
|
|
| return batch_frame_id_lists |
|
|
|
|
| def get_frame_id_list_from_frame_id(frame_id): |
| mask = torch.zeros(frame_id[-1] + 1, dtype=torch.bool) |
| mask[frame_id] = True |
| frame_id_list = get_frame_id_list_from_mask(mask) |
| return frame_id_list |
|
|
|
|
| def rearrange_by_mask(x, mask): |
| """ |
| x (L, *) |
| mask (M,), M >= L |
| """ |
| M = mask.size(0) |
| L = x.size(0) |
| if M == L: |
| return x |
| assert M > L |
| assert mask.sum() == L |
| x_rearranged = torch.zeros((M, *x.size()[1:]), dtype=x.dtype, device=x.device) |
| x_rearranged[mask] = x |
| return x_rearranged |
|
|
|
|
| def frame_id_to_mask(frame_id, max_len): |
| mask = torch.zeros(max_len, dtype=torch.bool) |
| mask[frame_id] = True |
| return mask |
|
|
|
|
| def mask_to_frame_id(mask): |
| frame_id = torch.where(mask)[0] |
| return frame_id |
|
|
|
|
| def linear_interpolate_frame_ids(data, frame_id_list): |
| data = data.clone() |
| for i, invalid_frame_ids in enumerate(frame_id_list): |
| |
| |
| if invalid_frame_ids[0] - 1 < 0 or invalid_frame_ids[-1] + 1 >= len(data): |
| if invalid_frame_ids[0] - 1 < 0: |
| data[invalid_frame_ids] = data[invalid_frame_ids[-1] + 1].clone() |
| else: |
| data[invalid_frame_ids] = data[invalid_frame_ids[0] - 1].clone() |
| else: |
| prev = data[invalid_frame_ids[0] - 1] |
| next = data[invalid_frame_ids[-1] + 1] |
| data[invalid_frame_ids] = ( |
| torch.linspace(0, 1, len(invalid_frame_ids) + 2)[1:-1][:, None] |
| * (next - prev)[None] |
| + prev[None] |
| ) |
| return data |
|
|
|
|
| def linear_interpolate(data, N_middle_frames): |
| """ |
| Args: |
| data: (2, C) |
| Returns: |
| data_interpolated: (1+N+1, C) |
| """ |
| prev = data[0] |
| next = data[1] |
| middle = ( |
| torch.linspace(0, 1, N_middle_frames + 2)[1:-1][:, None] * (next - prev)[None] |
| + prev[None] |
| ) |
| data_interpolated = torch.cat( |
| [data[0][None], middle, data[1][None]], dim=0 |
| ) |
| return data_interpolated |
|
|
|
|
| def find_top_k_span(mask, k=3): |
| """ |
| Args: |
| mask: (L,) |
| Return: |
| topk_span: List of tuple, usage: [start, end) |
| """ |
| if isinstance(mask, np.ndarray): |
| mask = torch.from_numpy(mask) |
| if mask.sum() == 0: |
| return [] |
| mask = mask.clone().float() |
| mask = torch.cat([mask.new([0]), mask, mask.new([0])]) |
| diff = mask[1:] - mask[:-1] |
| start = torch.where(diff == 1)[0] |
| end = torch.where(diff == -1)[0] |
| assert len(start) == len(end) |
| span_lengths = end - start |
| span_lengths, idx = span_lengths.sort(descending=True) |
| start = start[idx] |
| end = end[idx] |
| return list(zip(start.tolist(), end.tolist()))[:k] |
|
|