|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from typing import List, Tuple, Optional |
|
|
|
|
|
|
|
|
def sample_gumbel(shape: torch.Size, eps: float = 1e-6, device=None, dtype=None) -> torch.Tensor: |
|
|
U = torch.rand(shape, device=device, dtype=dtype) |
|
|
return -torch.log(-torch.log(U.clamp(min=eps, max=1 - eps))) |
|
|
|
|
|
|
|
|
def select_topk( |
|
|
logits: torch.Tensor, |
|
|
k: int, |
|
|
method: str, |
|
|
temperature: float, |
|
|
hard: bool, |
|
|
eps: float |
|
|
) -> torch.Tensor: |
|
|
B, N = logits.shape |
|
|
|
|
|
if method == 'topk': |
|
|
topk_vals, topk_idx = torch.topk(logits, k, dim=-1) |
|
|
mask = torch.zeros_like(logits).scatter(-1, topk_idx, 1.0) |
|
|
elif method == 'softmax': |
|
|
gumbel_noise = sample_gumbel(logits.shape, eps=eps, device=logits.device, dtype=logits.dtype) |
|
|
y = (logits + gumbel_noise) / temperature |
|
|
y_soft = F.softmax(y, dim=-1) |
|
|
|
|
|
if hard: |
|
|
topk_idx = y_soft.topk(k, dim=-1).indices |
|
|
hard_mask = torch.zeros_like(y_soft).scatter(-1, topk_idx, 1.0) |
|
|
mask = hard_mask - y_soft.detach() + y_soft |
|
|
else: |
|
|
mask = y_soft |
|
|
else: |
|
|
raise ValueError(f"Unknown method: {method}") |
|
|
|
|
|
return mask |
|
|
|
|
|
|
|
|
def global_selection( |
|
|
mask_logits: torch.Tensor, |
|
|
total_k: int, |
|
|
method: str, |
|
|
temperature: float, |
|
|
hard: bool, |
|
|
eps: float |
|
|
) -> torch.Tensor: |
|
|
B, T, H, W = mask_logits.shape |
|
|
N = T * H * W |
|
|
logits_flat = mask_logits.reshape(B, N) |
|
|
mask_flat = select_topk(logits_flat, total_k, method, temperature, hard, eps) |
|
|
mask = mask_flat.reshape(B, T, H, W) |
|
|
return mask |
|
|
|
|
|
|
|
|
def structured_selection( |
|
|
mask_logits: torch.Tensor, |
|
|
k_t: int, |
|
|
k_hw: int, |
|
|
method: str, |
|
|
temperature: float, |
|
|
hard: bool, |
|
|
eps: float |
|
|
) -> torch.Tensor: |
|
|
B, T, H, W = mask_logits.shape |
|
|
|
|
|
|
|
|
logits_t = mask_logits.mean(dim=[2, 3]) |
|
|
mask_t = select_topk(logits_t, k_t, method, temperature, hard, eps) |
|
|
|
|
|
|
|
|
mask_spatial = [] |
|
|
for b in range(B): |
|
|
mask_b = [] |
|
|
for t in range(T): |
|
|
logits_hw = mask_logits[b, t].reshape(-1) |
|
|
mask_hw = select_topk(logits_hw.unsqueeze(0), k_hw, method, temperature, hard, eps) |
|
|
mask_b.append(mask_hw.reshape(H, W)) |
|
|
mask_b = torch.stack(mask_b, dim=0) |
|
|
mask_spatial.append(mask_b) |
|
|
mask_spatial = torch.stack(mask_spatial, dim=0) |
|
|
|
|
|
|
|
|
mask = mask_spatial * mask_t.unsqueeze(-1).unsqueeze(-1) |
|
|
return mask |
|
|
|
|
|
|
|
|
def apply_mask_and_select( |
|
|
tokens: torch.Tensor, |
|
|
other_tensors: List[torch.Tensor], |
|
|
mask: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]: |
|
|
B, C, T, H, W = tokens.shape |
|
|
N = T * H * W |
|
|
|
|
|
tokens_flat = tokens.reshape(B, C, N) |
|
|
mask_flat = mask.reshape(B, N) |
|
|
|
|
|
selected_tokens = [] |
|
|
selected_others = [[] for _ in other_tensors] |
|
|
|
|
|
for b in range(B): |
|
|
idx = mask_flat[b].nonzero(as_tuple=False).squeeze(-1) |
|
|
selected_tokens.append(tokens_flat[b, :, idx]) |
|
|
|
|
|
for i, t in enumerate(other_tensors): |
|
|
t_flat = t.reshape(B, -1, N) |
|
|
selected = t_flat[b, :, idx] |
|
|
selected_others[i].append(selected) |
|
|
|
|
|
tokens_out = torch.stack(selected_tokens, dim=0) |
|
|
others_out = [torch.stack(x, dim=0) for x in selected_others] |
|
|
|
|
|
return tokens_out, others_out |
|
|
|
|
|
|
|
|
def process_tensors( |
|
|
tokens: torch.Tensor, |
|
|
mask_logits: torch.Tensor, |
|
|
other_tensors: List[torch.Tensor], |
|
|
total_k: Optional[int] = None, |
|
|
k_t: Optional[int] = None, |
|
|
k_hw: Optional[int] = None, |
|
|
temperature: float = 1.0, |
|
|
eps: float = 1e-6, |
|
|
training: bool = True, |
|
|
soft_inference: bool = True, |
|
|
) -> Tuple[torch.Tensor, List[torch.Tensor]]: |
|
|
""" |
|
|
If training=True -> uses softmax + hard sampling (Gumbel-Softmax trick) |
|
|
If training=False -> uses topk (non-differentiable) |
|
|
""" |
|
|
B, C, T, H, W = tokens.shape |
|
|
mask_logits = mask_logits.squeeze(1) |
|
|
|
|
|
if training or soft_inference: |
|
|
method = 'softmax' |
|
|
hard = True |
|
|
else: |
|
|
method = 'topk' |
|
|
hard = False |
|
|
|
|
|
if total_k is not None: |
|
|
mask = global_selection(mask_logits, total_k, method, temperature, hard, eps) |
|
|
elif k_t is not None and k_hw is not None: |
|
|
mask = structured_selection(mask_logits, k_t, k_hw, method, temperature, hard, eps) |
|
|
else: |
|
|
raise ValueError("Provide either total_k or both k_t and k_hw.") |
|
|
tokens_out, others_out = apply_mask_and_select(tokens, other_tensors, mask) |
|
|
return tokens_out, others_out, mask |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
temperature = 1.0 |
|
|
training = False |
|
|
k_t = 9 |
|
|
k_h = 90 |
|
|
k_w = 160 |
|
|
B, T, C, H, W = 2, 17, 3, 180, 320 |
|
|
tokens = torch.randn(B, C, T, H, W) |
|
|
mask_logits = torch.randn(B, 1, T, H, W) |
|
|
|
|
|
|
|
|
other1 = torch.randn(B, 6, T, H, W) |
|
|
other2 = torch.randn(B, 9, T, H, W) |
|
|
|
|
|
tokens_out, others_out = process_tensors( |
|
|
tokens=tokens, |
|
|
mask_logits=mask_logits, |
|
|
other_tensors=[other1, other2], |
|
|
k_t=k_t, |
|
|
k_hw=k_h * k_w, |
|
|
temperature=temperature, |
|
|
training=training, |
|
|
) |
|
|
|
|
|
tokens_out, others_out = process_tensors( |
|
|
tokens=tokens, |
|
|
mask_logits=mask_logits, |
|
|
other_tensors=[other1, other2], |
|
|
total_k=k_t * k_h * k_w, |
|
|
temperature=temperature, |
|
|
training=training, |
|
|
) |
|
|
|