Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor | |
| import numpy as np | |
| from einops import rearrange | |
| from typing import Optional, List | |
| from torchtyping import TensorType | |
| from einops._torch_specific import allow_ops_in_compiled_graph # requires einops>=0.6.1 | |
| allow_ops_in_compiled_graph() | |
| batch_size, num_cond_feats = None, None | |
| class FusedMLP(nn.Sequential): | |
| def __init__( | |
| self, | |
| dim_model: int, | |
| dropout: float, | |
| activation: nn.Module, | |
| hidden_layer_multiplier: int = 4, | |
| bias: bool = True, | |
| ): | |
| super().__init__( | |
| nn.Linear(dim_model, dim_model * hidden_layer_multiplier, bias=bias), | |
| activation(), | |
| nn.Dropout(dropout), | |
| nn.Linear(dim_model * hidden_layer_multiplier, dim_model, bias=bias), | |
| ) | |
| def _cast_if_autocast_enabled(tensor): | |
| if torch.is_autocast_enabled(): | |
| if tensor.device.type == "cuda": | |
| dtype = torch.get_autocast_gpu_dtype() | |
| elif tensor.device.type == "cpu": | |
| dtype = torch.get_autocast_cpu_dtype() | |
| else: | |
| raise NotImplementedError() | |
| return tensor.to(dtype=dtype) | |
| return tensor | |
| class LayerNorm16Bits(torch.nn.LayerNorm): | |
| """ | |
| 16-bit friendly version of torch.nn.LayerNorm | |
| """ | |
| def __init__( | |
| self, | |
| normalized_shape, | |
| eps=1e-06, | |
| elementwise_affine=True, | |
| device=None, | |
| dtype=None, | |
| ): | |
| super().__init__( | |
| normalized_shape=normalized_shape, | |
| eps=eps, | |
| elementwise_affine=elementwise_affine, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| def forward(self, x): | |
| module_device = x.device | |
| downcast_x = _cast_if_autocast_enabled(x) | |
| downcast_weight = ( | |
| _cast_if_autocast_enabled(self.weight) | |
| if self.weight is not None | |
| else self.weight | |
| ) | |
| downcast_bias = ( | |
| _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias | |
| ) | |
| with torch.autocast(enabled=False, device_type=module_device.type): | |
| return nn.functional.layer_norm( | |
| downcast_x, | |
| self.normalized_shape, | |
| downcast_weight, | |
| downcast_bias, | |
| self.eps, | |
| ) | |
| class StochatichDepth(nn.Module): | |
| def __init__(self, p: float): | |
| super().__init__() | |
| self.survival_prob = 1.0 - p | |
| def forward(self, x: Tensor) -> Tensor: | |
| if self.training and self.survival_prob < 1: | |
| mask = ( | |
| torch.empty(x.shape[0], 1, 1, device=x.device).uniform_() | |
| + self.survival_prob | |
| ) | |
| mask = mask.floor() | |
| if self.survival_prob > 0: | |
| mask = mask / self.survival_prob | |
| return x * mask | |
| else: | |
| return x | |
| class CrossAttentionOp(nn.Module): | |
| def __init__( | |
| self, attention_dim, num_heads, dim_q, dim_kv, use_biases=True, is_sa=False | |
| ): | |
| super().__init__() | |
| self.dim_q = dim_q | |
| self.dim_kv = dim_kv | |
| self.attention_dim = attention_dim | |
| self.num_heads = num_heads | |
| self.use_biases = use_biases | |
| self.is_sa = is_sa | |
| if self.is_sa: | |
| self.qkv = nn.Linear(dim_q, attention_dim * 3, bias=use_biases) | |
| else: | |
| self.q = nn.Linear(dim_q, attention_dim, bias=use_biases) | |
| self.kv = nn.Linear(dim_kv, attention_dim * 2, bias=use_biases) | |
| self.out = nn.Linear(attention_dim, dim_q, bias=use_biases) | |
| def forward(self, x_to, x_from=None, attention_mask=None): | |
| if x_from is None: | |
| x_from = x_to | |
| if self.is_sa: | |
| q, k, v = self.qkv(x_to).chunk(3, dim=-1) | |
| else: | |
| q = self.q(x_to) | |
| k, v = self.kv(x_from).chunk(2, dim=-1) | |
| q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads) | |
| k = rearrange(k, "b n (h d) -> b h n d", h=self.num_heads) | |
| v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads) | |
| if attention_mask is not None: | |
| attention_mask = attention_mask.unsqueeze(1) | |
| x = torch.nn.functional.scaled_dot_product_attention( | |
| q, k, v, attn_mask=attention_mask | |
| ) | |
| x = rearrange(x, "b h n d -> b n (h d)") | |
| x = self.out(x) | |
| return x | |
| class CrossAttentionBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim_q: int, | |
| dim_kv: int, | |
| num_heads: int, | |
| attention_dim: int = 0, | |
| mlp_multiplier: int = 4, | |
| dropout: float = 0.0, | |
| stochastic_depth: float = 0.0, | |
| use_biases: bool = True, | |
| retrieve_attention_scores: bool = False, | |
| use_layernorm16: bool = True, | |
| ): | |
| super().__init__() | |
| layer_norm = ( | |
| nn.LayerNorm | |
| if not use_layernorm16 or retrieve_attention_scores | |
| else LayerNorm16Bits | |
| ) | |
| self.retrieve_attention_scores = retrieve_attention_scores | |
| self.initial_to_ln = layer_norm(dim_q, eps=1e-6) | |
| attention_dim = min(dim_q, dim_kv) if attention_dim == 0 else attention_dim | |
| self.ca = CrossAttentionOp( | |
| attention_dim, num_heads, dim_q, dim_kv, is_sa=False, use_biases=use_biases | |
| ) | |
| self.ca_stochastic_depth = StochatichDepth(stochastic_depth) | |
| self.middle_ln = layer_norm(dim_q, eps=1e-6) | |
| self.ffn = FusedMLP( | |
| dim_model=dim_q, | |
| dropout=dropout, | |
| activation=nn.GELU, | |
| hidden_layer_multiplier=mlp_multiplier, | |
| bias=use_biases, | |
| ) | |
| self.ffn_stochastic_depth = StochatichDepth(stochastic_depth) | |
| def forward( | |
| self, | |
| to_tokens: Tensor, | |
| from_tokens: Tensor, | |
| to_token_mask: Optional[Tensor] = None, | |
| from_token_mask: Optional[Tensor] = None, | |
| ) -> Tensor: | |
| if to_token_mask is None and from_token_mask is None: | |
| attention_mask = None | |
| else: | |
| if to_token_mask is None: | |
| to_token_mask = torch.ones( | |
| to_tokens.shape[0], | |
| to_tokens.shape[1], | |
| dtype=torch.bool, | |
| device=to_tokens.device, | |
| ) | |
| if from_token_mask is None: | |
| from_token_mask = torch.ones( | |
| from_tokens.shape[0], | |
| from_tokens.shape[1], | |
| dtype=torch.bool, | |
| device=from_tokens.device, | |
| ) | |
| attention_mask = from_token_mask.unsqueeze(1) * to_token_mask.unsqueeze(2) | |
| attention_output = self.ca( | |
| self.initial_to_ln(to_tokens), | |
| from_tokens, | |
| attention_mask=attention_mask, | |
| ) | |
| to_tokens = to_tokens + self.ca_stochastic_depth(attention_output) | |
| to_tokens = to_tokens + self.ffn_stochastic_depth( | |
| self.ffn(self.middle_ln(to_tokens)) | |
| ) | |
| return to_tokens | |
| class SelfAttentionBlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim_qkv: int, | |
| num_heads: int, | |
| attention_dim: int = 0, | |
| mlp_multiplier: int = 4, | |
| dropout: float = 0.0, | |
| stochastic_depth: float = 0.0, | |
| use_biases: bool = True, | |
| use_layer_scale: bool = False, | |
| layer_scale_value: float = 0.0, | |
| use_layernorm16: bool = True, | |
| ): | |
| super().__init__() | |
| layer_norm = LayerNorm16Bits if use_layernorm16 else nn.LayerNorm | |
| self.initial_ln = layer_norm(dim_qkv, eps=1e-6) | |
| attention_dim = dim_qkv if attention_dim == 0 else attention_dim | |
| self.sa = CrossAttentionOp( | |
| attention_dim, | |
| num_heads, | |
| dim_qkv, | |
| dim_qkv, | |
| is_sa=True, | |
| use_biases=use_biases, | |
| ) | |
| self.sa_stochastic_depth = StochatichDepth(stochastic_depth) | |
| self.middle_ln = layer_norm(dim_qkv, eps=1e-6) | |
| self.ffn = FusedMLP( | |
| dim_model=dim_qkv, | |
| dropout=dropout, | |
| activation=nn.GELU, | |
| hidden_layer_multiplier=mlp_multiplier, | |
| bias=use_biases, | |
| ) | |
| self.ffn_stochastic_depth = StochatichDepth(stochastic_depth) | |
| self.use_layer_scale = use_layer_scale | |
| if use_layer_scale: | |
| self.layer_scale_1 = nn.Parameter( | |
| torch.ones(dim_qkv) * layer_scale_value, requires_grad=True | |
| ) | |
| self.layer_scale_2 = nn.Parameter( | |
| torch.ones(dim_qkv) * layer_scale_value, requires_grad=True | |
| ) | |
| def forward( | |
| self, | |
| tokens: torch.Tensor, | |
| token_mask: Optional[torch.Tensor] = None, | |
| ): | |
| if token_mask is None: | |
| attention_mask = None | |
| else: | |
| attention_mask = token_mask.unsqueeze(1) * torch.ones( | |
| tokens.shape[0], | |
| tokens.shape[1], | |
| 1, | |
| dtype=torch.bool, | |
| device=tokens.device, | |
| ) | |
| attention_output = self.sa( | |
| self.initial_ln(tokens), | |
| attention_mask=attention_mask, | |
| ) | |
| if self.use_layer_scale: | |
| tokens = tokens + self.sa_stochastic_depth( | |
| self.layer_scale_1 * attention_output | |
| ) | |
| tokens = tokens + self.ffn_stochastic_depth( | |
| self.layer_scale_2 * self.ffn(self.middle_ln(tokens)) | |
| ) | |
| else: | |
| tokens = tokens + self.sa_stochastic_depth(attention_output) | |
| tokens = tokens + self.ffn_stochastic_depth( | |
| self.ffn(self.middle_ln(tokens)) | |
| ) | |
| return tokens | |
| class AdaLNSABlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim_qkv: int, | |
| dim_cond: int, | |
| num_heads: int, | |
| attention_dim: int = 0, | |
| mlp_multiplier: int = 4, | |
| dropout: float = 0.0, | |
| stochastic_depth: float = 0.0, | |
| use_biases: bool = True, | |
| use_layer_scale: bool = False, | |
| layer_scale_value: float = 0.1, | |
| use_layernorm16: bool = True, | |
| ): | |
| super().__init__() | |
| layer_norm = LayerNorm16Bits if use_layernorm16 else nn.LayerNorm | |
| self.initial_ln = layer_norm(dim_qkv, eps=1e-6, elementwise_affine=False) | |
| attention_dim = dim_qkv if attention_dim == 0 else attention_dim | |
| self.adaln_modulation = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(dim_cond, dim_qkv * 6, bias=use_biases), | |
| ) | |
| # Zero init | |
| nn.init.zeros_(self.adaln_modulation[1].weight) | |
| nn.init.zeros_(self.adaln_modulation[1].bias) | |
| self.sa = CrossAttentionOp( | |
| attention_dim, | |
| num_heads, | |
| dim_qkv, | |
| dim_qkv, | |
| is_sa=True, | |
| use_biases=use_biases, | |
| ) | |
| self.sa_stochastic_depth = StochatichDepth(stochastic_depth) | |
| self.middle_ln = layer_norm(dim_qkv, eps=1e-6, elementwise_affine=False) | |
| self.ffn = FusedMLP( | |
| dim_model=dim_qkv, | |
| dropout=dropout, | |
| activation=nn.GELU, | |
| hidden_layer_multiplier=mlp_multiplier, | |
| bias=use_biases, | |
| ) | |
| self.ffn_stochastic_depth = StochatichDepth(stochastic_depth) | |
| self.use_layer_scale = use_layer_scale | |
| if use_layer_scale: | |
| self.layer_scale_1 = nn.Parameter( | |
| torch.ones(dim_qkv) * layer_scale_value, requires_grad=True | |
| ) | |
| self.layer_scale_2 = nn.Parameter( | |
| torch.ones(dim_qkv) * layer_scale_value, requires_grad=True | |
| ) | |
| def forward( | |
| self, | |
| tokens: torch.Tensor, | |
| cond: torch.Tensor, | |
| token_mask: Optional[torch.Tensor] = None, | |
| ): | |
| if token_mask is None: | |
| attention_mask = None | |
| else: | |
| attention_mask = token_mask.unsqueeze(1) * torch.ones( | |
| tokens.shape[0], | |
| tokens.shape[1], | |
| 1, | |
| dtype=torch.bool, | |
| device=tokens.device, | |
| ) | |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( | |
| self.adaln_modulation(cond).chunk(6, dim=-1) | |
| ) | |
| attention_output = self.sa( | |
| modulate_shift_and_scale(self.initial_ln(tokens), shift_msa, scale_msa), | |
| attention_mask=attention_mask, | |
| ) | |
| if self.use_layer_scale: | |
| tokens = tokens + self.sa_stochastic_depth( | |
| gate_msa.unsqueeze(1) * self.layer_scale_1 * attention_output | |
| ) | |
| tokens = tokens + self.ffn_stochastic_depth( | |
| gate_mlp.unsqueeze(1) | |
| * self.layer_scale_2 | |
| * self.ffn( | |
| modulate_shift_and_scale( | |
| self.middle_ln(tokens), shift_mlp, scale_mlp | |
| ) | |
| ) | |
| ) | |
| else: | |
| tokens = tokens + gate_msa.unsqueeze(1) * self.sa_stochastic_depth( | |
| attention_output | |
| ) | |
| tokens = tokens + self.ffn_stochastic_depth( | |
| gate_mlp.unsqueeze(1) | |
| * self.ffn( | |
| modulate_shift_and_scale( | |
| self.middle_ln(tokens), shift_mlp, scale_mlp | |
| ) | |
| ) | |
| ) | |
| return tokens | |
| class CrossAttentionSABlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim_qkv: int, | |
| dim_cond: int, | |
| num_heads: int, | |
| attention_dim: int = 0, | |
| mlp_multiplier: int = 4, | |
| dropout: float = 0.0, | |
| stochastic_depth: float = 0.0, | |
| use_biases: bool = True, | |
| use_layer_scale: bool = False, | |
| layer_scale_value: float = 0.0, | |
| use_layernorm16: bool = True, | |
| ): | |
| super().__init__() | |
| layer_norm = LayerNorm16Bits if use_layernorm16 else nn.LayerNorm | |
| attention_dim = dim_qkv if attention_dim == 0 else attention_dim | |
| self.ca = CrossAttentionOp( | |
| attention_dim, | |
| num_heads, | |
| dim_qkv, | |
| dim_cond, | |
| is_sa=False, | |
| use_biases=use_biases, | |
| ) | |
| self.ca_stochastic_depth = StochatichDepth(stochastic_depth) | |
| self.ca_ln = layer_norm(dim_qkv, eps=1e-6) | |
| self.initial_ln = layer_norm(dim_qkv, eps=1e-6) | |
| attention_dim = dim_qkv if attention_dim == 0 else attention_dim | |
| self.sa = CrossAttentionOp( | |
| attention_dim, | |
| num_heads, | |
| dim_qkv, | |
| dim_qkv, | |
| is_sa=True, | |
| use_biases=use_biases, | |
| ) | |
| self.sa_stochastic_depth = StochatichDepth(stochastic_depth) | |
| self.middle_ln = layer_norm(dim_qkv, eps=1e-6) | |
| self.ffn = FusedMLP( | |
| dim_model=dim_qkv, | |
| dropout=dropout, | |
| activation=nn.GELU, | |
| hidden_layer_multiplier=mlp_multiplier, | |
| bias=use_biases, | |
| ) | |
| self.ffn_stochastic_depth = StochatichDepth(stochastic_depth) | |
| self.use_layer_scale = use_layer_scale | |
| if use_layer_scale: | |
| self.layer_scale_1 = nn.Parameter( | |
| torch.ones(dim_qkv) * layer_scale_value, requires_grad=True | |
| ) | |
| self.layer_scale_2 = nn.Parameter( | |
| torch.ones(dim_qkv) * layer_scale_value, requires_grad=True | |
| ) | |
| def forward( | |
| self, | |
| tokens: torch.Tensor, | |
| cond: torch.Tensor, | |
| token_mask: Optional[torch.Tensor] = None, | |
| cond_mask: Optional[torch.Tensor] = None, | |
| ): | |
| if cond_mask is None: | |
| cond_attention_mask = None | |
| else: | |
| cond_attention_mask = torch.ones( | |
| cond.shape[0], | |
| 1, | |
| cond.shape[1], | |
| dtype=torch.bool, | |
| device=tokens.device, | |
| ) * token_mask.unsqueeze(2) | |
| if token_mask is None: | |
| attention_mask = None | |
| else: | |
| attention_mask = token_mask.unsqueeze(1) * torch.ones( | |
| tokens.shape[0], | |
| tokens.shape[1], | |
| 1, | |
| dtype=torch.bool, | |
| device=tokens.device, | |
| ) | |
| ca_output = self.ca( | |
| self.ca_ln(tokens), | |
| cond, | |
| attention_mask=cond_attention_mask, | |
| ) | |
| ca_output = torch.nan_to_num( | |
| ca_output, nan=0.0, posinf=0.0, neginf=0.0 | |
| ) # Needed as some tokens get attention from no token so Nan | |
| tokens = tokens + self.ca_stochastic_depth(ca_output) | |
| attention_output = self.sa( | |
| self.initial_ln(tokens), | |
| attention_mask=attention_mask, | |
| ) | |
| if self.use_layer_scale: | |
| tokens = tokens + self.sa_stochastic_depth( | |
| self.layer_scale_1 * attention_output | |
| ) | |
| tokens = tokens + self.ffn_stochastic_depth( | |
| self.layer_scale_2 * self.ffn(self.middle_ln(tokens)) | |
| ) | |
| else: | |
| tokens = tokens + self.sa_stochastic_depth(attention_output) | |
| tokens = tokens + self.ffn_stochastic_depth( | |
| self.ffn(self.middle_ln(tokens)) | |
| ) | |
| return tokens | |
| class CAAdaLNSABlock(nn.Module): | |
| def __init__( | |
| self, | |
| dim_qkv: int, | |
| dim_cond: int, | |
| num_heads: int, | |
| attention_dim: int = 0, | |
| mlp_multiplier: int = 4, | |
| dropout: float = 0.0, | |
| stochastic_depth: float = 0.0, | |
| use_biases: bool = True, | |
| use_layer_scale: bool = False, | |
| layer_scale_value: float = 0.1, | |
| use_layernorm16: bool = True, | |
| ): | |
| super().__init__() | |
| layer_norm = LayerNorm16Bits if use_layernorm16 else nn.LayerNorm | |
| self.ca = CrossAttentionOp( | |
| attention_dim, | |
| num_heads, | |
| dim_qkv, | |
| dim_cond, | |
| is_sa=False, | |
| use_biases=use_biases, | |
| ) | |
| self.ca_stochastic_depth = StochatichDepth(stochastic_depth) | |
| self.ca_ln = layer_norm(dim_qkv, eps=1e-6) | |
| self.initial_ln = layer_norm(dim_qkv, eps=1e-6) | |
| attention_dim = dim_qkv if attention_dim == 0 else attention_dim | |
| self.adaln_modulation = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(dim_cond, dim_qkv * 6, bias=use_biases), | |
| ) | |
| # Zero init | |
| nn.init.zeros_(self.adaln_modulation[1].weight) | |
| nn.init.zeros_(self.adaln_modulation[1].bias) | |
| self.sa = CrossAttentionOp( | |
| attention_dim, | |
| num_heads, | |
| dim_qkv, | |
| dim_qkv, | |
| is_sa=True, | |
| use_biases=use_biases, | |
| ) | |
| self.sa_stochastic_depth = StochatichDepth(stochastic_depth) | |
| self.middle_ln = layer_norm(dim_qkv, eps=1e-6) | |
| self.ffn = FusedMLP( | |
| dim_model=dim_qkv, | |
| dropout=dropout, | |
| activation=nn.GELU, | |
| hidden_layer_multiplier=mlp_multiplier, | |
| bias=use_biases, | |
| ) | |
| self.ffn_stochastic_depth = StochatichDepth(stochastic_depth) | |
| self.use_layer_scale = use_layer_scale | |
| if use_layer_scale: | |
| self.layer_scale_1 = nn.Parameter( | |
| torch.ones(dim_qkv) * layer_scale_value, requires_grad=True | |
| ) | |
| self.layer_scale_2 = nn.Parameter( | |
| torch.ones(dim_qkv) * layer_scale_value, requires_grad=True | |
| ) | |
| def forward( | |
| self, | |
| tokens: torch.Tensor, | |
| cond_1: torch.Tensor, | |
| cond_2: torch.Tensor, | |
| cond_1_mask: Optional[torch.Tensor] = None, | |
| token_mask: Optional[torch.Tensor] = None, | |
| ): | |
| if token_mask is None and cond_1_mask is None: | |
| cond_attention_mask = None | |
| elif token_mask is None: | |
| cond_attention_mask = cond_1_mask.unsqueeze(1) * torch.ones( | |
| cond_1.shape[0], | |
| cond_1.shape[1], | |
| 1, | |
| dtype=torch.bool, | |
| device=cond_1.device, | |
| ) | |
| elif cond_1_mask is None: | |
| cond_attention_mask = torch.ones( | |
| tokens.shape[0], | |
| 1, | |
| tokens.shape[1], | |
| dtype=torch.bool, | |
| device=tokens.device, | |
| ) * token_mask.unsqueeze(2) | |
| else: | |
| cond_attention_mask = cond_1_mask.unsqueeze(1) * token_mask.unsqueeze(2) | |
| if token_mask is None: | |
| attention_mask = None | |
| else: | |
| attention_mask = token_mask.unsqueeze(1) * torch.ones( | |
| tokens.shape[0], | |
| tokens.shape[1], | |
| 1, | |
| dtype=torch.bool, | |
| device=tokens.device, | |
| ) | |
| ca_output = self.ca( | |
| self.ca_ln(tokens), | |
| cond_1, | |
| attention_mask=cond_attention_mask, | |
| ) | |
| ca_output = torch.nan_to_num(ca_output, nan=0.0, posinf=0.0, neginf=0.0) | |
| tokens = tokens + self.ca_stochastic_depth(ca_output) | |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( | |
| self.adaln_modulation(cond_2).chunk(6, dim=-1) | |
| ) | |
| attention_output = self.sa( | |
| modulate_shift_and_scale(self.initial_ln(tokens), shift_msa, scale_msa), | |
| attention_mask=attention_mask, | |
| ) | |
| if self.use_layer_scale: | |
| tokens = tokens + self.sa_stochastic_depth( | |
| gate_msa.unsqueeze(1) * self.layer_scale_1 * attention_output | |
| ) | |
| tokens = tokens + self.ffn_stochastic_depth( | |
| gate_mlp.unsqueeze(1) | |
| * self.layer_scale_2 | |
| * self.ffn( | |
| modulate_shift_and_scale( | |
| self.middle_ln(tokens), shift_mlp, scale_mlp | |
| ) | |
| ) | |
| ) | |
| else: | |
| tokens = tokens + gate_msa.unsqueeze(1) * self.sa_stochastic_depth( | |
| attention_output | |
| ) | |
| tokens = tokens + self.ffn_stochastic_depth( | |
| gate_mlp.unsqueeze(1) | |
| * self.ffn( | |
| modulate_shift_and_scale( | |
| self.middle_ln(tokens), shift_mlp, scale_mlp | |
| ) | |
| ) | |
| ) | |
| return tokens | |
| class PositionalEmbedding(nn.Module): | |
| """ | |
| Taken from https://github.com/NVlabs/edm | |
| """ | |
| def __init__(self, num_channels, max_positions=10000, endpoint=False): | |
| super().__init__() | |
| self.num_channels = num_channels | |
| self.max_positions = max_positions | |
| self.endpoint = endpoint | |
| freqs = torch.arange(start=0, end=self.num_channels // 2, dtype=torch.float32) | |
| freqs = 2 * freqs / self.num_channels | |
| freqs = (1 / self.max_positions) ** freqs | |
| self.register_buffer("freqs", freqs) | |
| def forward(self, x): | |
| x = torch.outer(x, self.freqs) | |
| out = torch.cat([x.cos(), x.sin()], dim=1) | |
| return out.to(x.dtype) | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model, dropout=0.0, max_len=10000): | |
| super().__init__() | |
| self.dropout = nn.Dropout(p=dropout) | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp( | |
| torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model) | |
| ) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| pe = pe.unsqueeze(0) | |
| self.register_buffer("pe", pe) | |
| def forward(self, x): | |
| # not used in the final model | |
| x = x + self.pe[:, : x.shape[1], :] | |
| return self.dropout(x) | |
| class TimeEmbedder(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| time_scaling: float, | |
| expansion: int = 4, | |
| ): | |
| super().__init__() | |
| self.encode_time = PositionalEmbedding(num_channels=dim, endpoint=True) | |
| self.time_scaling = time_scaling | |
| self.map_time = nn.Sequential( | |
| nn.Linear(dim, dim * expansion), | |
| nn.SiLU(), | |
| nn.Linear(dim * expansion, dim * expansion), | |
| ) | |
| def forward(self, t: Tensor) -> Tensor: | |
| time = self.encode_time(t * self.time_scaling) | |
| time_mean = time.mean(dim=-1, keepdim=True) | |
| time_std = time.std(dim=-1, keepdim=True) | |
| time = (time - time_mean) / time_std | |
| return self.map_time(time) | |
| def modulate_shift_and_scale(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor: | |
| return x * (1 + scale).unsqueeze(1) + shift.unsqueeze(1) | |
| # ------------------------------------------------------------------------------------- # | |
| class BaseDirector(nn.Module): | |
| def __init__( | |
| self, | |
| name: str, | |
| num_feats: int, | |
| num_cond_feats: int, | |
| num_cams: int, | |
| latent_dim: int, | |
| mlp_multiplier: int, | |
| num_layers: int, | |
| num_heads: int, | |
| dropout: float, | |
| stochastic_depth: float, | |
| label_dropout: float, | |
| num_rawfeats: int, | |
| clip_sequential: bool = False, | |
| cond_sequential: bool = False, | |
| device: str = "cuda", | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.name = name | |
| self.label_dropout = label_dropout | |
| self.num_rawfeats = num_rawfeats | |
| self.num_feats = num_feats | |
| self.num_cams = num_cams | |
| self.clip_sequential = clip_sequential | |
| self.cond_sequential = cond_sequential | |
| self.use_layernorm16 = device == "cuda" | |
| self.input_projection = nn.Sequential( | |
| nn.Linear(num_feats, latent_dim), | |
| PositionalEncoding(latent_dim), | |
| ) | |
| self.time_embedding = TimeEmbedder(latent_dim // 4, time_scaling=1000) | |
| self.init_conds_mappings(num_cond_feats, latent_dim) | |
| self.init_backbone( | |
| num_layers, latent_dim, mlp_multiplier, num_heads, dropout, stochastic_depth | |
| ) | |
| self.init_output_projection(num_feats, latent_dim) | |
| def forward( | |
| self, | |
| x: Tensor, | |
| timesteps: Tensor, | |
| y: List[Tensor] = None, | |
| mask: Tensor = None, | |
| ) -> Tensor: | |
| mask = mask.logical_not() if mask is not None else None | |
| x = rearrange(x, "b c n -> b n c") | |
| x = self.input_projection(x) | |
| t = self.time_embedding(timesteps) | |
| if y is not None: | |
| y = self.mask_cond(y) | |
| y = self.cond_mapping(y, mask, t) | |
| x = self.backbone(x, y, mask) | |
| x = self.output_projection(x, y) | |
| return rearrange(x, "b n c -> b c n") | |
| def init_conds_mappings(self, num_cond_feats, latent_dim): | |
| raise NotImplementedError( | |
| "This method should be implemented in the derived class" | |
| ) | |
| def init_backbone(self): | |
| raise NotImplementedError( | |
| "This method should be implemented in the derived class" | |
| ) | |
| def cond_mapping(self, cond: List[Tensor], mask: Tensor, t: Tensor) -> Tensor: | |
| raise NotImplementedError( | |
| "This method should be implemented in the derived class" | |
| ) | |
| def backbone(self, x: Tensor, y: Tensor, mask: Tensor) -> Tensor: | |
| raise NotImplementedError( | |
| "This method should be implemented in the derived class" | |
| ) | |
| def mask_cond( | |
| self, cond: List[TensorType["batch_size", "num_cond_feats"]] | |
| ) -> TensorType["batch_size", "num_cond_feats"]: | |
| bs = cond[0].shape[0] | |
| if self.training and self.label_dropout > 0.0: | |
| # 1-> use null_cond, 0-> use real cond | |
| prob = torch.ones(bs, device=cond[0].device) * self.label_dropout | |
| masked_cond = [] | |
| common_mask = torch.bernoulli(prob) # Common to all modalities | |
| for _cond in cond: | |
| modality_mask = torch.bernoulli(prob) # Modality only | |
| mask = torch.clip(common_mask + modality_mask, 0, 1) | |
| mask = mask.view(bs, 1, 1) if _cond.dim() == 3 else mask.view(bs, 1) | |
| masked_cond.append(_cond * (1.0 - mask)) | |
| return masked_cond | |
| else: | |
| return cond | |
| def init_output_projection(self, num_feats, latent_dim): | |
| raise NotImplementedError( | |
| "This method should be implemented in the derived class" | |
| ) | |
| def output_projection(self, x: Tensor, y: Tensor) -> Tensor: | |
| raise NotImplementedError( | |
| "This method should be implemented in the derived class" | |
| ) | |
| class AdaLNDirector(BaseDirector): | |
| def __init__( | |
| self, | |
| name: str, | |
| num_feats: int, | |
| num_cond_feats: int, | |
| num_cams: int, | |
| latent_dim: int, | |
| mlp_multiplier: int, | |
| num_layers: int, | |
| num_heads: int, | |
| dropout: float, | |
| stochastic_depth: float, | |
| label_dropout: float, | |
| num_rawfeats: int, | |
| clip_sequential: bool = False, | |
| cond_sequential: bool = False, | |
| device: str = "cuda", | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| name=name, | |
| num_feats=num_feats, | |
| num_cond_feats=num_cond_feats, | |
| num_cams=num_cams, | |
| latent_dim=latent_dim, | |
| mlp_multiplier=mlp_multiplier, | |
| num_layers=num_layers, | |
| num_heads=num_heads, | |
| dropout=dropout, | |
| stochastic_depth=stochastic_depth, | |
| label_dropout=label_dropout, | |
| num_rawfeats=num_rawfeats, | |
| clip_sequential=clip_sequential, | |
| cond_sequential=cond_sequential, | |
| device=device, | |
| ) | |
| assert not (clip_sequential and cond_sequential) | |
| def init_conds_mappings(self, num_cond_feats, latent_dim): | |
| self.joint_cond_projection = nn.Linear(sum(num_cond_feats), latent_dim) | |
| def cond_mapping(self, cond: List[Tensor], mask: Tensor, t: Tensor) -> Tensor: | |
| c_emb = torch.cat(cond, dim=-1) | |
| return self.joint_cond_projection(c_emb) + t | |
| def init_backbone( | |
| self, | |
| num_layers, | |
| latent_dim, | |
| mlp_multiplier, | |
| num_heads, | |
| dropout, | |
| stochastic_depth, | |
| ): | |
| self.backbone_module = nn.ModuleList( | |
| [ | |
| AdaLNSABlock( | |
| dim_qkv=latent_dim, | |
| dim_cond=latent_dim, | |
| num_heads=num_heads, | |
| mlp_multiplier=mlp_multiplier, | |
| dropout=dropout, | |
| stochastic_depth=stochastic_depth, | |
| use_layernorm16=self.use_layernorm16, | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| def backbone(self, x: Tensor, y: Tensor, mask: Tensor) -> Tensor: | |
| for block in self.backbone_module: | |
| x = block(x, y, mask) | |
| return x | |
| def init_output_projection(self, num_feats, latent_dim): | |
| layer_norm = LayerNorm16Bits if self.use_layernorm16 else nn.LayerNorm | |
| self.final_norm = layer_norm(latent_dim, eps=1e-6, elementwise_affine=False) | |
| self.final_linear = nn.Linear(latent_dim, num_feats, bias=True) | |
| self.final_adaln = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(latent_dim, latent_dim * 2, bias=True), | |
| ) | |
| # Zero init | |
| nn.init.zeros_(self.final_adaln[1].weight) | |
| nn.init.zeros_(self.final_adaln[1].bias) | |
| def output_projection(self, x: Tensor, y: Tensor) -> Tensor: | |
| shift, scale = self.final_adaln(y).chunk(2, dim=-1) | |
| x = modulate_shift_and_scale(self.final_norm(x), shift, scale) | |
| return self.final_linear(x) | |
| class CrossAttentionDirector(BaseDirector): | |
| def __init__( | |
| self, | |
| name: str, | |
| num_feats: int, | |
| num_cond_feats: int, | |
| num_cams: int, | |
| latent_dim: int, | |
| mlp_multiplier: int, | |
| num_layers: int, | |
| num_heads: int, | |
| dropout: float, | |
| stochastic_depth: float, | |
| label_dropout: float, | |
| num_rawfeats: int, | |
| num_text_registers: int, | |
| clip_sequential: bool = True, | |
| cond_sequential: bool = True, | |
| device: str = "cuda", | |
| **kwargs, | |
| ): | |
| self.num_text_registers = num_text_registers | |
| self.num_heads = num_heads | |
| self.dropout = dropout | |
| self.mlp_multiplier = mlp_multiplier | |
| self.stochastic_depth = stochastic_depth | |
| super().__init__( | |
| name=name, | |
| num_feats=num_feats, | |
| num_cond_feats=num_cond_feats, | |
| num_cams=num_cams, | |
| latent_dim=latent_dim, | |
| mlp_multiplier=mlp_multiplier, | |
| num_layers=num_layers, | |
| num_heads=num_heads, | |
| dropout=dropout, | |
| stochastic_depth=stochastic_depth, | |
| label_dropout=label_dropout, | |
| num_rawfeats=num_rawfeats, | |
| clip_sequential=clip_sequential, | |
| cond_sequential=cond_sequential, | |
| device=device, | |
| ) | |
| assert clip_sequential and cond_sequential | |
| def init_conds_mappings(self, num_cond_feats, latent_dim): | |
| self.cond_projection = nn.ModuleList( | |
| [nn.Linear(num_cond_feat, latent_dim) for num_cond_feat in num_cond_feats] | |
| ) | |
| self.cond_registers = nn.Parameter( | |
| torch.randn(self.num_text_registers, latent_dim), requires_grad=True | |
| ) | |
| nn.init.trunc_normal_(self.cond_registers, std=0.02, a=-2 * 0.02, b=2 * 0.02) | |
| self.cond_sa = nn.ModuleList( | |
| [ | |
| SelfAttentionBlock( | |
| dim_qkv=latent_dim, | |
| num_heads=self.num_heads, | |
| mlp_multiplier=self.mlp_multiplier, | |
| dropout=self.dropout, | |
| stochastic_depth=self.stochastic_depth, | |
| use_layernorm16=self.use_layernorm16, | |
| ) | |
| for _ in range(2) | |
| ] | |
| ) | |
| self.cond_positional_embedding = PositionalEncoding(latent_dim, max_len=10000) | |
| def cond_mapping(self, cond: List[Tensor], mask: Tensor, t: Tensor) -> Tensor: | |
| batch_size = cond[0].shape[0] | |
| cond_emb = [ | |
| cond_proj(rearrange(c, "b c n -> b n c")) | |
| for cond_proj, c in zip(self.cond_projection, cond) | |
| ] | |
| cond_emb = [ | |
| self.cond_registers.unsqueeze(0).expand(batch_size, -1, -1), | |
| t.unsqueeze(1), | |
| ] + cond_emb | |
| cond_emb = torch.cat(cond_emb, dim=1) | |
| cond_emb = self.cond_positional_embedding(cond_emb) | |
| for block in self.cond_sa: | |
| cond_emb = block(cond_emb) | |
| return cond_emb | |
| def init_backbone( | |
| self, | |
| num_layers, | |
| latent_dim, | |
| mlp_multiplier, | |
| num_heads, | |
| dropout, | |
| stochastic_depth, | |
| ): | |
| self.backbone_module = nn.ModuleList( | |
| [ | |
| CrossAttentionSABlock( | |
| dim_qkv=latent_dim, | |
| dim_cond=latent_dim, | |
| num_heads=num_heads, | |
| mlp_multiplier=mlp_multiplier, | |
| dropout=dropout, | |
| stochastic_depth=stochastic_depth, | |
| use_layernorm16=self.use_layernorm16, | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| def backbone(self, x: Tensor, y: Tensor, mask: Tensor) -> Tensor: | |
| for block in self.backbone_module: | |
| x = block(x, y, mask, None) | |
| return x | |
| def init_output_projection(self, num_feats, latent_dim): | |
| layer_norm = LayerNorm16Bits if self.use_layernorm16 else nn.LayerNorm | |
| self.final_norm = layer_norm(latent_dim, eps=1e-6) | |
| self.final_linear = nn.Linear(latent_dim, num_feats, bias=True) | |
| def output_projection(self, x: Tensor, y: Tensor) -> Tensor: | |
| return self.final_linear(self.final_norm(x)) | |
| class InContextDirector(BaseDirector): | |
| def __init__( | |
| self, | |
| name: str, | |
| num_feats: int, | |
| num_cond_feats: int, | |
| num_cams: int, | |
| latent_dim: int, | |
| mlp_multiplier: int, | |
| num_layers: int, | |
| num_heads: int, | |
| dropout: float, | |
| stochastic_depth: float, | |
| label_dropout: float, | |
| num_rawfeats: int, | |
| clip_sequential: bool = False, | |
| cond_sequential: bool = False, | |
| device: str = "cuda", | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| name=name, | |
| num_feats=num_feats, | |
| num_cond_feats=num_cond_feats, | |
| num_cams=num_cams, | |
| latent_dim=latent_dim, | |
| mlp_multiplier=mlp_multiplier, | |
| num_layers=num_layers, | |
| num_heads=num_heads, | |
| dropout=dropout, | |
| stochastic_depth=stochastic_depth, | |
| label_dropout=label_dropout, | |
| num_rawfeats=num_rawfeats, | |
| clip_sequential=clip_sequential, | |
| cond_sequential=cond_sequential, | |
| device=device, | |
| ) | |
| def init_conds_mappings(self, num_cond_feats, latent_dim): | |
| self.cond_projection = nn.ModuleList( | |
| [nn.Linear(num_cond_feat, latent_dim) for num_cond_feat in num_cond_feats] | |
| ) | |
| def cond_mapping(self, cond: List[Tensor], mask: Tensor, t: Tensor) -> Tensor: | |
| for i in range(len(cond)): | |
| if cond[i].dim() == 3: | |
| cond[i] = rearrange(cond[i], "b c n -> b n c") | |
| cond_emb = [cond_proj(c) for cond_proj, c in zip(self.cond_projection, cond)] | |
| cond_emb = [c.unsqueeze(1) if c.dim() == 2 else cond_emb for c in cond_emb] | |
| cond_emb = torch.cat([t.unsqueeze(1)] + cond_emb, dim=1) | |
| return cond_emb | |
| def init_backbone( | |
| self, | |
| num_layers, | |
| latent_dim, | |
| mlp_multiplier, | |
| num_heads, | |
| dropout, | |
| stochastic_depth, | |
| ): | |
| self.backbone_module = nn.ModuleList( | |
| [ | |
| SelfAttentionBlock( | |
| dim_qkv=latent_dim, | |
| num_heads=num_heads, | |
| mlp_multiplier=mlp_multiplier, | |
| dropout=dropout, | |
| stochastic_depth=stochastic_depth, | |
| use_layernorm16=self.use_layernorm16, | |
| ) | |
| for _ in range(num_layers) | |
| ] | |
| ) | |
| def backbone(self, x: Tensor, y: Tensor, mask: Tensor) -> Tensor: | |
| bs, n_y, _ = y.shape | |
| mask = torch.cat([torch.ones(bs, n_y, device=y.device), mask], dim=1) | |
| x = torch.cat([y, x], dim=1) | |
| for block in self.backbone_module: | |
| x = block(x, mask) | |
| return x | |
| def init_output_projection(self, num_feats, latent_dim): | |
| layer_norm = LayerNorm16Bits if self.use_layernorm16 else nn.LayerNorm | |
| self.final_norm = layer_norm(latent_dim, eps=1e-6) | |
| self.final_linear = nn.Linear(latent_dim, num_feats, bias=True) | |
| def output_projection(self, x: Tensor, y: Tensor) -> Tensor: | |
| num_y = y.shape[1] | |
| x = x[:, num_y:] | |
| return self.final_linear(self.final_norm(x)) | |