|
|
import torch |
|
|
import math |
|
|
from typing import Dict, Any, Tuple, Callable |
|
|
|
|
|
|
|
|
""" |
|
|
Copied and adapted from https://github.com/dbolya/tomesd/tree/main |
|
|
Relevant files: |
|
|
- https://github.com/dbolya/tomesd/blob/main/tomesd/merge.py |
|
|
- https://github.com/dbolya/tomesd/blob/main/tomesd/patching.py |
|
|
""" |
|
|
|
|
|
def init_generator(device: torch.device, fallback: torch.Generator=None, seed: int = 42): |
|
|
""" |
|
|
Forks the current default random generator given device. |
|
|
""" |
|
|
if device.type == "cpu": |
|
|
return torch.Generator(device="cpu").manual_seed(seed) |
|
|
elif device.type == "cuda": |
|
|
return torch.Generator(device=device).manual_seed(seed) |
|
|
else: |
|
|
if fallback is None: |
|
|
return init_generator(torch.device("cpu")) |
|
|
else: |
|
|
return fallback |
|
|
|
|
|
def do_nothing(x: torch.Tensor, mode: str = None): |
|
|
return x |
|
|
|
|
|
|
|
|
def mps_gather_workaround(input, dim, index): |
|
|
if input.shape[-1] == 1: |
|
|
return torch.gather( |
|
|
input.unsqueeze(-1), |
|
|
dim - 1 if dim < 0 else dim, |
|
|
index.unsqueeze(-1) |
|
|
).squeeze(-1) |
|
|
else: |
|
|
return torch.gather(input, dim, index) |
|
|
|
|
|
|
|
|
def bipartite_soft_matching_random2d(metric: torch.Tensor, |
|
|
w: int, h: int, sx: int, sy: int, r: int, |
|
|
no_rand: bool = False, |
|
|
generator: torch.Generator = None) -> Tuple[Callable, Callable]: |
|
|
""" |
|
|
Partitions the tokens into src and dst and merges r tokens from src to dst. |
|
|
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region. |
|
|
|
|
|
Args: |
|
|
- metric [B, N, C]: metric to use for similarity |
|
|
- w: image width in tokens |
|
|
- h: image height in tokens |
|
|
- sx: stride in the x dimension for dst, must divide w |
|
|
- sy: stride in the y dimension for dst, must divide h |
|
|
- r: number of tokens to remove (by merging) |
|
|
- no_rand: if true, disable randomness (use top left corner only) |
|
|
- rand_seed: if no_rand is false, and if not None, sets random seed. |
|
|
""" |
|
|
B, N, _ = metric.shape |
|
|
|
|
|
if r <= 0: |
|
|
return do_nothing, do_nothing |
|
|
|
|
|
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather |
|
|
|
|
|
with torch.no_grad(): |
|
|
hsy, wsx = h // sy, w // sx |
|
|
|
|
|
|
|
|
if no_rand: |
|
|
rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64) |
|
|
else: |
|
|
rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=generator.device, generator=generator).to(metric.device) |
|
|
|
|
|
|
|
|
idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64) |
|
|
idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype)) |
|
|
idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx) |
|
|
|
|
|
|
|
|
if (hsy * sy) < h or (wsx * sx) < w: |
|
|
idx_buffer = torch.zeros(h, w, device=metric.device, dtype=torch.int64) |
|
|
idx_buffer[:(hsy * sy), :(wsx * sx)] = idx_buffer_view |
|
|
else: |
|
|
idx_buffer = idx_buffer_view |
|
|
|
|
|
|
|
|
rand_idx = idx_buffer.reshape(1, -1, 1).argsort(dim=1) |
|
|
|
|
|
|
|
|
del idx_buffer, idx_buffer_view |
|
|
|
|
|
|
|
|
num_dst = hsy * wsx |
|
|
a_idx = rand_idx[:, num_dst:, :] |
|
|
b_idx = rand_idx[:, :num_dst, :] |
|
|
|
|
|
def split(x): |
|
|
C = x.shape[-1] |
|
|
src = gather(x, dim=1, index=a_idx.expand(B, N - num_dst, C)) |
|
|
dst = gather(x, dim=1, index=b_idx.expand(B, num_dst, C)) |
|
|
return src, dst |
|
|
|
|
|
|
|
|
metric = metric / metric.norm(dim=-1, keepdim=True) |
|
|
a, b = split(metric) |
|
|
scores = a @ b.transpose(-1, -2) |
|
|
|
|
|
|
|
|
r = min(a.shape[1], r) |
|
|
|
|
|
|
|
|
node_max, node_idx = scores.max(dim=-1) |
|
|
edge_idx = node_max.argsort(dim=-1, descending=True)[..., None] |
|
|
|
|
|
unm_idx = edge_idx[..., r:, :] |
|
|
src_idx = edge_idx[..., :r, :] |
|
|
dst_idx = gather(node_idx[..., None], dim=-2, index=src_idx) |
|
|
|
|
|
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: |
|
|
src, dst = split(x) |
|
|
n, t1, c = src.shape |
|
|
|
|
|
unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c)) |
|
|
src = gather(src, dim=-2, index=src_idx.expand(n, r, c)) |
|
|
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode) |
|
|
|
|
|
return torch.cat([unm, dst], dim=1) |
|
|
|
|
|
def unmerge(x: torch.Tensor) -> torch.Tensor: |
|
|
unm_len = unm_idx.shape[1] |
|
|
unm, dst = x[..., :unm_len, :], x[..., unm_len:, :] |
|
|
_, _, c = unm.shape |
|
|
|
|
|
src = gather(dst, dim=-2, index=dst_idx.expand(B, r, c)) |
|
|
|
|
|
|
|
|
out = torch.zeros(B, N, c, device=x.device, dtype=x.dtype) |
|
|
out.scatter_(dim=-2, index=b_idx.expand(B, num_dst, c), src=dst) |
|
|
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=unm_idx).expand(B, unm_len, c), src=unm) |
|
|
out.scatter_(dim=-2, index=gather(a_idx.expand(B, a_idx.shape[1], 1), dim=1, index=src_idx).expand(B, r, c), src=src) |
|
|
|
|
|
return out |
|
|
|
|
|
return merge, unmerge |
|
|
|
|
|
|
|
|
def compute_merge( |
|
|
x: torch.Tensor, |
|
|
args: Dict[str, Any], |
|
|
size: Tuple[int, int], |
|
|
max_tokens: int = None, |
|
|
ratio: float = None, |
|
|
) -> Tuple[Callable, ...]: |
|
|
if not args["enabled"]: |
|
|
return do_nothing, do_nothing |
|
|
|
|
|
if max_tokens is None and ratio is None: |
|
|
raise ValueError("Must specify either max_tokens or ratio") |
|
|
|
|
|
original_h, original_w = size |
|
|
original_tokens = original_h * original_w |
|
|
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1]))) |
|
|
|
|
|
if ratio is not None: |
|
|
target_tokens = int(x.shape[1] * (1 - ratio)) |
|
|
else: |
|
|
target_tokens = x.shape[1] |
|
|
|
|
|
if max_tokens is not None and max_tokens > 0: |
|
|
target_tokens = min(target_tokens, max_tokens) |
|
|
r = x.shape[1] - target_tokens |
|
|
|
|
|
if r > 0: |
|
|
w = int(math.ceil(original_w / downsample)) |
|
|
h = int(math.ceil(original_h / downsample)) |
|
|
|
|
|
|
|
|
if args["generator"] is None: |
|
|
args["generator"] = init_generator(x.device, seed=args["seed"]) |
|
|
elif args["generator"].device != x.device: |
|
|
args["generator"] = init_generator(x.device, fallback=args["generator"], seed=args["seed"]) |
|
|
|
|
|
|
|
|
|
|
|
use_rand = False if x.shape[0] % 2 == 1 else args["use_rand"] |
|
|
return bipartite_soft_matching_random2d(x, w, h, args["sx"], args["sy"], r, |
|
|
no_rand=not use_rand, generator=args["generator"]) |
|
|
else: |
|
|
return do_nothing, do_nothing |
|
|
|