| import os |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torchvision.ops import StochasticDepth |
| import math |
| import warnings |
| try: |
| import torch.nn.attention.varlen as varlen |
| HAS_VARLEN_FLASH_ATTENTION = True |
| except ImportError: |
| warnings.warn( |
| "Could not import torch.nn.attention.varlen, variable length Flash Attention is disabled.", |
| category=UserWarning, |
| stacklevel=2, |
| ) |
| HAS_VARLEN_FLASH_ATTENTION = False |
|
|
| enable_fa = os.environ.get('DISABLE_FA', '0').lower() not in {"1", "true", "yes", "y", "on"} |
| HAS_VARLEN_FLASH_ATTENTION = HAS_VARLEN_FLASH_ATTENTION and enable_fa |
|
|
| class SoftMaskedMultiheadAttention(nn.Module): |
| def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True, |
| add_bias_kv=True, kdim=None, vdim=None, |
| scale=8., device=None, dtype=None): |
| super().__init__() |
| factory_kwargs = {'device': device, 'dtype': dtype} |
| self.embed_dim = embed_dim |
| self.kdim = kdim if kdim is not None else embed_dim |
| self.vdim = vdim if vdim is not None else embed_dim |
| self.num_heads = num_heads |
| self.dropout = dropout |
| self.scale = scale |
|
|
| assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" |
|
|
| self.head_dim = embed_dim // num_heads |
|
|
| self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) |
| self.k_proj = nn.Linear(self.kdim, embed_dim, bias=bias and add_bias_kv, **factory_kwargs) |
| self.v_proj = nn.Linear(self.vdim, embed_dim, bias=bias and add_bias_kv, **factory_kwargs) |
|
|
| self.dropout_layer = nn.Dropout(dropout) |
|
|
| self.out_proj = nn.Linear(embed_dim, embed_dim) |
|
|
| self._reset_parameters() |
|
|
| def _reset_parameters(self): |
| nn.init.xavier_uniform_(self.q_proj.weight) |
| nn.init.xavier_uniform_(self.k_proj.weight) |
| nn.init.xavier_uniform_(self.v_proj.weight) |
|
|
| if self.q_proj.bias is not None: |
| nn.init.constant_(self.q_proj.bias, 0.) |
| if self.k_proj.bias is not None: |
| nn.init.constant_(self.k_proj.bias, 0.) |
| if self.v_proj.bias is not None: |
| nn.init.constant_(self.v_proj.bias, 0.) |
|
|
| nn.init.xavier_uniform_(self.out_proj.weight) |
| if self.v_proj.bias is not None: |
| nn.init.constant_(self.out_proj.bias, 0.) |
|
|
|
|
| def eager_forward(self, query, key, value, key_padding_mask=None, |
| attn_mask=None, average_attn_weights=True): |
| batch_size, tgt_len, embed_dim = query.size() |
| batch_size, src_len, _ = key.size() |
|
|
| q = self.q_proj(query) |
| k = self.k_proj(key) |
| v = self.v_proj(value) |
|
|
| |
| q = q.view(batch_size, tgt_len, self.num_heads, self.head_dim).transpose(1,2) |
| k = k.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1,2) |
| v = v.view(batch_size, src_len, self.num_heads, self.head_dim).transpose(1,2) |
|
|
| |
| scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) |
| |
|
|
| |
| if attn_mask is not None: |
| |
| |
| attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) |
| eps = 1e-6 |
| attn_mask_l = attn_mask.clip(min=eps).log() |
| if not self.training: |
| attn_mask_l = attn_mask_l.masked_fill((attn_mask == 0.), float('-inf')) |
| attn_mask = attn_mask_l |
| |
| scores = scores + self.scale * attn_mask |
|
|
| |
| if key_padding_mask is not None: |
| key_padding_mask = key_padding_mask.view(batch_size, 1, 1, src_len) |
| scores = scores.masked_fill(key_padding_mask, float('-inf')) |
|
|
| |
| attn_weights = F.softmax(scores, dim=-1) |
| attn_weights = self.dropout_layer(attn_weights) |
|
|
| |
| attn_output = torch.matmul(attn_weights, v) |
| |
|
|
| |
| attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, tgt_len, embed_dim) |
|
|
| attn_output = self.out_proj(attn_output) |
|
|
| return attn_output |
|
|
| def flash_forward( |
| self, |
| query, key, value, |
| cu_seq_q, cu_seq_k, |
| max_q, max_k, |
| attn_mask=None, |
| ): |
| """ |
| FlashAttention-compatible soft-masked attention using varlen_attn |
| """ |
|
|
| q = self.q_proj(query) |
| k = self.k_proj(key) |
| v = self.v_proj(value) |
|
|
| Tq = q.shape[0] |
| Tk = k.shape[0] |
|
|
| q = q.view(Tq, self.num_heads, self.head_dim) |
| k = k.view(Tk, self.num_heads, self.head_dim) |
| v = v.view(Tk, self.num_heads, self.head_dim) |
|
|
| |
| if attn_mask is not None: |
| |
| eps = 1e-6 |
| attn_mask_l = attn_mask.clip(min=eps).log() |
| if not self.training: |
| attn_mask_l = attn_mask_l.masked_fill((attn_mask == 0.), float('-inf')) |
| log_m = attn_mask_l |
|
|
| |
| log_m = log_m.view(Tk, 1, 1).expand(-1, self.num_heads, 1) |
| k_zeros = torch.zeros_like(log_m).expand(-1, -1, 7) |
|
|
| |
| |
| |
| scale_attn = 1.0 / math.sqrt(self.head_dim) |
|
|
| k_extra = log_m * (self.scale / scale_attn) |
|
|
| k = torch.cat([k, k_extra, k_zeros], dim=-1) |
|
|
| v_zeros = torch.zeros(Tk, self.num_heads, 8, device=v.device, dtype=v.dtype) |
| v = torch.cat([v, v_zeros], dim=-1) |
|
|
| q_ones = torch.ones( |
| Tq, self.num_heads, 8, |
| device=q.device, dtype=q.dtype |
| ) |
| q = torch.cat([q, q_ones], dim=-1) |
|
|
| attn_dim = self.head_dim + 1 |
| else: |
| attn_dim = self.head_dim |
| scale_attn = 1.0 / math.sqrt(self.head_dim) |
|
|
| |
| out = varlen.varlen_attn( |
| query=q, |
| key=k, |
| value=v, |
| cu_seq_q=cu_seq_q, |
| cu_seq_k=cu_seq_k, |
| max_q=max_q, |
| max_k=max_k, |
| scale=scale_attn, |
| ) |
|
|
| |
| out = out[..., :self.head_dim] |
| out = out.reshape(Tq, self.num_heads * self.head_dim) |
| out = self.out_proj(out) |
|
|
| return out |
|
|
| def forward(self, query, key, value, method="eager", **kwargs): |
| if method == 'eager': |
| out = self.eager_forward(query, key, value, **kwargs) |
| elif method == "fa": |
| out = self.flash_forward(query, key, value, **kwargs) |
| else: |
| raise ValueError(f"No attention method named {method}.") |
| return out |
|
|
|
|
| def get_ffn(input_dim, output_dim, middle_dim, dropout=0.1): |
| fc1 = nn.Linear(input_dim, middle_dim) |
| fc2 = nn.Linear(middle_dim, output_dim) |
| fc3 = nn.Identity() |
| return nn.Sequential( |
| fc1, |
| nn.GELU(), |
| nn.Dropout(dropout), |
| fc2, |
| nn.Dropout(dropout), |
| fc3 |
| ) |
|
|
| |
| class EncoderBlock(nn.Module): |
| def __init__(self, input_dim, embed_dim, num_heads, mlp_dim, dropout=0.1, drop_path=0.0, patch_drop=0.0, attention_scale=2., mask_threshold=0.05): |
| super().__init__() |
| self.mask_threshold = mask_threshold |
| self.self_attn = SoftMaskedMultiheadAttention( |
| embed_dim, num_heads, dropout=dropout, scale=attention_scale |
| ) |
| if attention_scale > 0: |
| self.linear_mask = nn.Linear(input_dim, 1) |
| self.patch_drop = nn.Dropout(patch_drop) |
| else: |
| self.linear_mask = None |
| if input_dim != embed_dim: |
| raise ValueError("embed_dim must equal atten_dim but {input_dim}!={embed_dim}") |
| else: |
| self.embed = nn.Identity() |
| self.project = nn.Identity() |
| self.norm1 = nn.LayerNorm(embed_dim) |
| self.norm2 = nn.LayerNorm(embed_dim) |
| |
| self.mlp = get_ffn(embed_dim, embed_dim, mlp_dim, dropout=dropout) |
| self.path_drop = StochasticDepth(drop_path, mode='row') |
| self.norm3 = nn.LayerNorm(input_dim) |
|
|
| def _reset_parameters(self): |
| for n, m in self.named_modules(): |
| if n.startswith('self_attn'): |
| continue |
| if isinstance(m, (nn.Linear, GroupedLinear)): |
| nn.init.trunc_normal_(m.weight.data, std=0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias.data) |
| nn.init.ones_(self.norm1.weight) |
| nn.init.zeros_(self.norm1.bias) |
| nn.init.ones_(self.norm2.weight) |
| nn.init.zeros_(self.norm2.bias) |
| nn.init.zeros_(self.norm3.weight) |
| nn.init.zeros_(self.norm3.bias) |
|
|
| def forward_common(self, x, mask, skip_masks=False): |
| |
| x1 = x |
| x = self.embed(x) |
| x = self.norm1(x) |
| |
| attn_output = self.self_attn(x, x, x, attn_mask=mask if not skip_masks else None, method="eager") |
| |
| x = x + self.path_drop(attn_output) |
| x = self.norm2(x) |
| |
| mlp_output = self.mlp(x) |
| |
| x = self.path_drop(self.project(x + mlp_output)) |
| x = self.norm3(x) |
| if mask is not None: |
| x = x * mask.unsqueeze(-1) |
| x = x1 + x |
| return x |
|
|
| def flash_forward(self, x, mask, skip_masks=False): |
| |
| |
|
|
| B, N, C = x.shape |
|
|
| x_res = x |
|
|
| binary_mask = mask >= self.mask_threshold |
| seq_lengths = binary_mask.sum(dim=1, dtype=torch.int32) |
| mean_len = seq_lengths.float().square().mean().sqrt().item() |
| max_len = seq_lengths.amax().item() |
| min_len = seq_lengths.amin().item() |
|
|
| |
| if not binary_mask.any(): |
| return x |
|
|
| |
| if ((mean_len / x.shape[1]) > 0.90): |
| x_sel = x.flatten(0, 1) |
| flat_idx = None |
|
|
| if not skip_masks: |
| sel_mask = mask.flatten(0, 1) |
| else: |
| sel_mask = None |
|
|
| cu_seqlens = torch.arange(0, (B + 1) * N, step=N, dtype=torch.int32, device=x.device) |
| elif max_len > 32: |
| |
| idx = binary_mask.nonzero(as_tuple=False) |
| b_idx = idx[:, 0] |
| t_idx = idx[:, 1] |
| flat_idx = b_idx * N + t_idx |
|
|
| |
| x_sel = x[b_idx, t_idx] |
|
|
| if not skip_masks: |
| sel_mask = mask[b_idx, t_idx] |
| else: |
| sel_mask = None |
|
|
| |
| cu_seqlens = torch.zeros(binary_mask.shape[0]+1, dtype=torch.int, device=binary_mask.device) |
| cu_seqlens[1:] = seq_lengths.cumsum(-1) |
| else: |
| |
| k = max_len |
|
|
| |
| top_vals, top_idx = mask.topk(k, dim=1, largest=True, sorted=False) |
| b_idx = torch.arange(B, device=mask.device)[:, None].expand_as(top_idx) |
| flat_idx = (b_idx * N + top_idx).reshape(-1) |
|
|
| gather_idx = top_idx.unsqueeze(-1).expand(-1, -1, C) |
| x_top = x.gather(1, gather_idx) |
|
|
| |
| x_sel = x_top.flatten(0, 1) |
|
|
| if not skip_masks: |
| sel_mask = top_vals.flatten(0, 1) |
| else: |
| sel_mask = None |
|
|
| cu_seqlens = torch.arange(0, (B + 1) * max_len, step=max_len, dtype=torch.int32, device=x.device) |
|
|
| cu_seqlens = cu_seqlens.to(torch.int32) |
|
|
| |
| x_sel = self.embed(x_sel) |
| x_sel = self.norm1(x_sel) |
|
|
| attn_output = self.self_attn( |
| x_sel, x_sel, x_sel, |
| cu_seq_q=cu_seqlens, |
| cu_seq_k=cu_seqlens, |
| max_q=max_len, |
| max_k=max_len, |
| attn_mask=None if skip_masks else sel_mask, |
| method="fa", |
| ) |
|
|
| x_sel = x_sel + self.path_drop(attn_output) |
| x_sel = self.norm2(x_sel) |
|
|
| mlp_output = self.mlp(x_sel) |
| x_sel = self.path_drop(self.project(x_sel + mlp_output)) |
| x_sel = self.norm3(x_sel) |
|
|
| if sel_mask is not None: |
| x_sel.mul_(sel_mask.unsqueeze(-1)) |
|
|
| |
| if flat_idx is None: |
| x_out = x_res + x_sel.view(*x_res.shape) |
| else: |
| B, N, C = x_res.shape |
| flat_out = x_res.reshape(B * N, C) |
| if torch.is_grad_enabled(): |
| flat_out = flat_out.clone() |
| flat_out.index_add_(0, flat_idx, x_sel) |
| x_out = flat_out.view(B, N, C) |
|
|
| return x_out |
|
|
| def get_groups(self, mask, full=False): |
| n_items, index = (mask != 0.0).sum(-1).cpu().sort(descending=True) |
| n_items, index = n_items.tolist(), index.tolist() |
| groups = [] |
| t = 1.0 if full else 1.2 |
| for ni, ii in zip(n_items, index): |
| if ni == 0: |
| break |
| if len(groups) == 0 or groups[-1][1] / ni > t: |
| groups.append(([], ni)) |
| groups[-1][0].append(ii) |
| return groups |
|
|
| def eager_forward(self, x, mask, full=False, skip_masks=False): |
| |
| mask_thresholded = mask * (mask >= self.mask_threshold) |
| |
| x_out = x.clone() |
| |
| groups = self.get_groups(mask_thresholded, full) |
| |
| for batch_indices, n_keep in groups: |
| x_sel = x[batch_indices] |
| mask_sel = mask_thresholded[batch_indices] |
| |
| topk_vals, topk_idx_unsorted = torch.topk(mask_sel, k=n_keep, dim=1, sorted=False) |
| topk_idx_sorted, _ = topk_idx_unsorted.sort(dim=1) |
| |
| idx_expanded = topk_idx_sorted.unsqueeze(-1).expand(-1, -1, x_sel.size(-1)) |
| X_topk = torch.gather(x_sel, dim=1, index=idx_expanded) |
| mask_topk = torch.gather(mask_sel, dim=1, index=topk_idx_sorted) |
| |
| results = self.forward_common(X_topk, mask_topk, skip_masks) |
| |
| x_sel_updated = x_sel.clone() |
| x_sel_updated = x_sel_updated.scatter(1, idx_expanded, results) |
| |
| x_out[batch_indices] = x_sel_updated |
| return x_out |
|
|
| def forward(self, x, full=False, skip_masks=False): |
| if self.linear_mask is not None: |
| attn_mask = self.patch_drop(self.linear_mask(x).sigmoid().squeeze(-1)) |
| else: |
| attn_mask = None |
| if not self.training and not attn_mask is None and self.mask_threshold >= 0: |
| if ( |
| HAS_VARLEN_FLASH_ATTENTION and |
| 'cuda' in x.device.type and |
| x.dtype in (torch.bfloat16, torch.float16) |
| ): |
| x = self.flash_forward(x, attn_mask, skip_masks) |
| else: |
| warnings.warn( |
| "Flash Attention requirements not met, falling back to eager attention.", |
| category=UserWarning, |
| stacklevel=2, |
| ) |
| x = self.eager_forward(x, attn_mask, full, skip_masks) |
| else: |
| x = self.forward_common(x, attn_mask, skip_masks) |
| return x, attn_mask |
|
|
|
|
| class VisionTransformer(nn.Module): |
| def __init__( |
| self, |
| image_size=256, |
| patch_size=16, |
| num_classes=1000, |
| embed_dim=768, |
| atten_dim=192, |
| depth=12, |
| num_heads=3, |
| mlp_dim=768, |
| channels=3, |
| dropout=0.1, |
| drop_path=0.1, |
| patch_drop=0.1, |
| attention_scale=2., |
| mask_threshold=0.05, |
| use_distil_token=False |
| ): |
| super().__init__() |
| assert image_size % patch_size == 0, "Image dimensions must be divisible by the patch size." |
| num_patches = (image_size // patch_size) ** 2 |
|
|
| |
| self.patch_embed = nn.Conv2d( |
| in_channels=channels, |
| out_channels=embed_dim, |
| kernel_size=patch_size, |
| stride=patch_size |
| ) |
|
|
| |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|
|
| |
| self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1 + (1 if use_distil_token else 0), embed_dim)) |
|
|
| self.dropout = nn.Dropout(dropout) |
|
|
| |
| self.encoder_layers = nn.ModuleList([ |
| EncoderBlock( |
| embed_dim, atten_dim, |
| num_heads, mlp_dim, |
| dropout, drop_path * i / (depth - 1), |
| patch_drop=patch_drop, |
| attention_scale=attention_scale, |
| mask_threshold=mask_threshold, |
| ) |
| for i in range(depth) |
| ]) |
|
|
| |
| self.post_norm = nn.LayerNorm(embed_dim) |
| self.head = nn.Linear(embed_dim, num_classes) |
| if use_distil_token: |
| self.dis_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
| self.dis_head = nn.Linear(embed_dim, num_classes) |
| else: |
| self.dis_token = None |
|
|
| |
| self._init_weights() |
|
|
| def _init_weights(self): |
| for n, m in self.named_modules(): |
| if n.startswith('encoder_layers'): |
| continue |
| if isinstance(m, (nn.Linear, nn.Conv2d)): |
| nn.init.trunc_normal_(m.weight.data, std=0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias.data) |
| if isinstance(m, nn.LayerNorm): |
| nn.init.ones_(m.weight.data) |
| nn.init.zeros_(m.bias.data) |
| nn.init.trunc_normal_(self.pos_embed, std=0.02) |
| if self.cls_token is not None: |
| nn.init.trunc_normal_(self.cls_token, std=0.02) |
| if self.dis_token is not None: |
| nn.init.trunc_normal_(self.dis_token, std=0.02) |
|
|
| def forward_features( |
| self, |
| pixel_values, |
| full=False, |
| output_hidden_states=False, |
| skip_masks=False |
| ): |
| batch_size = pixel_values.size(0) |
| hidden_states = [] |
|
|
| |
| x = self.patch_embed(pixel_values) |
| x = x.flatten(2).transpose(1, 2) |
|
|
| |
| if self.dis_token is not None: |
| dis_tokens = self.dis_token.expand(batch_size, -1, -1) |
| x = torch.cat((dis_tokens, x), dim=1) |
|
|
| |
| cls_tokens = self.cls_token.expand(batch_size, -1, -1) |
| x = torch.cat((cls_tokens, x), dim=1) |
|
|
| |
| x = x + self.pos_embed |
| x = self.dropout(x) |
|
|
| masks = [] |
|
|
| for layer in self.encoder_layers: |
| x, mask = layer(x, full, skip_masks=skip_masks) |
|
|
| if output_hidden_states: |
| hidden_states.append(x) |
|
|
| if mask is not None: |
| masks.append(mask) |
|
|
| x = self.post_norm(x) |
|
|
| if output_hidden_states: |
| hidden_states = tuple(hidden_states) |
| else: |
| hidden_states = None |
|
|
| if len(masks) > 0: |
| masks = tuple(masks) |
| else: |
| masks = None |
|
|
| return x, hidden_states, masks |
|
|
|
|
| def forward_classifier(self, hidden_states): |
| cls_token = hidden_states[:, 0] |
| logits = self.head(cls_token) |
|
|
| dis_logits = None |
| if self.dis_token is not None: |
| dis_cls_token = hidden_states[:, 1] |
| dis_logits = self.dis_head(dis_cls_token) |
|
|
| if not self.training: |
| logits = (logits + dis_logits) / 2 |
|
|
| return logits, dis_logits |
|
|
| def forward(self, x, full=False, skip_masks=False): |
| last_hidden_states, hidden_states, masks = self.forward_features(x, full, skip_masks=skip_masks) |
| logits, dis_logits = self.forward_classifier(last_hidden_states) |
| return logits, dis_logits, masks |
|
|