File size: 3,553 Bytes
5dbdc31 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 | import torch
import torch.nn.functional as F
from typing import Literal, Optional, Tuple
@torch.no_grad()
def select_tokens(
obj_masks: torch.Tensor,
grid_thw: Tuple[int,int,int],
*,
patch_size: int = 14,
spatial_merge_size: int = 2,
temporal_patch_size: int = 2,
coverage_thresh: float = 0.7,
time_reduce: Literal["mean","max","all"] = "max",
device: str | torch.device = "cpu",
retry_step: float = 0.1,
retry_times: int = 1,
ensure_at_least_one: bool = True,
dtype: torch.dtype = torch.float32,
):
if obj_masks.dim() == 3:
obj_masks = obj_masks.unsqueeze(0)
O, N, H_rz, W_rz = obj_masks.shape
T, H, W = grid_thw
m, g = spatial_merge_size, temporal_patch_size
if N != T*g:
if N < T * g:
pad = T*g - N
last = obj_masks[:,-1:,:,:].repeat(1, pad, 1, 1)
obj_masks = torch.cat([obj_masks, last], dim=1)
N = T * g
else:
obj_masks = obj_masks[:, :T * g, :, :]
N = T * g
Hm, Wm = H // m, W // m
pix_h, pix_w = m * patch_size, m * patch_size
assert H_rz % pix_h == 0 and W_rz % pix_w == 0, "resized // (28×28)"
M = obj_masks.to(device=device, dtype=dtype).clamp(0, 1)
M_flat = M.view(O*N, 1, H_rz, W_rz)
cov_hw = F.avg_pool2d(M_flat, kernel_size=(pix_h, pix_w), stride=(pix_h, pix_w)) # (O*N,1,Hm,Wm)
cov_hw = cov_hw.view(O, N, Hm, Wm)
cov_hw = cov_hw.view(O, T, g, Hm, Wm)
if time_reduce == "mean":
cov_thw = cov_hw.mean(dim=2)
elif time_reduce == "max":
cov_thw = cov_hw.max(dim=2).values
elif time_reduce == "all":
cov_thw = cov_hw.min(dim=2).values
else:
raise ValueError("time_reduce ∈ {'mean','max','all'}")
per_obj_idx = []
per_t = Hm * Wm
for o in range(O):
nz = torch.empty(0, 3, dtype=torch.long, device=device)
tried = 0
thr = coverage_thresh
while tried <= retry_times:
thr_eff = max(0.0, float(thr))
sel = (cov_thw[o] >= thr_eff)
nz = torch.nonzero(sel, as_tuple=False)
if nz.numel() > 0:
break
tried += 1
thr -= retry_step
if nz.numel() == 0:
if ensure_at_least_one:
flat = cov_thw[o].reshape(-1)
arg = torch.argmax(flat)
t = arg // (Hm * Wm)
rem = arg % (Hm * Wm)
hp = rem // Wm
wp = rem % Wm
idx = (t * per_t + hp * Wm + wp).view(1)
per_obj_idx.append(idx.to(device=device, dtype=torch.long))
else:
per_obj_idx.append(torch.empty(0, dtype=torch.long, device=device))
else:
t = nz[:, 0]
hp = nz[:, 1]
wp = nz[:, 2]
idx = t * per_t + hp * Wm + wp
per_obj_idx.append(idx.to(device=device, dtype=torch.long))
if len(per_obj_idx) == 0:
union_idx = torch.empty(0, dtype=torch.long, device=device)
else:
union_idx = torch.unique(torch.cat(per_obj_idx, dim=0)) if per_obj_idx[0].numel() else torch.empty(0, dtype=torch.long, device=device)
union_idx_cpu = union_idx.cpu()
per_obj_idx_cpu = [idx.cpu() for idx in per_obj_idx]
cov_thw_cpu = cov_thw.cpu()
del M, M_flat, cov_hw, cov_thw, per_obj_idx, union_idx
if O > 0:
del sel, nz
return union_idx_cpu, per_obj_idx_cpu, cov_thw_cpu
|