| | import torch |
| | import torch.nn.functional as F |
| | from typing import Literal, Optional, Tuple |
| |
|
| | @torch.no_grad() |
| | def select_tokens( |
| | obj_masks: torch.Tensor, |
| | grid_thw: Tuple[int,int,int], |
| | *, |
| | patch_size: int = 14, |
| | spatial_merge_size: int = 2, |
| | temporal_patch_size: int = 2, |
| | coverage_thresh: float = 0.7, |
| | time_reduce: Literal["mean","max","all"] = "max", |
| | device: str | torch.device = "cpu", |
| | retry_step: float = 0.1, |
| | retry_times: int = 1, |
| | ensure_at_least_one: bool = True, |
| | dtype: torch.dtype = torch.float32, |
| | ): |
| | if obj_masks.dim() == 3: |
| | obj_masks = obj_masks.unsqueeze(0) |
| | O, N, H_rz, W_rz = obj_masks.shape |
| | T, H, W = grid_thw |
| | m, g = spatial_merge_size, temporal_patch_size |
| | if N != T*g: |
| | if N < T * g: |
| | pad = T*g - N |
| | last = obj_masks[:,-1:,:,:].repeat(1, pad, 1, 1) |
| | obj_masks = torch.cat([obj_masks, last], dim=1) |
| | N = T * g |
| | else: |
| | obj_masks = obj_masks[:, :T * g, :, :] |
| | N = T * g |
| | Hm, Wm = H // m, W // m |
| | pix_h, pix_w = m * patch_size, m * patch_size |
| | assert H_rz % pix_h == 0 and W_rz % pix_w == 0, "resized // (28×28)" |
| |
|
| | M = obj_masks.to(device=device, dtype=dtype).clamp(0, 1) |
| |
|
| | M_flat = M.view(O*N, 1, H_rz, W_rz) |
| | cov_hw = F.avg_pool2d(M_flat, kernel_size=(pix_h, pix_w), stride=(pix_h, pix_w)) |
| | cov_hw = cov_hw.view(O, N, Hm, Wm) |
| |
|
| | cov_hw = cov_hw.view(O, T, g, Hm, Wm) |
| | if time_reduce == "mean": |
| | cov_thw = cov_hw.mean(dim=2) |
| | elif time_reduce == "max": |
| | cov_thw = cov_hw.max(dim=2).values |
| | elif time_reduce == "all": |
| | cov_thw = cov_hw.min(dim=2).values |
| | else: |
| | raise ValueError("time_reduce ∈ {'mean','max','all'}") |
| |
|
| | per_obj_idx = [] |
| | per_t = Hm * Wm |
| | for o in range(O): |
| | nz = torch.empty(0, 3, dtype=torch.long, device=device) |
| | tried = 0 |
| | thr = coverage_thresh |
| | while tried <= retry_times: |
| | thr_eff = max(0.0, float(thr)) |
| | sel = (cov_thw[o] >= thr_eff) |
| | nz = torch.nonzero(sel, as_tuple=False) |
| | if nz.numel() > 0: |
| | break |
| | tried += 1 |
| | thr -= retry_step |
| | if nz.numel() == 0: |
| | if ensure_at_least_one: |
| | flat = cov_thw[o].reshape(-1) |
| | arg = torch.argmax(flat) |
| | t = arg // (Hm * Wm) |
| | rem = arg % (Hm * Wm) |
| | hp = rem // Wm |
| | wp = rem % Wm |
| | idx = (t * per_t + hp * Wm + wp).view(1) |
| | per_obj_idx.append(idx.to(device=device, dtype=torch.long)) |
| | else: |
| | per_obj_idx.append(torch.empty(0, dtype=torch.long, device=device)) |
| | else: |
| | t = nz[:, 0] |
| | hp = nz[:, 1] |
| | wp = nz[:, 2] |
| | idx = t * per_t + hp * Wm + wp |
| | per_obj_idx.append(idx.to(device=device, dtype=torch.long)) |
| |
|
| | if len(per_obj_idx) == 0: |
| | union_idx = torch.empty(0, dtype=torch.long, device=device) |
| | else: |
| | union_idx = torch.unique(torch.cat(per_obj_idx, dim=0)) if per_obj_idx[0].numel() else torch.empty(0, dtype=torch.long, device=device) |
| |
|
| | union_idx_cpu = union_idx.cpu() |
| | per_obj_idx_cpu = [idx.cpu() for idx in per_obj_idx] |
| | cov_thw_cpu = cov_thw.cpu() |
| |
|
| | del M, M_flat, cov_hw, cov_thw, per_obj_idx, union_idx |
| | if O > 0: |
| | del sel, nz |
| |
|
| | return union_idx_cpu, per_obj_idx_cpu, cov_thw_cpu |
| |
|