File size: 3,553 Bytes
5dbdc31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import torch.nn.functional as F
from typing import Literal, Optional, Tuple

@torch.no_grad()
def select_tokens(
    obj_masks: torch.Tensor,                
    grid_thw: Tuple[int,int,int],           
    *,
    patch_size: int = 14,
    spatial_merge_size: int = 2,             
    temporal_patch_size: int = 2,         
    coverage_thresh: float = 0.7,
    time_reduce: Literal["mean","max","all"] = "max", 
    device: str | torch.device = "cpu",
    retry_step: float = 0.1,                 
    retry_times: int = 1, 
    ensure_at_least_one: bool = True,        
    dtype: torch.dtype = torch.float32,
):
    if obj_masks.dim() == 3:
        obj_masks = obj_masks.unsqueeze(0)
    O, N, H_rz, W_rz = obj_masks.shape
    T, H, W = grid_thw
    m, g = spatial_merge_size, temporal_patch_size
    if N != T*g:
        if N < T * g:
            pad = T*g - N
            last = obj_masks[:,-1:,:,:].repeat(1, pad, 1, 1)
            obj_masks = torch.cat([obj_masks, last], dim=1)
            N = T * g 
        else:
            obj_masks = obj_masks[:, :T * g, :, :]
            N = T * g 
    Hm, Wm = H // m, W // m
    pix_h, pix_w = m * patch_size, m * patch_size
    assert H_rz % pix_h == 0 and W_rz % pix_w == 0, "resized // (28×28)"

    M = obj_masks.to(device=device, dtype=dtype).clamp(0, 1)

    M_flat = M.view(O*N, 1, H_rz, W_rz)
    cov_hw = F.avg_pool2d(M_flat, kernel_size=(pix_h, pix_w), stride=(pix_h, pix_w))  # (O*N,1,Hm,Wm)
    cov_hw = cov_hw.view(O, N, Hm, Wm) 

    cov_hw = cov_hw.view(O, T, g, Hm, Wm)
    if time_reduce == "mean":
        cov_thw = cov_hw.mean(dim=2)
    elif time_reduce == "max":
        cov_thw = cov_hw.max(dim=2).values
    elif time_reduce == "all":
        cov_thw = cov_hw.min(dim=2).values
    else:
        raise ValueError("time_reduce ∈ {'mean','max','all'}")

    per_obj_idx = []
    per_t = Hm * Wm
    for o in range(O):
        nz = torch.empty(0, 3, dtype=torch.long, device=device)
        tried = 0
        thr = coverage_thresh
        while tried <= retry_times:
            thr_eff = max(0.0, float(thr))
            sel = (cov_thw[o] >= thr_eff)
            nz = torch.nonzero(sel, as_tuple=False)
            if nz.numel() > 0:
                break
            tried += 1
            thr -= retry_step
        if nz.numel() == 0:
            if ensure_at_least_one:
                flat = cov_thw[o].reshape(-1)
                arg = torch.argmax(flat)
                t = arg // (Hm * Wm)
                rem = arg % (Hm * Wm)
                hp = rem // Wm
                wp = rem % Wm
                idx = (t * per_t + hp * Wm + wp).view(1)
                per_obj_idx.append(idx.to(device=device, dtype=torch.long))
            else:
                per_obj_idx.append(torch.empty(0, dtype=torch.long, device=device))
        else:
            t = nz[:, 0]
            hp = nz[:, 1]
            wp = nz[:, 2]
            idx = t * per_t + hp * Wm + wp
            per_obj_idx.append(idx.to(device=device, dtype=torch.long))

    if len(per_obj_idx) == 0:
        union_idx = torch.empty(0, dtype=torch.long, device=device)
    else:
        union_idx = torch.unique(torch.cat(per_obj_idx, dim=0)) if per_obj_idx[0].numel() else torch.empty(0, dtype=torch.long, device=device)

    union_idx_cpu = union_idx.cpu()
    per_obj_idx_cpu = [idx.cpu() for idx in per_obj_idx]
    cov_thw_cpu = cov_thw.cpu()

    del M, M_flat, cov_hw, cov_thw, per_obj_idx, union_idx
    if O > 0:
        del sel, nz

    return union_idx_cpu, per_obj_idx_cpu, cov_thw_cpu