# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # #!/usr/bin/env python3 import math import torch import torch.nn as nn import torch.nn.functional as F class AdaptiveMask(nn.Module): """Soft masking function for adaptive size. It masks out the last K values of an input. The masking value goes from 1 to 0 gradually, so K can be learned with back-propagation. Args: max_size: maximum size (i.e. input dimension) ramp_size: size of the ramp going from 0 to 1 init_val: initial size proportion not to be masked out shape: learn multiple sizes independent of each other """ def __init__(self, max_size, ramp_size, init_val=0, shape=(1,)): nn.Module.__init__(self) self._max_size = max_size self._ramp_size = ramp_size self.current_val = nn.Parameter(torch.zeros(*shape) + init_val) mask_template = torch.linspace(1 - max_size, 0, steps=max_size) self.register_buffer('mask_template', mask_template) def forward(self, x): mask = self.mask_template + self.current_val * self._max_size mask = mask / self._ramp_size + 1 mask=torch.where(mask>0.5,mask,0) mask = mask.clamp(0, 1) if x.size(-1) < self._max_size: # the input could have been trimmed beforehand to save computation mask = mask[:, :, -x.size(-1):] x = x * mask return x def get_current_max_size(self, include_ramp=True): current_size = math.ceil(self.current_val.max().item() * self._max_size) if include_ramp: current_size += self._ramp_size current_size = max(0, min(self._max_size, current_size)) return current_size def get_current_avg_size(self, include_ramp=True): current_size = math.ceil(self.current_val.mean().item() * self._max_size) if include_ramp: current_size += self._ramp_size current_size = max(0, min(self._max_size, current_size)) return current_size def clamp_param(self): """this need to be called after each update""" self.current_val.data.clamp_(0, 1) class AdaptiveSpan(nn.Module): """Adaptive attention span for Transformerself. This module learns an attention span length from data for each self-attention head. Args: attn_span: maximum attention span adapt_span_loss: loss coefficient for the span length adapt_span_ramp: length of the masking ramp adapt_span_init: initial size ratio adapt_span_cache: adapt cache size to reduce memory usage """ def __init__(self, attn_span, adapt_span_loss, adapt_span_ramp, adapt_span_init, adapt_span_cache, nb_heads, **kargs): nn.Module.__init__(self) self._adapt_cache = adapt_span_cache self._max_span = attn_span self._loss_coeff = adapt_span_loss self._nb_heads = nb_heads self._mask = AdaptiveMask(max_size=self._max_span, ramp_size=adapt_span_ramp, init_val=adapt_span_init, shape=(nb_heads, 1, 1)) def forward(self, attn, normalize=True): """mask attention with the right span""" # batch and head dimensions are merged together, so separate them first B = attn.size(0) # batch size M = attn.size(1) # block size attn = attn.reshape(B // self._nb_heads, self._nb_heads, M, -1) attn = self._mask(attn) if normalize: attn = attn / (attn.sum(-1, keepdim=True) + 1e-8) # normalize so sum is 1 attn = attn.view(B, M, -1) return attn def get_trim_len(self): """how much of memory can be trimmed to reduce computation""" L = self._max_span trim_len = min(L - 1, L - self._mask.get_current_max_size()) # too fine granularity might be bad for the memory management trim_len = math.floor(trim_len / 64) * 64 return trim_len def trim_memory(self, query, key, value, key_pe): """trim out unnecessary memory beforehand to reduce computation""" trim_len = self.get_trim_len() cache_size = key.size(1) - query.size(1) trim_len_cache = trim_len - (self._max_span - cache_size) if trim_len_cache > 0: key = key[:, trim_len_cache:, :] value = value[:, trim_len_cache:, :] elif trim_len_cache < 0: # cache is too short! this happens when validation resumes # after a lot of updates. key = F.pad(key, [0, 0, -trim_len_cache, 0]) value = F.pad(value, [0, 0, -trim_len_cache, 0]) if trim_len > 0: if key_pe is not None: key_pe = key_pe[:, :, trim_len:] return key, value, key_pe def get_cache_size(self): """determine how long the cache should be""" if self._adapt_cache: trim_len = self.get_trim_len() # give a buffer of 64 steps since a span might increase # in future updates return min(self._max_span, self._max_span - trim_len + 64) else: return self._max_span def get_loss(self): """a loss term for regularizing the span length""" return self._loss_coeff * self._max_span * self._mask.current_val.mean() def get_current_max_span(self): return self._mask.get_current_max_size() def get_current_avg_span(self): return self._mask.get_current_avg_size() def clamp_param(self): self._mask.clamp_param()