| import pdb |
| from typing import Callable, Tuple |
|
|
| import torch |
|
|
|
|
| def init_generator(device: torch.device, fallback: torch.Generator = None): |
| """ |
| Forks the current default random generator given device. |
| """ |
| if device.type == "cpu": |
| return torch.Generator(device="cpu").set_state(torch.get_rng_state()) |
| elif device.type == "cuda": |
| return torch.Generator(device=device).set_state(torch.cuda.get_rng_state()) |
| else: |
| if fallback is None: |
| return init_generator(torch.device("cpu")) |
| else: |
| return fallback |
|
|
|
|
| def do_nothing(x: torch.Tensor, mode: str = None): |
| return x |
|
|
|
|
| def mps_gather_workaround(input, dim, index): |
| if input.shape[-1] == 1: |
| return torch.gather( |
| input.unsqueeze(-1), dim - 1 if dim < 0 else dim, index.unsqueeze(-1) |
| ).squeeze(-1) |
| else: |
| return torch.gather(input, dim, index) |
|
|
|
|
| def bipartite_soft_matching_random2d( |
| metric: torch.Tensor, |
| w: int, |
| h: int, |
| sx: int, |
| sy: int, |
| r: int, |
| no_rand: bool = False, |
| generator: torch.Generator = None, |
| ) -> Tuple[Callable, Callable]: |
| """ |
| Partitions the tokens into src and dst and merges r tokens from src to dst. |
| Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. |
| |
| Args: |
| - metric [B, N, C]: metric to use for similarity |
| - w: image width in tokens |
| - h: image height in tokens |
| - sx: stride in the x dimension for dst, must divide w |
| - sy: stride in the y dimension for dst, must divide h |
| - r: number of tokens to remove (by merging) |
| - no_rand: if true, disable randomness (use top left corner only) |
| - rand_seed: if no_rand is false, and if not None, sets random seed. |
| """ |
| B, N, _ = metric.shape |
|
|
| if r <= 0: |
| return do_nothing, do_nothing |
|
|
| gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather |
|
|
| with torch.no_grad(): |
|
|
| hsy, wsx = h // sy, w // sx |
|
|
| |
| if no_rand: |
| rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64) |
| else: |
| rand_idx = torch.randint( |
| sy * sx, |
| size=(hsy, wsx, 1), |
| device=generator.device, |
| generator=generator, |
| ).to(metric.device) |
|
|
| |
| idx_buffer_view = torch.zeros( |
| hsy, wsx, sy * sx, device=metric.device, dtype=torch.int64 |
| ) |
| idx_buffer_view.scatter_( |
| dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype) |
| ) |
| idx_buffer_view = ( |
| idx_buffer_view.view(hsy, wsx, sy, sx) |
| .transpose(1, 2) |
| .reshape(hsy * sy, wsx * sx) |
| ) |
|
|
| |
| if (hsy * sy) < h or (wsx * sx) < w: |
| idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64) |
| idx_buffer[: (hsy * sy), : (wsx * sx)] = idx_buffer_view |
| else: |
| idx_buffer = idx_buffer_view |
|
|
| |
| rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1) |
|
|
| |
| del idx_buffer, idx_buffer_view |
|
|
| |
| num_dst = hsy * wsx |
| a_idx = rand_idx[:, num_dst:, :] |
| b_idx = rand_idx[:, :num_dst, :] |
|
|
| def split(x): |
| C = x.shape[-1] |
| src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C)) |
| dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C)) |
| return src, dst |
|
|
| |
| metric = metric / metric.norm(dim=-1, keepdim=True) |
| a, b = split(metric) |
| scores = a @ b.transpose(-1, -2) |
|
|
| |
| r = min(a.shape[1], r) |
|
|
| |
| node_max, node_idx = scores.max(dim=-1) |
| edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] |
|
|
| unm_idx = edge_idx[..., r:, :] |
| src_idx = edge_idx[..., :r, :] |
| dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) |
|
|
| def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: |
| src, dst = split(x) |
| n, t1, c = src.shape |
|
|
| unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c)) |
| src = gather(src, dim=-2, index=src_idx.expand(n, r, c)) |
| dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) |
|
|
| return torch.cat([unm, dst], dim=1) |
|
|
| def unmerge(x: torch.Tensor) -> torch.Tensor: |
| unm_len = unm_idx.shape[1] |
| unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] |
| _, _, c = unm.shape |
|
|
| src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c)) |
|
|
| |
| out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype) |
| out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst) |
| out.scatter_( |
| dim=-2, |
| index=gather( |
| a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx |
| ).expand(B, unm_len, c), |
| src=unm, |
| ) |
| out.scatter_( |
| dim=-2, |
| index=gather( |
| a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx |
| ).expand(B, r, c), |
| src=src, |
| ) |
|
|
| return out |
|
|
| return merge, unmerge |
|
|