| |
| |
| |
| |
| |
|
|
| """ |
| Transformer model, with streaming support, xformer attention support |
| and easy causal attention with a potentially finite receptive field. |
| |
| See `StreamingTransformer` for more information. |
| |
| Unlike regular PyTorch Transformer, we make the hard choice that batches are first. |
| """ |
|
|
| import typing as tp |
|
|
| from einops import rearrange |
| import torch |
| import torch.nn as nn |
| from torch.nn import functional as F |
| from torch.utils.checkpoint import checkpoint as torch_checkpoint |
| from xformers import ops |
|
|
| from .rope import RotaryEmbedding |
| from .streaming import StreamingModule |
|
|
| _efficient_attention_backend: str = 'torch' |
|
|
|
|
| def set_efficient_attention_backend(backend: str = 'torch'): |
| |
| global _efficient_attention_backend |
| assert _efficient_attention_backend in ['xformers', 'torch'] |
| _efficient_attention_backend = backend |
|
|
|
|
| def _get_attention_time_dimension() -> int: |
| if _efficient_attention_backend == 'torch': |
| return 2 |
| else: |
| return 1 |
|
|
|
|
| def _is_profiled() -> bool: |
| |
| try: |
| from xformers.profiler import profiler |
| except ImportError: |
| return False |
| return profiler._Profiler._CURRENT_PROFILER is not None |
|
|
|
|
| def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module: |
| """Create normalization module for transformer encoder layer. |
| |
| Args: |
| norm_type (str): Normalization method. |
| dim (int): Dimension of the normalized layer. |
| **kwargs (dict): Additional parameters for normalization layer. |
| Returns: |
| nn.Module: Normalization module. |
| """ |
| if norm_type == 'layer_norm': |
| return nn.LayerNorm(dim, eps=1e-5, **kwargs) |
| else: |
| raise ValueError(f"Unknown norm type: {norm_type}") |
|
|
|
|
| def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000, |
| dtype: torch.dtype = torch.float32) -> torch.Tensor: |
| """Create sinusoidal positional embedding, with shape `[B, T, C]`. |
| |
| Args: |
| positions (torch.Tensor): LongTensor of positions. |
| dim (int): Dimension of the embedding. |
| max_period (float): Maximum period of the cosine/sine functions. |
| dtype (torch.dtype or str): dtype to use to generate the embedding. |
| Returns: |
| torch.Tensor: Sinusoidal positional embedding. |
| """ |
| |
| assert dim % 2 == 0 |
| half_dim = dim // 2 |
| positions = positions.to(dtype) |
| adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1) |
| max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) |
| phase = positions / (max_period_tensor ** (adim / (half_dim - 1))) |
| return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) |
|
|
|
|
| def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers""" |
| if n_rep == 1: |
| return x |
| if _efficient_attention_backend == 'torch': |
| bs, n_kv_heads, slen, head_dim = x.shape |
| return ( |
| x[:, :, None, :, :] |
| .expand(bs, n_kv_heads, n_rep, slen, head_dim) |
| .reshape(bs, n_kv_heads * n_rep, slen, head_dim) |
| ) |
| else: |
| bs, slen, n_kv_heads, head_dim = x.shape |
| return ( |
| x[:, :, :, None, :] |
| .expand(bs, slen, n_kv_heads, n_rep, head_dim) |
| .reshape(bs, slen, n_kv_heads * n_rep, head_dim) |
| ) |
|
|
|
|
| class LayerScale(nn.Module): |
| """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). |
| This rescales diagonaly the residual outputs close to 0, with a learnt scale. |
| |
| Args: |
| channels (int): Number of channels. |
| init (float): Initial scale. |
| channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`. |
| device (torch.device or None): Device on which to initialize the module. |
| dtype (torch.dtype or None): dtype to use to initialize the module. |
| """ |
| def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True, |
| device=None, dtype=None): |
| super().__init__() |
| self.channel_last = channel_last |
| self.scale = nn.Parameter( |
| torch.full((channels,), init, |
| requires_grad=True, device=device, dtype=dtype)) |
|
|
| def forward(self, x: torch.Tensor): |
| if self.channel_last: |
| return self.scale * x |
| else: |
| return self.scale[:, None] * x |
|
|
|
|
| class StreamingMultiheadAttention(StreamingModule): |
| """Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation. |
| |
| Args: |
| embed_dim (int): Dimension to project to. |
| num_heads (int): Number of heads. |
| dropout (float): Dropout level. |
| bias (bool): Use bias in projections. |
| causal (bool): Causal mask applied automatically. |
| past_context (int or None): Receptive field for the causal mask, infinite if None. |
| custom (bool): Use custom MHA implementation, for testing / benchmarking. |
| memory_efficient (bool): Use xformers based memory efficient attention. |
| attention_as_float32 (bool): Perform the attention as float32 |
| (especially important with memory_efficient as autocast won't do this automatically). |
| rope (`RotaryEmbedding` or None): Rope embedding to use. |
| cross_attention: Should be true when used as a cross attention. |
| All keys and values must be available at once, streaming is only for the queries. |
| Cannot be used with `causal` or `rope` (as it wouldn't make sens to |
| intepret the time steps in the keys relative to those in the queries). |
| safe_streaming (bool): Bug fix, will go away with xformers update. |
| qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product. |
| kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads). |
| This will lead to faster decoding time on A100 or other GPUs with tensorcore. |
| device (torch.device or None): Sevice on which to initialize. |
| dtype (torch.dtype or None): dtype to use. |
| """ |
| def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, |
| causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False, |
| memory_efficient: bool = False, attention_as_float32: bool = False, |
| rope: tp.Optional[RotaryEmbedding] = None, cross_attention: bool = False, |
| safe_streaming: bool = True, qk_layer_norm: bool = False, kv_repeat: int = 1, |
| device=None, dtype=None): |
| super().__init__() |
| factory_kwargs = {'device': device, 'dtype': dtype} |
| if past_context is not None: |
| assert causal |
|
|
| self.embed_dim = embed_dim |
| self.causal = causal |
| self.past_context = past_context |
| self.memory_efficient = memory_efficient |
| self.attention_as_float32 = attention_as_float32 |
| self.rope = rope |
| self.cross_attention = cross_attention |
| self.safe_streaming = safe_streaming |
| self.num_heads = num_heads |
| self.dropout = dropout |
| self.kv_repeat = kv_repeat |
| if cross_attention: |
| assert not causal, "Causal cannot work with cross attention." |
| assert rope is None, "Rope cannot work with cross attention." |
|
|
| if memory_efficient: |
| _verify_xformers_memory_efficient_compat() |
|
|
| self.custom = _is_custom(custom, memory_efficient) |
| if self.custom: |
| out_dim = embed_dim |
| assert num_heads % kv_repeat == 0 |
| assert not cross_attention or kv_repeat == 1 |
| num_kv = num_heads // kv_repeat |
| kv_dim = (embed_dim // num_heads) * num_kv |
| out_dim += 2 * kv_dim |
| in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs) |
| |
| self.in_proj_weight = in_proj.weight |
| self.in_proj_bias = in_proj.bias |
| if bias: |
| self.in_proj_bias.data.zero_() |
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) |
| if bias: |
| self.out_proj.bias.data.zero_() |
| else: |
| assert not qk_layer_norm |
| assert kv_repeat == 1 |
| self.mha = nn.MultiheadAttention( |
| embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True, |
| **factory_kwargs) |
| self.qk_layer_norm = qk_layer_norm |
| if qk_layer_norm: |
| assert self.custom |
| assert kv_repeat == 1 |
| ln_dim = embed_dim |
| self.q_layer_norm = nn.LayerNorm(ln_dim) |
| self.k_layer_norm = nn.LayerNorm(ln_dim) |
|
|
| def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): |
| if not self.custom: |
| |
| keys = [n for n, _ in self.mha.named_parameters()] |
| for key in keys: |
| if prefix + key in state_dict: |
| state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key) |
| super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
| def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype): |
| |
| |
| |
| time_dim = _get_attention_time_dimension() |
| if self.memory_efficient: |
| from xformers.ops import LowerTriangularMask |
| if current_steps == 1: |
| |
| return None |
| elif 'past_keys' in self._streaming_state: |
| raise RuntimeError('Not supported at the moment') |
| else: |
| |
| return LowerTriangularMask() |
| if self._streaming_state: |
| past_keys = self._streaming_state['past_keys'] |
| past_steps = past_keys.shape[time_dim] |
| else: |
| past_steps = 0 |
|
|
| queries_pos = torch.arange( |
| past_steps, current_steps + past_steps, device=device).view(-1, 1) |
| keys_pos = torch.arange(past_steps + current_steps, device=device).view(1, -1) |
| delta = queries_pos - keys_pos |
| valid = delta >= 0 |
| if self.past_context is not None: |
| valid &= (delta <= self.past_context) |
| return torch.where( |
| valid, |
| torch.zeros([], device=device, dtype=dtype), |
| torch.full([], float('-inf'), device=device, dtype=dtype)) |
|
|
| def _complete_kv(self, k, v): |
| time_dim = _get_attention_time_dimension() |
| if self.cross_attention: |
| |
| |
| |
| return k, v |
| |
| if self._streaming_state: |
| pk = self._streaming_state['past_keys'] |
| nk = torch.cat([pk, k], dim=time_dim) |
| if v is k: |
| nv = nk |
| else: |
| pv = self._streaming_state['past_values'] |
| nv = torch.cat([pv, v], dim=time_dim) |
| else: |
| nk = k |
| nv = v |
|
|
| assert nk.shape[time_dim] == nv.shape[time_dim] |
| offset = 0 |
| if self.past_context is not None: |
| offset = max(0, nk.shape[time_dim] - self.past_context) |
| if self._is_streaming: |
| self._streaming_state['past_keys'] = nk[:, offset:] |
| if v is not k: |
| self._streaming_state['past_values'] = nv[:, offset:] |
| if 'offset' in self._streaming_state: |
| self._streaming_state['offset'] += offset |
| else: |
| self._streaming_state['offset'] = torch.tensor(0) |
| return nk, nv |
|
|
| def _apply_rope(self, query: torch.Tensor, key: torch.Tensor): |
| |
| assert _efficient_attention_backend == 'xformers', 'Rope not supported with torch attn.' |
| |
| assert self.rope is not None |
| if 'past_keys' in self._streaming_state: |
| past_keys_offset = self._streaming_state['past_keys'].shape[1] |
| else: |
| past_keys_offset = 0 |
| if 'offset' in self._streaming_state: |
| past_context_offset = int(self._streaming_state['offset'].item()) |
| else: |
| past_context_offset = 0 |
| streaming_offset = past_context_offset + past_keys_offset |
| return self.rope.rotate_qk(query, key, start=streaming_offset) |
|
|
| def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, |
| key_padding_mask=None, need_weights=False, attn_mask=None, |
| average_attn_weights=True, is_causal=False): |
| assert attn_mask is None |
| assert not is_causal, ("new param added in torch 2.0.1 not supported, " |
| "use the causal args in the constructor.") |
|
|
| time_dim = _get_attention_time_dimension() |
| if time_dim == 2: |
| layout = "b h t d" |
| else: |
| layout = "b t h d" |
| dtype = query.dtype |
| if self._is_streaming: |
| assert self.causal or self.cross_attention, \ |
| "Streaming only available for causal or cross attention" |
|
|
| if self.causal: |
| |
| assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value" |
| assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value" |
| attn_mask = self._get_mask(query.shape[1], query.device, query.dtype) |
|
|
| if self.custom: |
| |
| assert need_weights is False |
| assert key_padding_mask is None |
| if self.cross_attention: |
| |
| |
| dim = self.in_proj_weight.shape[0] // 3 |
| if self.in_proj_bias is None: |
| bias_q, bias_k, bias_v = None, None, None |
| else: |
| bias_q = self.in_proj_bias[:dim] |
| bias_k = self.in_proj_bias[dim: 2 * dim] |
| bias_v = self.in_proj_bias[2 * dim:] |
| q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q) |
| |
| k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k) |
| v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v) |
| if self.qk_layer_norm is True: |
| q = self.q_layer_norm(q) |
| k = self.k_layer_norm(k) |
| q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]] |
| else: |
| if not _is_profiled(): |
| |
| assert query is key, "specialized implementation" |
| assert value is key, "specialized implementation" |
| projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias) |
| if self.kv_repeat == 1: |
| if time_dim == 2: |
| bound_layout = "b h p t d" |
| else: |
| bound_layout = "b t p h d" |
| packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads) |
| q, k, v = ops.unbind(packed, dim=2) |
| else: |
| embed_dim = self.embed_dim |
| per_head_dim = (embed_dim // self.num_heads) |
| kv_heads = self.num_heads // self.kv_repeat |
| q = projected[:, :, :embed_dim] |
| start = embed_dim |
| end = start + per_head_dim * kv_heads |
| k = projected[:, :, start: end] |
| v = projected[:, :, end:] |
| q = rearrange(q, f"b t (h d) -> {layout}", h=self.num_heads) |
| k = rearrange(k, f"b t (h d) -> {layout}", h=kv_heads) |
| v = rearrange(v, f"b t (h d) -> {layout}", h=kv_heads) |
|
|
| if self.qk_layer_norm is True: |
| assert self.kv_repeat == 1 |
| q, k = [rearrange(x, f"{layout} -> b t (h d)") for x in [q, k]] |
| q = self.q_layer_norm(q) |
| k = self.k_layer_norm(k) |
| q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]] |
| if self.rope: |
| q, k = self._apply_rope(q, k) |
| k, v = self._complete_kv(k, v) |
| if self.kv_repeat > 1: |
| k = expand_repeated_kv(k, self.kv_repeat) |
| v = expand_repeated_kv(v, self.kv_repeat) |
| if self.attention_as_float32: |
| q, k, v = [x.float() for x in [q, k, v]] |
| if self.memory_efficient: |
| p = self.dropout if self.training else 0 |
| if _efficient_attention_backend == 'torch': |
| x = torch.nn.functional.scaled_dot_product_attention( |
| q, k, v, is_causal=attn_mask is not None, dropout_p=p) |
| else: |
| x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p) |
| else: |
| |
| |
| |
| |
| |
| |
| |
| q = q / q.shape[-1] ** 0.5 |
| key_layout = layout.replace('t', 'k') |
| query_layout = layout |
| if self._is_streaming and self.safe_streaming and q.device.type == 'cuda': |
| with torch.autocast(device_type=q.device.type, dtype=torch.float32): |
| pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k) |
| else: |
| pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k) |
| if attn_mask is not None: |
| pre_w = pre_w + attn_mask |
| w = torch.softmax(pre_w, dim=-1) |
| w = F.dropout(w, self.dropout, training=self.training).to(v) |
| |
| x = torch.einsum(f"b h t k, {key_layout} -> {layout}", w, v) |
| x = x.to(dtype) |
| x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads) |
| x = self.out_proj(x) |
| else: |
| key, value = self._complete_kv(key, value) |
| if self.attention_as_float32: |
| query, key, value = [x.float() for x in [query, key, value]] |
| x, _ = self.mha( |
| query, key, value, key_padding_mask, |
| need_weights, attn_mask, average_attn_weights) |
| x = x.to(dtype) |
|
|
| return x, None |
|
|
|
|
| class StreamingTransformerLayer(nn.TransformerEncoderLayer): |
| """TransformerLayer with Streaming / Causal support. |
| This also integrates cross_attention, when passing `cross_attention=True`, |
| rather than having two separate classes like in PyTorch. |
| |
| Args: |
| d_model (int): Dimension of the data. |
| num_heads (int): Number of heads. |
| dim_feedforward (int): Intermediate dimension of FF module. |
| dropout (float): Dropout both for MHA and FF. |
| bias_ff (bool): Use bias for FF. |
| bias_attn (bool): Use bias for MHA. |
| causal (bool): Causal mask applied automatically. |
| past_context (int or None): Receptive field for the causal mask, infinite if None. |
| custom (bool): Use custom MHA implementation, for testing / benchmarking. |
| memory_efficient (bool): Use xformers based memory efficient attention. |
| attention_as_float32 (bool): Perform the attention as float32 |
| (especially important with memory_efficient as autocast won't do this automatically). |
| qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product in attention. |
| qk_layer_norm_cross (bool): Same for the cross attention. |
| cross_attention (bool): If True, expect to get secondary input for cross-attention. |
| Cross attention will use the default MHA, as it typically won't require |
| special treatment. |
| layer_scale (float or None): If not None, LayerScale will be used with |
| the given value as initial scale. |
| rope (`RotaryEmbedding` or None): Rope embedding to use. |
| attention_dropout (float or None): If not None, separate the value of the dimension dropout |
| in FFN and of the attention dropout. |
| kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads). |
| This will lead to faster decoding time on A100 or other GPUs with tensorcore. |
| device (torch.device or None): Device on which to initialize. |
| dtype (torch.dtype or None): dtype to use. |
| **kwargs: See `nn.TransformerEncoderLayer`. |
| """ |
| def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1, |
| bias_ff: bool = True, bias_attn: bool = True, causal: bool = False, |
| past_context: tp.Optional[int] = None, custom: bool = False, |
| memory_efficient: bool = False, attention_as_float32: bool = False, |
| qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False, |
| cross_attention: bool = False, layer_scale: tp.Optional[float] = None, |
| rope: tp.Optional[RotaryEmbedding] = None, attention_dropout: tp.Optional[float] = None, |
| kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs): |
| super().__init__(d_model, num_heads, dim_feedforward, dropout, |
| device=device, dtype=dtype, batch_first=True, **kwargs) |
| factory_kwargs = {'device': device, 'dtype': dtype} |
| |
| attn_kwargs: tp.Dict[str, tp.Any] = { |
| 'embed_dim': d_model, |
| 'num_heads': num_heads, |
| 'dropout': dropout if attention_dropout is None else attention_dropout, |
| 'bias': bias_attn, |
| 'custom': custom, |
| 'memory_efficient': memory_efficient, |
| 'attention_as_float32': attention_as_float32, |
| } |
| self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention( |
| causal=causal, past_context=past_context, rope=rope, qk_layer_norm=qk_layer_norm, |
| kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) |
| |
| self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs) |
| self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs) |
|
|
| self.layer_scale_1: nn.Module |
| self.layer_scale_2: nn.Module |
| if layer_scale is None: |
| self.layer_scale_1 = nn.Identity() |
| self.layer_scale_2 = nn.Identity() |
| else: |
| self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs) |
| self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs) |
|
|
| self.cross_attention: tp.Optional[nn.Module] = None |
| if cross_attention: |
| self.cross_attention = StreamingMultiheadAttention( |
| cross_attention=True, qk_layer_norm=qk_layer_norm_cross, |
| **attn_kwargs, **factory_kwargs) |
| |
| self.dropout_cross = nn.Dropout(dropout) |
| |
| self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs) |
| self.layer_scale_cross: nn.Module |
| if layer_scale is None: |
| self.layer_scale_cross = nn.Identity() |
| else: |
| self.layer_scale_cross = LayerScale(d_model, layer_scale, **factory_kwargs) |
| self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) |
| self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) |
|
|
| def _cross_attention_block(self, src: torch.Tensor, |
| cross_attention_src: torch.Tensor) -> torch.Tensor: |
| assert self.cross_attention is not None |
| |
| x = self.cross_attention( |
| src, cross_attention_src, cross_attention_src, need_weights=False)[0] |
| return self.dropout_cross(x) |
|
|
| def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tensor] = None, |
| src_key_padding_mask: tp.Optional[torch.Tensor] = None, |
| cross_attention_src: tp.Optional[torch.Tensor] = None): |
| if self.cross_attention is None: |
| assert cross_attention_src is None |
| else: |
| assert cross_attention_src is not None |
| x = src |
| if self.norm_first: |
| x = x + self.layer_scale_1( |
| self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)) |
| if cross_attention_src is not None: |
| x = x + self.layer_scale_cross( |
| self._cross_attention_block( |
| self.norm_cross(x), cross_attention_src)) |
| x = x + self.layer_scale_2(self._ff_block(self.norm2(x))) |
| else: |
| x = self.norm1(x + self.layer_scale_1( |
| self._sa_block(x, src_mask, src_key_padding_mask))) |
| if cross_attention_src is not None: |
| x = self.norm_cross( |
| x + self.layer_scale_cross( |
| self._cross_attention_block(src, cross_attention_src))) |
| x = self.norm2(x + self.layer_scale_2(self._ff_block(x))) |
| return x |
|
|
|
|
| class StreamingTransformer(StreamingModule): |
| """Transformer with Streaming / Causal support. |
| |
| Args: |
| d_model (int): Dimension of the data. |
| num_heads (int): Number of heads. |
| dim_feedforward (int): Intermediate dimension of FF module. |
| dropout (float): Dropout both for MHA and FF. |
| bias_ff (bool): Use bias for FF. |
| bias_attn (bool): Use bias for MHA. |
| causal (bool): Causal mask applied automatically. |
| past_context (int or None): Receptive field for the causal mask, infinite if None. |
| custom (bool): Use custom MHA implementation, for testing / benchmarking. |
| memory_efficient (bool): Use xformers based memory efficient attention. |
| attention_as_float32 (bool): Perform the attention as float32 |
| (especially important with memory_efficient as autocast won't do this automatically). |
| cross_attention (bool): If True, expect to get secondary input for cross-attention. |
| layer_scale (float or None): If not None, LayerScale will be used |
| with the given value as initial scale. |
| positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope). |
| max_period (float): Maximum period of the time embedding. |
| positional_scale (float): Scale of positional embedding, set to 0 to deactivate. |
| xpos (bool): Apply xpos exponential decay to positional embedding (rope only). |
| lr (float or None): learning rate override through the `make_optim_group` API. |
| weight_decay (float or None): Weight_decay override through the `make_optim_group` API. |
| layer_class: (subclass of `StreamingTransformerLayer): class to use |
| to initialize the layers, allowing further customization outside of Audiocraft. |
| checkpointing (str): Checkpointing strategy to reduce memory usage. |
| No checkpointing if set to 'none'. Per layer checkpointing using PyTorch |
| if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice, |
| minimal memory usage, but maximal runtime). Finally, `xformers_default` provide |
| a policy for opting-out some operations of the checkpointing like |
| linear layers and attention, providing a middle ground between speed and memory. |
| device (torch.device or None): Device on which to initialize. |
| dtype (torch.dtype or None): dtype to use. |
| **kwargs: See `nn.TransformerEncoderLayer`. |
| """ |
| def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048, |
| dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True, |
| causal: bool = False, past_context: tp.Optional[int] = None, |
| custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False, |
| cross_attention: bool = False, layer_scale: tp.Optional[float] = None, |
| positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1., |
| xpos: bool = False, lr: tp.Optional[float] = None, weight_decay: tp.Optional[float] = None, |
| layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer, |
| checkpointing: str = 'none', device=None, dtype=None, **kwargs): |
| super().__init__() |
| assert d_model % num_heads == 0 |
|
|
| self.positional_embedding = positional_embedding |
| self.max_period = max_period |
| self.positional_scale = positional_scale |
| self.weight_decay = weight_decay |
| self.lr = lr |
|
|
| assert positional_embedding in ['sin', 'rope', 'sin_rope'] |
| self.rope: tp.Optional[RotaryEmbedding] = None |
| if self.positional_embedding in ['rope', 'sin_rope']: |
| assert _is_custom(custom, memory_efficient) |
| self.rope = RotaryEmbedding(d_model // num_heads, max_period=max_period, |
| xpos=xpos, scale=positional_scale, device=device) |
|
|
| self.checkpointing = checkpointing |
|
|
| assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm'] |
| if self.checkpointing.startswith('xformers'): |
| _verify_xformers_internal_compat() |
|
|
| self.layers = nn.ModuleList() |
| for idx in range(num_layers): |
| self.layers.append( |
| layer_class( |
| d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward, |
| dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn, |
| causal=causal, past_context=past_context, custom=custom, |
| memory_efficient=memory_efficient, attention_as_float32=attention_as_float32, |
| cross_attention=cross_attention, layer_scale=layer_scale, rope=self.rope, |
| device=device, dtype=dtype, **kwargs)) |
|
|
| if self.checkpointing != 'none': |
| for layer in self.layers: |
| |
| |
| layer._magma_checkpointed = True |
| assert layer.layer_drop == 0., "Need further checking" |
|
|
| def _apply_layer(self, layer, *args, **kwargs): |
| method = self.checkpointing |
| if method == 'none': |
| return layer(*args, **kwargs) |
| elif method == 'torch': |
| return torch_checkpoint(layer, *args, use_reentrant=False, **kwargs) |
| elif method.startswith('xformers'): |
| from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy |
| if method == 'xformers_default': |
| |
| |
| allow_list = [ |
| "xformers.efficient_attention_forward_cutlass.default", |
| "xformers_flash.flash_fwd.default", |
| "aten.addmm.default", |
| "aten.mm.default", |
| ] |
| elif method == 'xformers_mm': |
| |
| |
| allow_list = [ |
| "aten.addmm.default", |
| "aten.mm.default", |
| ] |
| else: |
| raise ValueError(f"xformers checkpointing xformers policy {method} is not known.") |
| policy_fn = _get_default_policy(allow_list) |
| return checkpoint(layer, *args, policy_fn=policy_fn, **kwargs) |
| else: |
| raise ValueError(f"Checkpointing method {method} is unknown.") |
|
|
| def forward(self, x: torch.Tensor, *args, **kwargs): |
| B, T, C = x.shape |
|
|
| if 'offsets' in self._streaming_state: |
| offsets = self._streaming_state['offsets'] |
| else: |
| offsets = torch.zeros(B, dtype=torch.long, device=x.device) |
|
|
| if self.positional_embedding in ['sin', 'sin_rope']: |
| positions = torch.arange(T, device=x.device).view(1, -1, 1) |
| positions = positions + offsets.view(-1, 1, 1) |
| pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype) |
| x = x + self.positional_scale * pos_emb |
|
|
| for layer in self.layers: |
| x = self._apply_layer(layer, x, *args, **kwargs) |
|
|
| if self._is_streaming: |
| self._streaming_state['offsets'] = offsets + T |
|
|
| return x |
|
|
| def make_optim_group(self): |
| group = {"params": list(self.parameters())} |
| if self.lr is not None: |
| group["lr"] = self.lr |
| if self.weight_decay is not None: |
| group["weight_decay"] = self.weight_decay |
| return group |
|
|
|
|
| |
|
|
| def _verify_xformers_memory_efficient_compat(): |
| try: |
| from xformers.ops import memory_efficient_attention, LowerTriangularMask |
| except ImportError: |
| raise ImportError( |
| "xformers is not installed. Please install it and try again.\n" |
| "To install on AWS and Azure, run \n" |
| "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n" |
| "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n" |
| "To install on FAIR Cluster, run \n" |
| "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n" |
| "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n") |
|
|
|
|
| def _verify_xformers_internal_compat(): |
| try: |
| from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy |
| except ImportError: |
| raise ImportError( |
| "Francisco's fairinternal xformers is not installed. Please install it and try again.\n" |
| "To install on AWS and Azure, run \n" |
| "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n" |
| "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n" |
| "To install on FAIR Cluster, run \n" |
| "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n" |
| "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n") |
|
|
|
|
| def _is_custom(custom: bool, memory_efficient: bool): |
| return custom or memory_efficient |
|
|