import torch import torch.nn.functional as F from typing import Literal, Optional, Tuple @torch.no_grad() def select_tokens( obj_masks: torch.Tensor, grid_thw: Tuple[int,int,int], *, patch_size: int = 14, spatial_merge_size: int = 2, temporal_patch_size: int = 2, coverage_thresh: float = 0.7, time_reduce: Literal["mean","max","all"] = "max", device: str | torch.device = "cpu", retry_step: float = 0.1, retry_times: int = 1, ensure_at_least_one: bool = True, dtype: torch.dtype = torch.float32, ): if obj_masks.dim() == 3: obj_masks = obj_masks.unsqueeze(0) O, N, H_rz, W_rz = obj_masks.shape T, H, W = grid_thw m, g = spatial_merge_size, temporal_patch_size if N != T*g: if N < T * g: pad = T*g - N last = obj_masks[:,-1:,:,:].repeat(1, pad, 1, 1) obj_masks = torch.cat([obj_masks, last], dim=1) N = T * g else: obj_masks = obj_masks[:, :T * g, :, :] N = T * g Hm, Wm = H // m, W // m pix_h, pix_w = m * patch_size, m * patch_size assert H_rz % pix_h == 0 and W_rz % pix_w == 0, "resized // (28×28)" M = obj_masks.to(device=device, dtype=dtype).clamp(0, 1) M_flat = M.view(O*N, 1, H_rz, W_rz) cov_hw = F.avg_pool2d(M_flat, kernel_size=(pix_h, pix_w), stride=(pix_h, pix_w)) # (O*N,1,Hm,Wm) cov_hw = cov_hw.view(O, N, Hm, Wm) cov_hw = cov_hw.view(O, T, g, Hm, Wm) if time_reduce == "mean": cov_thw = cov_hw.mean(dim=2) elif time_reduce == "max": cov_thw = cov_hw.max(dim=2).values elif time_reduce == "all": cov_thw = cov_hw.min(dim=2).values else: raise ValueError("time_reduce ∈ {'mean','max','all'}") per_obj_idx = [] per_t = Hm * Wm for o in range(O): nz = torch.empty(0, 3, dtype=torch.long, device=device) tried = 0 thr = coverage_thresh while tried <= retry_times: thr_eff = max(0.0, float(thr)) sel = (cov_thw[o] >= thr_eff) nz = torch.nonzero(sel, as_tuple=False) if nz.numel() > 0: break tried += 1 thr -= retry_step if nz.numel() == 0: if ensure_at_least_one: flat = cov_thw[o].reshape(-1) arg = torch.argmax(flat) t = arg // (Hm * Wm) rem = arg % (Hm * Wm) hp = rem // Wm wp = rem % Wm idx = (t * per_t + hp * Wm + wp).view(1) per_obj_idx.append(idx.to(device=device, dtype=torch.long)) else: per_obj_idx.append(torch.empty(0, dtype=torch.long, device=device)) else: t = nz[:, 0] hp = nz[:, 1] wp = nz[:, 2] idx = t * per_t + hp * Wm + wp per_obj_idx.append(idx.to(device=device, dtype=torch.long)) if len(per_obj_idx) == 0: union_idx = torch.empty(0, dtype=torch.long, device=device) else: union_idx = torch.unique(torch.cat(per_obj_idx, dim=0)) if per_obj_idx[0].numel() else torch.empty(0, dtype=torch.long, device=device) union_idx_cpu = union_idx.cpu() per_obj_idx_cpu = [idx.cpu() for idx in per_obj_idx] cov_thw_cpu = cov_thw.cpu() del M, M_flat, cov_hw, cov_thw, per_obj_idx, union_idx if O > 0: del sel, nz return union_idx_cpu, per_obj_idx_cpu, cov_thw_cpu