|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_frame_id_list_from_mask(mask): |
|
|
|
|
|
""" |
|
|
Vectorized approach to get frame id list from a boolean mask. |
|
|
|
|
|
Args: |
|
|
mask (F,), bool tensor: Mask array where `True` indicates a frame to be processed. |
|
|
|
|
|
Returns: |
|
|
frame_id_list: List of torch.Tensors, each tensor containing continuous indices where mask is True. |
|
|
""" |
|
|
|
|
|
padded_mask = torch.cat( |
|
|
[ |
|
|
torch.tensor([False], device=mask.device), |
|
|
mask, |
|
|
torch.tensor([False], device=mask.device), |
|
|
] |
|
|
) |
|
|
diffs = torch.diff(padded_mask.int()) |
|
|
starts = (diffs == 1).nonzero(as_tuple=False).squeeze() |
|
|
ends = (diffs == -1).nonzero(as_tuple=False).squeeze() |
|
|
if starts.numel() == 0: |
|
|
return [] |
|
|
if starts.numel() == 1: |
|
|
starts = starts.reshape(-1) |
|
|
ends = ends.reshape(-1) |
|
|
|
|
|
|
|
|
frame_id_list = [torch.arange(start, end) for start, end in zip(starts, ends)] |
|
|
return frame_id_list |
|
|
|
|
|
|
|
|
def get_batch_frame_id_lists_from_mask_BLC(masks): |
|
|
|
|
|
""" |
|
|
处理三维掩码数组,为每个批次和通道提取连续True区段的索引列表。 |
|
|
|
|
|
参数: |
|
|
masks (B, L, C), 布尔张量:每个元素代表一个掩码,True表示需要处理的帧。 |
|
|
|
|
|
返回: |
|
|
batch_frame_id_lists: 对应于每个批次和每个通道的帧id列表的嵌套列表。 |
|
|
""" |
|
|
B, L, C = masks.size() |
|
|
|
|
|
padded_masks = torch.cat( |
|
|
[ |
|
|
torch.zeros((B, 1, C), dtype=torch.bool, device=masks.device), |
|
|
masks, |
|
|
torch.zeros((B, 1, C), dtype=torch.bool, device=masks.device), |
|
|
], |
|
|
dim=1, |
|
|
) |
|
|
|
|
|
diffs = torch.diff(padded_masks.int(), dim=1) |
|
|
starts = (diffs == 1).nonzero(as_tuple=True) |
|
|
ends = (diffs == -1).nonzero(as_tuple=True) |
|
|
|
|
|
|
|
|
batch_frame_id_lists = [[[] for _ in range(C)] for _ in range(B)] |
|
|
for b in range(B): |
|
|
for c in range(C): |
|
|
batch_start = starts[0][(starts[0] == b) & (starts[2] == c)] |
|
|
batch_end = ends[0][(ends[0] == b) & (ends[2] == c)] |
|
|
|
|
|
batch_frame_id_lists[b][c] = [ |
|
|
torch.arange(start.item(), end.item()) |
|
|
for start, end in zip(batch_start, batch_end) |
|
|
] |
|
|
|
|
|
return batch_frame_id_lists |
|
|
|
|
|
|
|
|
def get_frame_id_list_from_frame_id(frame_id): |
|
|
mask = torch.zeros(frame_id[-1] + 1, dtype=torch.bool) |
|
|
mask[frame_id] = True |
|
|
frame_id_list = get_frame_id_list_from_mask(mask) |
|
|
return frame_id_list |
|
|
|
|
|
|
|
|
def rearrange_by_mask(x, mask): |
|
|
""" |
|
|
x (L, *) |
|
|
mask (M,), M >= L |
|
|
""" |
|
|
M = mask.size(0) |
|
|
L = x.size(0) |
|
|
if M == L: |
|
|
return x |
|
|
assert M > L |
|
|
assert mask.sum() == L |
|
|
x_rearranged = torch.zeros((M, *x.size()[1:]), dtype=x.dtype, device=x.device) |
|
|
x_rearranged[mask] = x |
|
|
return x_rearranged |
|
|
|
|
|
|
|
|
def frame_id_to_mask(frame_id, max_len): |
|
|
mask = torch.zeros(max_len, dtype=torch.bool) |
|
|
mask[frame_id] = True |
|
|
return mask |
|
|
|
|
|
|
|
|
def mask_to_frame_id(mask): |
|
|
frame_id = torch.where(mask)[0] |
|
|
return frame_id |
|
|
|
|
|
|
|
|
def linear_interpolate_frame_ids(data, frame_id_list): |
|
|
data = data.clone() |
|
|
for i, invalid_frame_ids in enumerate(frame_id_list): |
|
|
|
|
|
|
|
|
if invalid_frame_ids[0] - 1 < 0 or invalid_frame_ids[-1] + 1 >= len(data): |
|
|
if invalid_frame_ids[0] - 1 < 0: |
|
|
data[invalid_frame_ids] = data[invalid_frame_ids[-1] + 1].clone() |
|
|
else: |
|
|
data[invalid_frame_ids] = data[invalid_frame_ids[0] - 1].clone() |
|
|
else: |
|
|
prev = data[invalid_frame_ids[0] - 1] |
|
|
next = data[invalid_frame_ids[-1] + 1] |
|
|
data[invalid_frame_ids] = ( |
|
|
torch.linspace(0, 1, len(invalid_frame_ids) + 2)[1:-1][:, None] |
|
|
* (next - prev)[None] |
|
|
+ prev[None] |
|
|
) |
|
|
return data |
|
|
|
|
|
|
|
|
def linear_interpolate(data, N_middle_frames): |
|
|
""" |
|
|
Args: |
|
|
data: (2, C) |
|
|
Returns: |
|
|
data_interpolated: (1+N+1, C) |
|
|
""" |
|
|
prev = data[0] |
|
|
next = data[1] |
|
|
middle = ( |
|
|
torch.linspace(0, 1, N_middle_frames + 2)[1:-1][:, None] * (next - prev)[None] |
|
|
+ prev[None] |
|
|
) |
|
|
data_interpolated = torch.cat( |
|
|
[data[0][None], middle, data[1][None]], dim=0 |
|
|
) |
|
|
return data_interpolated |
|
|
|
|
|
|
|
|
def find_top_k_span(mask, k=3): |
|
|
""" |
|
|
Args: |
|
|
mask: (L,) |
|
|
Return: |
|
|
topk_span: List of tuple, usage: [start, end) |
|
|
""" |
|
|
if isinstance(mask, np.ndarray): |
|
|
mask = torch.from_numpy(mask) |
|
|
if mask.sum() == 0: |
|
|
return [] |
|
|
mask = mask.clone().float() |
|
|
mask = torch.cat([mask.new([0]), mask, mask.new([0])]) |
|
|
diff = mask[1:] - mask[:-1] |
|
|
start = torch.where(diff == 1)[0] |
|
|
end = torch.where(diff == -1)[0] |
|
|
assert len(start) == len(end) |
|
|
span_lengths = end - start |
|
|
span_lengths, idx = span_lengths.sort(descending=True) |
|
|
start = start[idx] |
|
|
end = end[idx] |
|
|
return list(zip(start.tolist(), end.tolist()))[:k] |
|
|
|