Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| import torch | |
| from torch.nn.attention.flex_attention import create_block_mask | |
| from xformers.ops import fmha | |
| def patch_reduce(h, max_num_patches, reduction, patch_ids): | |
| """ | |
| Reduce variable length patches to single embedding per patch | |
| Note: this works with variable number of patches for different sequences in the batch | |
| It handles variable length patches by assuming that patch_lengths will be 0 for any | |
| extra patches on the *right*. Since there can be a variable number of patches | |
| this function also return the number of patches for each sequence in the batch. | |
| Any embeddings on the right that are not allocated to a patch | |
| (i.e. if the sum(patch_lengths[i]) < seq_len for any i) | |
| will be sent to a dummy patch, which is trimmed before returning. | |
| """ | |
| bs, seq_len, emb_dim = h.shape | |
| patch_ids = patch_ids.unsqueeze(-1).expand(-1, -1, h.shape[-1]) | |
| reduced_embs = torch.zeros( | |
| (bs, max_num_patches, emb_dim), dtype=h.dtype, device=h.device | |
| ) | |
| reduced_embs = reduced_embs.scatter_reduce( | |
| src=h, | |
| dim=1, | |
| index=patch_ids, | |
| reduce=reduction, | |
| include_self=False, | |
| ) | |
| reduced_embs = reduced_embs[:, :max_num_patches, :] | |
| return reduced_embs | |
| def concat_downsample(h, patch_lengths, patch_size): | |
| # The assumption in this function is that seq_len = patch_size * num_patches. | |
| bs, seq_len, emb_dim = h.shape | |
| patch_end_ids = torch.cumsum(patch_lengths, dim=1) | |
| patch_ids = patch_end_ids.unsqueeze(-1) - torch.arange(patch_size, 0, -1).to( | |
| patch_end_ids.device | |
| ) | |
| # Is clamp ok here? | |
| patch_ids = patch_ids.clamp(min=0).unsqueeze(-1).expand(-1, -1, -1, h.shape[-1]) | |
| patch_ids = patch_ids.view(bs, -1, emb_dim) | |
| # after gather h.shape = [batch_size, seq_len, dim] | |
| h = torch.gather(h, 1, patch_ids) | |
| h = h.reshape(bs, patch_lengths.shape[1], patch_size * h.size(-1)) | |
| return h | |
| def pooling_downsample(h, max_num_patches, pooling_mode, patch_ids): | |
| cat = [] | |
| if "avg" in pooling_mode or "mean" in pooling_mode: | |
| cat.append(patch_reduce(h, max_num_patches, "mean", patch_ids)) | |
| if "min" in pooling_mode: | |
| cat.append(patch_reduce(h, max_num_patches, "amin", patch_ids)) | |
| if "max" in pooling_mode: | |
| cat.append(patch_reduce(h, max_num_patches, "amax", patch_ids)) | |
| assert len(cat) > 0 | |
| h = torch.cat(cat, dim=-1) | |
| return h | |
| def downsample( | |
| h, | |
| num_patches, | |
| patch_lengths=None, | |
| patch_ids=None, | |
| downsampling_by_pooling=None, | |
| patch_size=4, | |
| ): | |
| """ | |
| Downsampling: | |
| a. concatenating embeddings in the patch | |
| Note: with dynamic patching, patch the last patch_size tokens. | |
| b. pooling embeddings in the patch | |
| """ | |
| # input: h.shape = [batch_size, seq_len, dim] | |
| # input: pool h.shape = [batch_size, seq_len / patch_size, dim] | |
| # if we don't use the cros_attn, we pool so that we convert bytes rep to patch rep | |
| if downsampling_by_pooling is not None and len(downsampling_by_pooling) > 0: | |
| # By pooling | |
| max_num_patches = num_patches | |
| assert patch_ids is not None | |
| h = pooling_downsample(h, max_num_patches, downsampling_by_pooling, patch_ids) | |
| else: | |
| # TODO: remove this condition | |
| # By concatenating (fixed lengths patching) | |
| assert patch_lengths is not None | |
| h = concat_downsample(h, patch_lengths, patch_size) | |
| return h | |
| def causal_mask(b, h, q_idx, kv_idx): | |
| return q_idx >= kv_idx | |
| def create_causal_mask(seqlen, attn_impl, sliding_window): | |
| if sliding_window is not None and attn_impl == "xformers": | |
| return fmha.attn_bias.LocalAttentionFromBottomRightMask( | |
| window_left=sliding_window - 1, window_right=0 | |
| ) | |
| elif attn_impl == "xformers": | |
| return fmha.attn_bias.LowerTriangularMask() | |
| elif attn_impl == "sdpa": | |
| return "causal" | |
| elif attn_impl == "flex_attention": | |
| return create_block_mask(causal_mask, None, None, seqlen, seqlen) | |
| elif attn_impl == "fmha": | |
| return None | |
| else: | |
| raise NotImplementedError( | |
| f"Attention {attn_impl} with {sliding_window} sliding window not implemented" | |
| ) | |