Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) Meta Platforms, Inc. and affiliates. | |
| All rights reserved. | |
| This source code is licensed under the license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| import math | |
| from typing import Any, Callable, List, Optional, Union | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange | |
| from torch import Tensor | |
| from torch.nn import functional as F | |
| def generate_causal_mask(source_length, target_length, device="cpu"): | |
| if source_length == target_length: | |
| mask = ( | |
| torch.triu(torch.ones(target_length, source_length, device=device)) == 1 | |
| ).transpose(0, 1) | |
| else: | |
| mask = torch.zeros(target_length, source_length, device=device) | |
| idx = torch.linspace(0, source_length, target_length + 1)[1:].round().long() | |
| for i in range(target_length): | |
| mask[i, 0 : idx[i]] = 1 | |
| return ( | |
| mask.float() | |
| .masked_fill(mask == 0, float("-inf")) | |
| .masked_fill(mask == 1, float(0.0)) | |
| ) | |
| class TransformerEncoderLayerRotary(nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| nhead: int, | |
| dim_feedforward: int = 2048, | |
| dropout: float = 0.1, | |
| activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, | |
| layer_norm_eps: float = 1e-5, | |
| batch_first: bool = False, | |
| norm_first: bool = True, | |
| rotary=None, | |
| ) -> None: | |
| super().__init__() | |
| self.self_attn = nn.MultiheadAttention( | |
| d_model, nhead, dropout=dropout, batch_first=batch_first | |
| ) | |
| # Implementation of Feedforward model | |
| self.linear1 = nn.Linear(d_model, dim_feedforward) | |
| self.dropout = nn.Dropout(dropout) | |
| self.linear2 = nn.Linear(dim_feedforward, d_model) | |
| self.norm_first = norm_first | |
| self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) | |
| self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.dropout2 = nn.Dropout(dropout) | |
| self.activation = activation | |
| self.rotary = rotary | |
| self.use_rotary = rotary is not None | |
| def forward( | |
| self, | |
| src: Tensor, | |
| src_mask: Optional[Tensor] = None, | |
| src_key_padding_mask: Optional[Tensor] = None, | |
| ) -> Tensor: | |
| x = src | |
| if self.norm_first: | |
| x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) | |
| x = x + self._ff_block(self.norm2(x)) | |
| else: | |
| x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) | |
| x = self.norm2(x + self._ff_block(x)) | |
| return x | |
| # self-attention block | |
| def _sa_block( | |
| self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor] | |
| ) -> Tensor: | |
| qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x | |
| x = self.self_attn( | |
| qk, | |
| qk, | |
| x, | |
| attn_mask=attn_mask, | |
| key_padding_mask=key_padding_mask, | |
| need_weights=False, | |
| )[0] | |
| return self.dropout1(x) | |
| # feed forward block | |
| def _ff_block(self, x: Tensor) -> Tensor: | |
| x = self.linear2(self.dropout(self.activation(self.linear1(x)))) | |
| return self.dropout2(x) | |
| class DenseFiLM(nn.Module): | |
| """Feature-wise linear modulation (FiLM) generator.""" | |
| def __init__(self, embed_channels): | |
| super().__init__() | |
| self.embed_channels = embed_channels | |
| self.block = nn.Sequential( | |
| nn.Mish(), nn.Linear(embed_channels, embed_channels * 2) | |
| ) | |
| def forward(self, position): | |
| pos_encoding = self.block(position) | |
| pos_encoding = rearrange(pos_encoding, "b c -> b 1 c") | |
| scale_shift = pos_encoding.chunk(2, dim=-1) | |
| return scale_shift | |
| def featurewise_affine(x, scale_shift): | |
| scale, shift = scale_shift | |
| return (scale + 1) * x + shift | |
| class FiLMTransformerDecoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| nhead: int, | |
| dim_feedforward=2048, | |
| dropout=0.1, | |
| activation=F.relu, | |
| layer_norm_eps=1e-5, | |
| batch_first=False, | |
| norm_first=True, | |
| rotary=None, | |
| use_cm=False, | |
| ): | |
| super().__init__() | |
| self.self_attn = nn.MultiheadAttention( | |
| d_model, nhead, dropout=dropout, batch_first=batch_first | |
| ) | |
| self.multihead_attn = nn.MultiheadAttention( | |
| d_model, nhead, dropout=dropout, batch_first=batch_first | |
| ) | |
| # Feedforward | |
| self.linear1 = nn.Linear(d_model, dim_feedforward) | |
| self.dropout = nn.Dropout(dropout) | |
| self.linear2 = nn.Linear(dim_feedforward, d_model) | |
| self.norm_first = norm_first | |
| self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) | |
| self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) | |
| self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps) | |
| self.dropout1 = nn.Dropout(dropout) | |
| self.dropout2 = nn.Dropout(dropout) | |
| self.dropout3 = nn.Dropout(dropout) | |
| self.activation = activation | |
| self.film1 = DenseFiLM(d_model) | |
| self.film2 = DenseFiLM(d_model) | |
| self.film3 = DenseFiLM(d_model) | |
| if use_cm: | |
| self.multihead_attn2 = nn.MultiheadAttention( # 2 | |
| d_model, nhead, dropout=dropout, batch_first=batch_first | |
| ) | |
| self.norm2a = nn.LayerNorm(d_model, eps=layer_norm_eps) # 2 | |
| self.dropout2a = nn.Dropout(dropout) # 2 | |
| self.film2a = DenseFiLM(d_model) # 2 | |
| self.rotary = rotary | |
| self.use_rotary = rotary is not None | |
| # x, cond, t | |
| def forward( | |
| self, | |
| tgt, | |
| memory, | |
| t, | |
| tgt_mask=None, | |
| memory_mask=None, | |
| tgt_key_padding_mask=None, | |
| memory_key_padding_mask=None, | |
| memory2=None, | |
| ): | |
| x = tgt | |
| if self.norm_first: | |
| # self-attention -> film -> residual | |
| x_1 = self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask) | |
| x = x + featurewise_affine(x_1, self.film1(t)) | |
| # cross-attention -> film -> residual | |
| x_2 = self._mha_block( | |
| self.norm2(x), | |
| memory, | |
| memory_mask, | |
| memory_key_padding_mask, | |
| self.multihead_attn, | |
| self.dropout2, | |
| ) | |
| x = x + featurewise_affine(x_2, self.film2(t)) | |
| if memory2 is not None: | |
| # cross-attention x2 -> film -> residual | |
| x_2a = self._mha_block( | |
| self.norm2a(x), | |
| memory2, | |
| memory_mask, | |
| memory_key_padding_mask, | |
| self.multihead_attn2, | |
| self.dropout2a, | |
| ) | |
| x = x + featurewise_affine(x_2a, self.film2a(t)) | |
| # feedforward -> film -> residual | |
| x_3 = self._ff_block(self.norm3(x)) | |
| x = x + featurewise_affine(x_3, self.film3(t)) | |
| else: | |
| x = self.norm1( | |
| x | |
| + featurewise_affine( | |
| self._sa_block(x, tgt_mask, tgt_key_padding_mask), self.film1(t) | |
| ) | |
| ) | |
| x = self.norm2( | |
| x | |
| + featurewise_affine( | |
| self._mha_block(x, memory, memory_mask, memory_key_padding_mask), | |
| self.film2(t), | |
| ) | |
| ) | |
| x = self.norm3(x + featurewise_affine(self._ff_block(x), self.film3(t))) | |
| return x | |
| # self-attention block | |
| # qkv | |
| def _sa_block(self, x, attn_mask, key_padding_mask): | |
| qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x | |
| x = self.self_attn( | |
| qk, | |
| qk, | |
| x, | |
| attn_mask=attn_mask, | |
| key_padding_mask=key_padding_mask, | |
| need_weights=False, | |
| )[0] | |
| return self.dropout1(x) | |
| # multihead attention block | |
| # qkv | |
| def _mha_block(self, x, mem, attn_mask, key_padding_mask, mha, dropout): | |
| q = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x | |
| k = self.rotary.rotate_queries_or_keys(mem) if self.use_rotary else mem | |
| x = mha( | |
| q, | |
| k, | |
| mem, | |
| attn_mask=attn_mask, | |
| key_padding_mask=key_padding_mask, | |
| need_weights=False, | |
| )[0] | |
| return dropout(x) | |
| # feed forward block | |
| def _ff_block(self, x): | |
| x = self.linear2(self.dropout(self.activation(self.linear1(x)))) | |
| return self.dropout3(x) | |
| class DecoderLayerStack(nn.Module): | |
| def __init__(self, stack): | |
| super().__init__() | |
| self.stack = stack | |
| def forward(self, x, cond, t, tgt_mask=None, memory2=None): | |
| for layer in self.stack: | |
| x = layer(x, cond, t, tgt_mask=tgt_mask, memory2=memory2) | |
| return x | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1024): | |
| super().__init__() | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len).unsqueeze(1) | |
| div_term = torch.exp( | |
| torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) | |
| ) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| self.register_buffer("pe", pe) | |
| self.dropout = nn.Dropout(p=dropout) | |
| def forward(self, x: torch.Tensor): | |
| """ | |
| :param x: B x T x d_model tensor | |
| :return: B x T x d_model tensor | |
| """ | |
| x = x + self.pe[None, : x.shape[1], :] | |
| x = self.dropout(x) | |
| return x | |
| class TimestepEncoding(nn.Module): | |
| def __init__(self, embedding_dim: int): | |
| super().__init__() | |
| # Fourier embedding | |
| half_dim = embedding_dim // 2 | |
| emb = math.log(10000) / (half_dim - 1) | |
| emb = torch.exp(torch.arange(half_dim) * -emb) | |
| self.register_buffer("emb", emb) | |
| # encoding | |
| self.encoding = nn.Sequential( | |
| nn.Linear(embedding_dim, 4 * embedding_dim), | |
| nn.Mish(), | |
| nn.Linear(4 * embedding_dim, embedding_dim), | |
| ) | |
| def forward(self, t: torch.Tensor): | |
| """ | |
| :param t: B-dimensional tensor containing timesteps in range [0, 1] | |
| :return: B x embedding_dim tensor containing timestep encodings | |
| """ | |
| x = t[:, None] * self.emb[None, :] | |
| x = torch.cat([torch.sin(x), torch.cos(x)], dim=-1) | |
| x = self.encoding(x) | |
| return x | |
| class FiLM(nn.Module): | |
| def __init__(self, dim: int): | |
| super().__init__() | |
| self.dim = dim | |
| self.film = nn.Sequential(nn.Mish(), nn.Linear(dim, dim * 2)) | |
| def forward(self, x: torch.Tensor, cond: torch.Tensor): | |
| """ | |
| :param x: ... x dim tensor | |
| :param cond: ... x dim tensor | |
| :return: ... x dim tensor as scale(cond) * x + bias(cond) | |
| """ | |
| cond = self.film(cond) | |
| scale, bias = torch.chunk(cond, chunks=2, dim=-1) | |
| x = (scale + 1) * x + bias | |
| return x | |
| class FeedforwardBlock(nn.Module): | |
| def __init__(self, d_model: int, d_feedforward: int = 1024, dropout: float = 0.1): | |
| super().__init__() | |
| self.ff = nn.Sequential( | |
| nn.Linear(d_model, d_feedforward), | |
| nn.ReLU(), | |
| nn.Dropout(p=dropout), | |
| nn.Linear(d_feedforward, d_model), | |
| nn.Dropout(p=dropout), | |
| ) | |
| def forward(self, x: torch.Tensor): | |
| """ | |
| :param x: ... x d_model tensor | |
| :return: ... x d_model tensor | |
| """ | |
| return self.ff(x) | |
| class SelfAttention(nn.Module): | |
| def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1): | |
| super().__init__() | |
| self.self_attn = nn.MultiheadAttention( | |
| d_model, num_heads, dropout=dropout, batch_first=True | |
| ) | |
| self.dropout = nn.Dropout(p=dropout) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| attn_mask: torch.Tensor = None, | |
| key_padding_mask: torch.Tensor = None, | |
| ): | |
| """ | |
| :param x: B x T x d_model input tensor | |
| :param attn_mask: B * num_heads x L x S mask with L=target sequence length, S=source sequence length | |
| for a float mask: values will be added to attention weight | |
| for a binary mask: True indicates that the element is not allowed to attend | |
| :param key_padding_mask: B x S mask | |
| for a float mask: values will be added directly to the corresponding key values | |
| for a binary mask: True indicates that the corresponding key value will be ignored | |
| :return: B x T x d_model output tensor | |
| """ | |
| x = self.self_attn( | |
| x, | |
| x, | |
| x, | |
| attn_mask=attn_mask, | |
| key_padding_mask=key_padding_mask, | |
| need_weights=False, | |
| )[0] | |
| x = self.dropout(x) | |
| return x | |
| class CrossAttention(nn.Module): | |
| def __init__(self, d_model: int, d_cond: int, num_heads: int, dropout: float = 0.1): | |
| super().__init__() | |
| self.cross_attn = nn.MultiheadAttention( | |
| d_model, | |
| num_heads, | |
| dropout=dropout, | |
| batch_first=True, | |
| kdim=d_cond, | |
| vdim=d_cond, | |
| ) | |
| self.dropout = nn.Dropout(p=dropout) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| cond: torch.Tensor, | |
| attn_mask: torch.Tensor = None, | |
| key_padding_mask: torch.Tensor = None, | |
| ): | |
| """ | |
| :param x: B x T_target x d_model input tensor | |
| :param cond: B x T_cond x d_cond condition tensor | |
| :param attn_mask: B * num_heads x L x S mask with L=target sequence length, S=source sequence length | |
| for a float mask: values will be added to attention weight | |
| for a binary mask: True indicates that the element is not allowed to attend | |
| :param key_padding_mask: B x S mask | |
| for a float mask: values will be added directly to the corresponding key values | |
| for a binary mask: True indicates that the corresponding key value will be ignored | |
| :return: B x T x d_model output tensor | |
| """ | |
| x = self.cross_attn( | |
| x, | |
| cond, | |
| cond, | |
| attn_mask=attn_mask, | |
| key_padding_mask=key_padding_mask, | |
| need_weights=False, | |
| )[0] | |
| x = self.dropout(x) | |
| return x | |
| class TransformerEncoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| num_heads: int, | |
| d_feedforward: int = 1024, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.self_attn = SelfAttention(d_model, num_heads, dropout) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.feedforward = FeedforwardBlock(d_model, d_feedforward, dropout) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| mask: torch.Tensor = None, | |
| key_padding_mask: torch.Tensor = None, | |
| ): | |
| x = x + self.self_attn(self.norm1(x), mask, key_padding_mask) | |
| x = x + self.feedforward(self.norm2(x)) | |
| return x | |
| class TransformerDecoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| d_cond: int, | |
| num_heads: int, | |
| d_feedforward: int = 1024, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.self_attn = SelfAttention(d_model, num_heads, dropout) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.cross_attn = CrossAttention(d_model, d_cond, num_heads, dropout) | |
| self.norm3 = nn.LayerNorm(d_model) | |
| self.feedforward = FeedforwardBlock(d_model, d_feedforward, dropout) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| cross_cond: torch.Tensor, | |
| target_mask: torch.Tensor = None, | |
| target_key_padding_mask: torch.Tensor = None, | |
| cross_cond_mask: torch.Tensor = None, | |
| cross_cond_key_padding_mask: torch.Tensor = None, | |
| ): | |
| """ | |
| :param x: B x T x d_model tensor | |
| :param cross_cond: B x T x d_cond tensor containing the conditioning input to cross attention layers | |
| :return: B x T x d_model tensor | |
| """ | |
| x = x + self.self_attn(self.norm1(x), target_mask, target_key_padding_mask) | |
| x = x + self.cross_attn( | |
| self.norm2(x), cross_cond, cross_cond_mask, cross_cond_key_padding_mask | |
| ) | |
| x = x + self.feedforward(self.norm3(x)) | |
| return x | |
| class FilmTransformerDecoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| d_model: int, | |
| d_cond: int, | |
| num_heads: int, | |
| d_feedforward: int = 1024, | |
| dropout: float = 0.1, | |
| ): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.self_attn = SelfAttention(d_model, num_heads, dropout) | |
| self.film1 = FiLM(d_model) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.cross_attn = CrossAttention(d_model, d_cond, num_heads, dropout) | |
| self.film2 = FiLM(d_model) | |
| self.norm3 = nn.LayerNorm(d_model) | |
| self.feedforward = FeedforwardBlock(d_model, d_feedforward, dropout) | |
| self.film3 = FiLM(d_model) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| cross_cond: torch.Tensor, | |
| film_cond: torch.Tensor, | |
| target_mask: torch.Tensor = None, | |
| target_key_padding_mask: torch.Tensor = None, | |
| cross_cond_mask: torch.Tensor = None, | |
| cross_cond_key_padding_mask: torch.Tensor = None, | |
| ): | |
| """ | |
| :param x: B x T x d_model tensor | |
| :param cross_cond: B x T x d_cond tensor containing the conditioning input to cross attention layers | |
| :param film_cond: B x [1 or T] x film_cond tensor containing the conditioning input to FiLM layers | |
| :return: B x T x d_model tensor | |
| """ | |
| x1 = self.self_attn(self.norm1(x), target_mask, target_key_padding_mask) | |
| x = x + self.film1(x1, film_cond) | |
| x2 = self.cross_attn( | |
| self.norm2(x), cross_cond, cross_cond_mask, cross_cond_key_padding_mask | |
| ) | |
| x = x + self.film2(x2, film_cond) | |
| x3 = self.feedforward(self.norm3(x)) | |
| x = x + self.film3(x3, film_cond) | |
| return x | |
| class RegressionTransformer(nn.Module): | |
| def __init__( | |
| self, | |
| transformer_encoder_layers: int = 2, | |
| transformer_decoder_layers: int = 4, | |
| d_model: int = 512, | |
| d_cond: int = 512, | |
| num_heads: int = 4, | |
| d_feedforward: int = 1024, | |
| dropout: float = 0.1, | |
| causal: bool = False, | |
| ): | |
| super().__init__() | |
| self.causal = causal | |
| self.cond_positional_encoding = PositionalEncoding(d_cond, dropout) | |
| self.target_positional_encoding = PositionalEncoding(d_model, dropout) | |
| self.transformer_encoder = nn.ModuleList( | |
| [ | |
| TransformerEncoderLayer(d_cond, num_heads, d_feedforward, dropout) | |
| for _ in range(transformer_encoder_layers) | |
| ] | |
| ) | |
| self.transformer_decoder = nn.ModuleList( | |
| [ | |
| TransformerDecoderLayer( | |
| d_model, d_cond, num_heads, d_feedforward, dropout | |
| ) | |
| for _ in range(transformer_decoder_layers) | |
| ] | |
| ) | |
| def forward(self, x: torch.Tensor, cond: torch.Tensor): | |
| """ | |
| :param x: B x T x d_model input tensor | |
| :param cond: B x T x d_cond conditional tensor | |
| :return: B x T x d_model output tensor | |
| """ | |
| x = self.target_positional_encoding(x) | |
| cond = self.cond_positional_encoding(cond) | |
| if self.causal: | |
| encoder_mask = generate_causal_mask( | |
| cond.shape[1], cond.shape[1], device=cond.device | |
| ) | |
| decoder_self_attn_mask = generate_causal_mask( | |
| x.shape[1], x.shape[1], device=x.device | |
| ) | |
| decoder_cross_attn_mask = generate_causal_mask( | |
| cond.shape[1], x.shape[1], device=x.device | |
| ) | |
| else: | |
| encoder_mask = None | |
| decoder_self_attn_mask = None | |
| decoder_cross_attn_mask = None | |
| for encoder_layer in self.transformer_encoder: | |
| cond = encoder_layer(cond, mask=encoder_mask) | |
| for decoder_layer in self.transformer_decoder: | |
| x = decoder_layer( | |
| x, | |
| cond, | |
| target_mask=decoder_self_attn_mask, | |
| cross_cond_mask=decoder_cross_attn_mask, | |
| ) | |
| return x | |
| class DiffusionTransformer(nn.Module): | |
| def __init__( | |
| self, | |
| transformer_encoder_layers: int = 2, | |
| transformer_decoder_layers: int = 4, | |
| d_model: int = 512, | |
| d_cond: int = 512, | |
| num_heads: int = 4, | |
| d_feedforward: int = 1024, | |
| dropout: float = 0.1, | |
| causal: bool = False, | |
| ): | |
| super().__init__() | |
| self.causal = causal | |
| self.timestep_encoder = TimestepEncoding(d_model) | |
| self.cond_positional_encoding = PositionalEncoding(d_cond, dropout) | |
| self.target_positional_encoding = PositionalEncoding(d_model, dropout) | |
| self.transformer_encoder = nn.ModuleList( | |
| [ | |
| TransformerEncoderLayer(d_cond, num_heads, d_feedforward, dropout) | |
| for _ in range(transformer_encoder_layers) | |
| ] | |
| ) | |
| self.transformer_decoder = nn.ModuleList( | |
| [ | |
| FilmTransformerDecoderLayer( | |
| d_model, d_cond, num_heads, d_feedforward, dropout | |
| ) | |
| for _ in range(transformer_decoder_layers) | |
| ] | |
| ) | |
| def forward(self, x: torch.Tensor, cond: torch.Tensor, t: torch.Tensor): | |
| """ | |
| :param x: B x T x d_model input tensor | |
| :param cond: B x T x d_cond conditional tensor | |
| :param t: B-dimensional tensor containing diffusion timesteps in range [0, 1] | |
| :return: B x T x d_model output tensor | |
| """ | |
| t = self.timestep_encoder(t).unsqueeze(1) # B x 1 x d_model | |
| x = self.target_positional_encoding(x) | |
| cond = self.cond_positional_encoding(cond) | |
| if self.causal: | |
| encoder_mask = generate_causal_mask( | |
| cond.shape[1], cond.shape[1], device=cond.device | |
| ) | |
| decoder_self_attn_mask = generate_causal_mask( | |
| x.shape[1], x.shape[1], device=x.device | |
| ) | |
| decoder_cross_attn_mask = generate_causal_mask( | |
| cond.shape[1], x.shape[1], device=x.device | |
| ) | |
| else: | |
| encoder_mask = None | |
| decoder_self_attn_mask = None | |
| decoder_cross_attn_mask = None | |
| for encoder_layer in self.transformer_encoder: | |
| cond = encoder_layer(cond, mask=encoder_mask) | |
| for decoder_layer in self.transformer_decoder: | |
| x = decoder_layer( | |
| x, | |
| cond, | |
| t, | |
| target_mask=decoder_self_attn_mask, | |
| cross_cond_mask=decoder_cross_attn_mask, | |
| ) | |
| return x | |