| | |
| |
|
| | from comfy.ldm.modules.attention import optimized_attention |
| | import typing as tp |
| |
|
| | import torch |
| |
|
| | from einops import rearrange |
| | from torch import nn |
| | from torch.nn import functional as F |
| | import math |
| | import comfy.ops |
| |
|
| | class FourierFeatures(nn.Module): |
| | def __init__(self, in_features, out_features, std=1., dtype=None, device=None): |
| | super().__init__() |
| | assert out_features % 2 == 0 |
| | self.weight = nn.Parameter(torch.empty( |
| | [out_features // 2, in_features], dtype=dtype, device=device)) |
| |
|
| | def forward(self, input): |
| | f = 2 * math.pi * input @ comfy.ops.cast_to_input(self.weight.T, input) |
| | return torch.cat([f.cos(), f.sin()], dim=-1) |
| |
|
| | |
| | class LayerNorm(nn.Module): |
| | def __init__(self, dim, bias=False, fix_scale=False, dtype=None, device=None): |
| | """ |
| | bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less |
| | """ |
| | super().__init__() |
| |
|
| | self.gamma = nn.Parameter(torch.empty(dim, dtype=dtype, device=device)) |
| |
|
| | if bias: |
| | self.beta = nn.Parameter(torch.empty(dim, dtype=dtype, device=device)) |
| | else: |
| | self.beta = None |
| |
|
| | def forward(self, x): |
| | beta = self.beta |
| | if beta is not None: |
| | beta = comfy.ops.cast_to_input(beta, x) |
| | return F.layer_norm(x, x.shape[-1:], weight=comfy.ops.cast_to_input(self.gamma, x), bias=beta) |
| |
|
| | class GLU(nn.Module): |
| | def __init__( |
| | self, |
| | dim_in, |
| | dim_out, |
| | activation, |
| | use_conv = False, |
| | conv_kernel_size = 3, |
| | dtype=None, |
| | device=None, |
| | operations=None, |
| | ): |
| | super().__init__() |
| | self.act = activation |
| | self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2), dtype=dtype, device=device) |
| | self.use_conv = use_conv |
| |
|
| | def forward(self, x): |
| | if self.use_conv: |
| | x = rearrange(x, 'b n d -> b d n') |
| | x = self.proj(x) |
| | x = rearrange(x, 'b d n -> b n d') |
| | else: |
| | x = self.proj(x) |
| |
|
| | x, gate = x.chunk(2, dim = -1) |
| | return x * self.act(gate) |
| |
|
| | class AbsolutePositionalEmbedding(nn.Module): |
| | def __init__(self, dim, max_seq_len): |
| | super().__init__() |
| | self.scale = dim ** -0.5 |
| | self.max_seq_len = max_seq_len |
| | self.emb = nn.Embedding(max_seq_len, dim) |
| |
|
| | def forward(self, x, pos = None, seq_start_pos = None): |
| | seq_len, device = x.shape[1], x.device |
| | assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}' |
| |
|
| | if pos is None: |
| | pos = torch.arange(seq_len, device = device) |
| |
|
| | if seq_start_pos is not None: |
| | pos = (pos - seq_start_pos[..., None]).clamp(min = 0) |
| |
|
| | pos_emb = self.emb(pos) |
| | pos_emb = pos_emb * self.scale |
| | return pos_emb |
| |
|
| | class ScaledSinusoidalEmbedding(nn.Module): |
| | def __init__(self, dim, theta = 10000): |
| | super().__init__() |
| | assert (dim % 2) == 0, 'dimension must be divisible by 2' |
| | self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5) |
| |
|
| | half_dim = dim // 2 |
| | freq_seq = torch.arange(half_dim).float() / half_dim |
| | inv_freq = theta ** -freq_seq |
| | self.register_buffer('inv_freq', inv_freq, persistent = False) |
| |
|
| | def forward(self, x, pos = None, seq_start_pos = None): |
| | seq_len, device = x.shape[1], x.device |
| |
|
| | if pos is None: |
| | pos = torch.arange(seq_len, device = device) |
| |
|
| | if seq_start_pos is not None: |
| | pos = pos - seq_start_pos[..., None] |
| |
|
| | emb = torch.einsum('i, j -> i j', pos, self.inv_freq) |
| | emb = torch.cat((emb.sin(), emb.cos()), dim = -1) |
| | return emb * self.scale |
| |
|
| | class RotaryEmbedding(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | use_xpos = False, |
| | scale_base = 512, |
| | interpolation_factor = 1., |
| | base = 10000, |
| | base_rescale_factor = 1., |
| | dtype=None, |
| | device=None, |
| | ): |
| | super().__init__() |
| | |
| | |
| | |
| | base *= base_rescale_factor ** (dim / (dim - 2)) |
| |
|
| | |
| | self.register_buffer('inv_freq', torch.empty((dim // 2,), device=device, dtype=dtype)) |
| |
|
| | assert interpolation_factor >= 1. |
| | self.interpolation_factor = interpolation_factor |
| |
|
| | if not use_xpos: |
| | self.register_buffer('scale', None) |
| | return |
| |
|
| | scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) |
| |
|
| | self.scale_base = scale_base |
| | self.register_buffer('scale', scale) |
| |
|
| | def forward_from_seq_len(self, seq_len, device, dtype): |
| | |
| |
|
| | t = torch.arange(seq_len, device=device, dtype=dtype) |
| | return self.forward(t) |
| |
|
| | def forward(self, t): |
| | |
| | device = t.device |
| | dtype = t.dtype |
| |
|
| | |
| |
|
| | t = t / self.interpolation_factor |
| |
|
| | freqs = torch.einsum('i , j -> i j', t, comfy.ops.cast_to_input(self.inv_freq, t)) |
| | freqs = torch.cat((freqs, freqs), dim = -1) |
| |
|
| | if self.scale is None: |
| | return freqs, 1. |
| |
|
| | power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base |
| | scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1') |
| | scale = torch.cat((scale, scale), dim = -1) |
| |
|
| | return freqs, scale |
| |
|
| | def rotate_half(x): |
| | x = rearrange(x, '... (j d) -> ... j d', j = 2) |
| | x1, x2 = x.unbind(dim = -2) |
| | return torch.cat((-x2, x1), dim = -1) |
| |
|
| | def apply_rotary_pos_emb(t, freqs, scale = 1): |
| | out_dtype = t.dtype |
| |
|
| | |
| | dtype = t.dtype |
| | rot_dim, seq_len = freqs.shape[-1], t.shape[-2] |
| | freqs, t = freqs.to(dtype), t.to(dtype) |
| | freqs = freqs[-seq_len:, :] |
| |
|
| | if t.ndim == 4 and freqs.ndim == 3: |
| | freqs = rearrange(freqs, 'b n d -> b 1 n d') |
| |
|
| | |
| | t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:] |
| | t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) |
| |
|
| | t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype) |
| |
|
| | return torch.cat((t, t_unrotated), dim = -1) |
| |
|
| | class FeedForward(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | dim_out = None, |
| | mult = 4, |
| | no_bias = False, |
| | glu = True, |
| | use_conv = False, |
| | conv_kernel_size = 3, |
| | zero_init_output = True, |
| | dtype=None, |
| | device=None, |
| | operations=None, |
| | ): |
| | super().__init__() |
| | inner_dim = int(dim * mult) |
| |
|
| | |
| |
|
| | activation = nn.SiLU() |
| |
|
| | dim_out = dim if dim_out is None else dim_out |
| |
|
| | if glu: |
| | linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations) |
| | else: |
| | linear_in = nn.Sequential( |
| | Rearrange('b n d -> b d n') if use_conv else nn.Identity(), |
| | operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device), |
| | Rearrange('b n d -> b d n') if use_conv else nn.Identity(), |
| | activation |
| | ) |
| |
|
| | linear_out = operations.Linear(inner_dim, dim_out, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | self.ff = nn.Sequential( |
| | linear_in, |
| | Rearrange('b d n -> b n d') if use_conv else nn.Identity(), |
| | linear_out, |
| | Rearrange('b n d -> b d n') if use_conv else nn.Identity(), |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.ff(x) |
| |
|
| | class Attention(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | dim_heads = 64, |
| | dim_context = None, |
| | causal = False, |
| | zero_init_output=True, |
| | qk_norm = False, |
| | natten_kernel_size = None, |
| | dtype=None, |
| | device=None, |
| | operations=None, |
| | ): |
| | super().__init__() |
| | self.dim = dim |
| | self.dim_heads = dim_heads |
| | self.causal = causal |
| |
|
| | dim_kv = dim_context if dim_context is not None else dim |
| |
|
| | self.num_heads = dim // dim_heads |
| | self.kv_heads = dim_kv // dim_heads |
| |
|
| | if dim_context is not None: |
| | self.to_q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) |
| | self.to_kv = operations.Linear(dim_kv, dim_kv * 2, bias=False, dtype=dtype, device=device) |
| | else: |
| | self.to_qkv = operations.Linear(dim, dim * 3, bias=False, dtype=dtype, device=device) |
| |
|
| | self.to_out = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) |
| |
|
| | |
| | |
| |
|
| | self.qk_norm = qk_norm |
| |
|
| |
|
| | def forward( |
| | self, |
| | x, |
| | context = None, |
| | mask = None, |
| | context_mask = None, |
| | rotary_pos_emb = None, |
| | causal = None |
| | ): |
| | h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None |
| |
|
| | kv_input = context if has_context else x |
| |
|
| | if hasattr(self, 'to_q'): |
| | |
| | q = self.to_q(x) |
| | q = rearrange(q, 'b n (h d) -> b h n d', h = h) |
| |
|
| | k, v = self.to_kv(kv_input).chunk(2, dim=-1) |
| |
|
| | k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v)) |
| | else: |
| | |
| | q, k, v = self.to_qkv(x).chunk(3, dim=-1) |
| | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) |
| |
|
| | |
| | if self.qk_norm: |
| | q = F.normalize(q, dim=-1) |
| | k = F.normalize(k, dim=-1) |
| |
|
| | if rotary_pos_emb is not None and not has_context: |
| | freqs, _ = rotary_pos_emb |
| |
|
| | q_dtype = q.dtype |
| | k_dtype = k.dtype |
| |
|
| | q = q.to(torch.float32) |
| | k = k.to(torch.float32) |
| | freqs = freqs.to(torch.float32) |
| |
|
| | q = apply_rotary_pos_emb(q, freqs) |
| | k = apply_rotary_pos_emb(k, freqs) |
| |
|
| | q = q.to(q_dtype) |
| | k = k.to(k_dtype) |
| |
|
| | input_mask = context_mask |
| |
|
| | if input_mask is None and not has_context: |
| | input_mask = mask |
| |
|
| | |
| | masks = [] |
| | final_attn_mask = None |
| |
|
| | if input_mask is not None: |
| | input_mask = rearrange(input_mask, 'b j -> b 1 1 j') |
| | masks.append(~input_mask) |
| |
|
| | |
| |
|
| | if len(masks) > 0: |
| | final_attn_mask = ~or_reduce(masks) |
| |
|
| | n, device = q.shape[-2], q.device |
| |
|
| | causal = self.causal if causal is None else causal |
| |
|
| | if n == 1 and causal: |
| | causal = False |
| |
|
| | if h != kv_h: |
| | |
| | heads_per_kv_head = h // kv_h |
| | k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) |
| |
|
| | out = optimized_attention(q, k, v, h, skip_reshape=True) |
| | out = self.to_out(out) |
| |
|
| | if mask is not None: |
| | mask = rearrange(mask, 'b n -> b n 1') |
| | out = out.masked_fill(~mask, 0.) |
| |
|
| | return out |
| |
|
| | class ConformerModule(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | norm_kwargs = {}, |
| | ): |
| |
|
| | super().__init__() |
| |
|
| | self.dim = dim |
| |
|
| | self.in_norm = LayerNorm(dim, **norm_kwargs) |
| | self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False) |
| | self.glu = GLU(dim, dim, nn.SiLU()) |
| | self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False) |
| | self.mid_norm = LayerNorm(dim, **norm_kwargs) |
| | self.swish = nn.SiLU() |
| | self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False) |
| |
|
| | def forward(self, x): |
| | x = self.in_norm(x) |
| | x = rearrange(x, 'b n d -> b d n') |
| | x = self.pointwise_conv(x) |
| | x = rearrange(x, 'b d n -> b n d') |
| | x = self.glu(x) |
| | x = rearrange(x, 'b n d -> b d n') |
| | x = self.depthwise_conv(x) |
| | x = rearrange(x, 'b d n -> b n d') |
| | x = self.mid_norm(x) |
| | x = self.swish(x) |
| | x = rearrange(x, 'b n d -> b d n') |
| | x = self.pointwise_conv_2(x) |
| | x = rearrange(x, 'b d n -> b n d') |
| |
|
| | return x |
| |
|
| | class TransformerBlock(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | dim_heads = 64, |
| | cross_attend = False, |
| | dim_context = None, |
| | global_cond_dim = None, |
| | causal = False, |
| | zero_init_branch_outputs = True, |
| | conformer = False, |
| | layer_ix = -1, |
| | remove_norms = False, |
| | attn_kwargs = {}, |
| | ff_kwargs = {}, |
| | norm_kwargs = {}, |
| | dtype=None, |
| | device=None, |
| | operations=None, |
| | ): |
| |
|
| | super().__init__() |
| | self.dim = dim |
| | self.dim_heads = dim_heads |
| | self.cross_attend = cross_attend |
| | self.dim_context = dim_context |
| | self.causal = causal |
| |
|
| | self.pre_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity() |
| |
|
| | self.self_attn = Attention( |
| | dim, |
| | dim_heads = dim_heads, |
| | causal = causal, |
| | zero_init_output=zero_init_branch_outputs, |
| | dtype=dtype, |
| | device=device, |
| | operations=operations, |
| | **attn_kwargs |
| | ) |
| |
|
| | if cross_attend: |
| | self.cross_attend_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity() |
| | self.cross_attn = Attention( |
| | dim, |
| | dim_heads = dim_heads, |
| | dim_context=dim_context, |
| | causal = causal, |
| | zero_init_output=zero_init_branch_outputs, |
| | dtype=dtype, |
| | device=device, |
| | operations=operations, |
| | **attn_kwargs |
| | ) |
| |
|
| | self.ff_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity() |
| | self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, dtype=dtype, device=device, operations=operations,**ff_kwargs) |
| |
|
| | self.layer_ix = layer_ix |
| |
|
| | self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None |
| |
|
| | self.global_cond_dim = global_cond_dim |
| |
|
| | if global_cond_dim is not None: |
| | self.to_scale_shift_gate = nn.Sequential( |
| | nn.SiLU(), |
| | nn.Linear(global_cond_dim, dim * 6, bias=False) |
| | ) |
| |
|
| | nn.init.zeros_(self.to_scale_shift_gate[1].weight) |
| | |
| |
|
| | def forward( |
| | self, |
| | x, |
| | context = None, |
| | global_cond=None, |
| | mask = None, |
| | context_mask = None, |
| | rotary_pos_emb = None |
| | ): |
| | if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: |
| |
|
| | scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1) |
| |
|
| | |
| | residual = x |
| | x = self.pre_norm(x) |
| | x = x * (1 + scale_self) + shift_self |
| | x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb) |
| | x = x * torch.sigmoid(1 - gate_self) |
| | x = x + residual |
| |
|
| | if context is not None: |
| | x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) |
| |
|
| | if self.conformer is not None: |
| | x = x + self.conformer(x) |
| |
|
| | |
| | residual = x |
| | x = self.ff_norm(x) |
| | x = x * (1 + scale_ff) + shift_ff |
| | x = self.ff(x) |
| | x = x * torch.sigmoid(1 - gate_ff) |
| | x = x + residual |
| |
|
| | else: |
| | x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb) |
| |
|
| | if context is not None: |
| | x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask) |
| |
|
| | if self.conformer is not None: |
| | x = x + self.conformer(x) |
| |
|
| | x = x + self.ff(self.ff_norm(x)) |
| |
|
| | return x |
| |
|
| | class ContinuousTransformer(nn.Module): |
| | def __init__( |
| | self, |
| | dim, |
| | depth, |
| | *, |
| | dim_in = None, |
| | dim_out = None, |
| | dim_heads = 64, |
| | cross_attend=False, |
| | cond_token_dim=None, |
| | global_cond_dim=None, |
| | causal=False, |
| | rotary_pos_emb=True, |
| | zero_init_branch_outputs=True, |
| | conformer=False, |
| | use_sinusoidal_emb=False, |
| | use_abs_pos_emb=False, |
| | abs_pos_emb_max_length=10000, |
| | dtype=None, |
| | device=None, |
| | operations=None, |
| | **kwargs |
| | ): |
| |
|
| | super().__init__() |
| |
|
| | self.dim = dim |
| | self.depth = depth |
| | self.causal = causal |
| | self.layers = nn.ModuleList([]) |
| |
|
| | self.project_in = operations.Linear(dim_in, dim, bias=False, dtype=dtype, device=device) if dim_in is not None else nn.Identity() |
| | self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity() |
| |
|
| | if rotary_pos_emb: |
| | self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32), device=device, dtype=dtype) |
| | else: |
| | self.rotary_pos_emb = None |
| |
|
| | self.use_sinusoidal_emb = use_sinusoidal_emb |
| | if use_sinusoidal_emb: |
| | self.pos_emb = ScaledSinusoidalEmbedding(dim) |
| |
|
| | self.use_abs_pos_emb = use_abs_pos_emb |
| | if use_abs_pos_emb: |
| | self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length) |
| |
|
| | for i in range(depth): |
| | self.layers.append( |
| | TransformerBlock( |
| | dim, |
| | dim_heads = dim_heads, |
| | cross_attend = cross_attend, |
| | dim_context = cond_token_dim, |
| | global_cond_dim = global_cond_dim, |
| | causal = causal, |
| | zero_init_branch_outputs = zero_init_branch_outputs, |
| | conformer=conformer, |
| | layer_ix=i, |
| | dtype=dtype, |
| | device=device, |
| | operations=operations, |
| | **kwargs |
| | ) |
| | ) |
| |
|
| | def forward( |
| | self, |
| | x, |
| | mask = None, |
| | prepend_embeds = None, |
| | prepend_mask = None, |
| | global_cond = None, |
| | return_info = False, |
| | **kwargs |
| | ): |
| | batch, seq, device = *x.shape[:2], x.device |
| |
|
| | info = { |
| | "hidden_states": [], |
| | } |
| |
|
| | x = self.project_in(x) |
| |
|
| | if prepend_embeds is not None: |
| | prepend_length, prepend_dim = prepend_embeds.shape[1:] |
| |
|
| | assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension' |
| |
|
| | x = torch.cat((prepend_embeds, x), dim = -2) |
| |
|
| | if prepend_mask is not None or mask is not None: |
| | mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool) |
| | prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool) |
| |
|
| | mask = torch.cat((prepend_mask, mask), dim = -1) |
| |
|
| | |
| |
|
| | if self.rotary_pos_emb is not None: |
| | rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device) |
| | else: |
| | rotary_pos_emb = None |
| |
|
| | if self.use_sinusoidal_emb or self.use_abs_pos_emb: |
| | x = x + self.pos_emb(x) |
| |
|
| | |
| | for layer in self.layers: |
| | x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) |
| | |
| |
|
| | if return_info: |
| | info["hidden_states"].append(x) |
| |
|
| | x = self.project_out(x) |
| |
|
| | if return_info: |
| | return x, info |
| |
|
| | return x |
| |
|
| | class AudioDiffusionTransformer(nn.Module): |
| | def __init__(self, |
| | io_channels=64, |
| | patch_size=1, |
| | embed_dim=1536, |
| | cond_token_dim=768, |
| | project_cond_tokens=False, |
| | global_cond_dim=1536, |
| | project_global_cond=True, |
| | input_concat_dim=0, |
| | prepend_cond_dim=0, |
| | depth=24, |
| | num_heads=24, |
| | transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer", |
| | global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend", |
| | audio_model="", |
| | dtype=None, |
| | device=None, |
| | operations=None, |
| | **kwargs): |
| |
|
| | super().__init__() |
| |
|
| | self.dtype = dtype |
| | self.cond_token_dim = cond_token_dim |
| |
|
| | |
| | timestep_features_dim = 256 |
| |
|
| | self.timestep_features = FourierFeatures(1, timestep_features_dim, dtype=dtype, device=device) |
| |
|
| | self.to_timestep_embed = nn.Sequential( |
| | operations.Linear(timestep_features_dim, embed_dim, bias=True, dtype=dtype, device=device), |
| | nn.SiLU(), |
| | operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device), |
| | ) |
| |
|
| | if cond_token_dim > 0: |
| | |
| |
|
| | cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim |
| | self.to_cond_embed = nn.Sequential( |
| | operations.Linear(cond_token_dim, cond_embed_dim, bias=False, dtype=dtype, device=device), |
| | nn.SiLU(), |
| | operations.Linear(cond_embed_dim, cond_embed_dim, bias=False, dtype=dtype, device=device) |
| | ) |
| | else: |
| | cond_embed_dim = 0 |
| |
|
| | if global_cond_dim > 0: |
| | |
| | global_embed_dim = global_cond_dim if not project_global_cond else embed_dim |
| | self.to_global_embed = nn.Sequential( |
| | operations.Linear(global_cond_dim, global_embed_dim, bias=False, dtype=dtype, device=device), |
| | nn.SiLU(), |
| | operations.Linear(global_embed_dim, global_embed_dim, bias=False, dtype=dtype, device=device) |
| | ) |
| |
|
| | if prepend_cond_dim > 0: |
| | |
| | self.to_prepend_embed = nn.Sequential( |
| | operations.Linear(prepend_cond_dim, embed_dim, bias=False, dtype=dtype, device=device), |
| | nn.SiLU(), |
| | operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device) |
| | ) |
| |
|
| | self.input_concat_dim = input_concat_dim |
| |
|
| | dim_in = io_channels + self.input_concat_dim |
| |
|
| | self.patch_size = patch_size |
| |
|
| | |
| |
|
| | self.transformer_type = transformer_type |
| |
|
| | self.global_cond_type = global_cond_type |
| |
|
| | if self.transformer_type == "continuous_transformer": |
| |
|
| | global_dim = None |
| |
|
| | if self.global_cond_type == "adaLN": |
| | |
| | global_dim = embed_dim |
| |
|
| | self.transformer = ContinuousTransformer( |
| | dim=embed_dim, |
| | depth=depth, |
| | dim_heads=embed_dim // num_heads, |
| | dim_in=dim_in * patch_size, |
| | dim_out=io_channels * patch_size, |
| | cross_attend = cond_token_dim > 0, |
| | cond_token_dim = cond_embed_dim, |
| | global_cond_dim=global_dim, |
| | dtype=dtype, |
| | device=device, |
| | operations=operations, |
| | **kwargs |
| | ) |
| | else: |
| | raise ValueError(f"Unknown transformer type: {self.transformer_type}") |
| |
|
| | self.preprocess_conv = operations.Conv1d(dim_in, dim_in, 1, bias=False, dtype=dtype, device=device) |
| | self.postprocess_conv = operations.Conv1d(io_channels, io_channels, 1, bias=False, dtype=dtype, device=device) |
| |
|
| | def _forward( |
| | self, |
| | x, |
| | t, |
| | mask=None, |
| | cross_attn_cond=None, |
| | cross_attn_cond_mask=None, |
| | input_concat_cond=None, |
| | global_embed=None, |
| | prepend_cond=None, |
| | prepend_cond_mask=None, |
| | return_info=False, |
| | **kwargs): |
| |
|
| | if cross_attn_cond is not None: |
| | cross_attn_cond = self.to_cond_embed(cross_attn_cond) |
| |
|
| | if global_embed is not None: |
| | |
| | global_embed = self.to_global_embed(global_embed) |
| |
|
| | prepend_inputs = None |
| | prepend_mask = None |
| | prepend_length = 0 |
| | if prepend_cond is not None: |
| | |
| | prepend_cond = self.to_prepend_embed(prepend_cond) |
| |
|
| | prepend_inputs = prepend_cond |
| | if prepend_cond_mask is not None: |
| | prepend_mask = prepend_cond_mask |
| |
|
| | if input_concat_cond is not None: |
| |
|
| | |
| | if input_concat_cond.shape[2] != x.shape[2]: |
| | input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest') |
| |
|
| | x = torch.cat([x, input_concat_cond], dim=1) |
| |
|
| | |
| | timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None]).to(x.dtype)) |
| |
|
| | |
| | if global_embed is not None: |
| | global_embed = global_embed + timestep_embed |
| | else: |
| | global_embed = timestep_embed |
| |
|
| | |
| | if self.global_cond_type == "prepend": |
| | if prepend_inputs is None: |
| | |
| | prepend_inputs = global_embed.unsqueeze(1) |
| | prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool) |
| | else: |
| | |
| | prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1) |
| | prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1) |
| |
|
| | prepend_length = prepend_inputs.shape[1] |
| |
|
| | x = self.preprocess_conv(x) + x |
| |
|
| | x = rearrange(x, "b c t -> b t c") |
| |
|
| | extra_args = {} |
| |
|
| | if self.global_cond_type == "adaLN": |
| | extra_args["global_cond"] = global_embed |
| |
|
| | if self.patch_size > 1: |
| | x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size) |
| |
|
| | if self.transformer_type == "x-transformers": |
| | output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs) |
| | elif self.transformer_type == "continuous_transformer": |
| | output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs) |
| |
|
| | if return_info: |
| | output, info = output |
| | elif self.transformer_type == "mm_transformer": |
| | output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs) |
| |
|
| | output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:] |
| |
|
| | if self.patch_size > 1: |
| | output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size) |
| |
|
| | output = self.postprocess_conv(output) + output |
| |
|
| | if return_info: |
| | return output, info |
| |
|
| | return output |
| |
|
| | def forward( |
| | self, |
| | x, |
| | timestep, |
| | context=None, |
| | context_mask=None, |
| | input_concat_cond=None, |
| | global_embed=None, |
| | negative_global_embed=None, |
| | prepend_cond=None, |
| | prepend_cond_mask=None, |
| | mask=None, |
| | return_info=False, |
| | control=None, |
| | transformer_options={}, |
| | **kwargs): |
| | return self._forward( |
| | x, |
| | timestep, |
| | cross_attn_cond=context, |
| | cross_attn_cond_mask=context_mask, |
| | input_concat_cond=input_concat_cond, |
| | global_embed=global_embed, |
| | prepend_cond=prepend_cond, |
| | prepend_cond_mask=prepend_cond_mask, |
| | mask=mask, |
| | return_info=return_info, |
| | **kwargs |
| | ) |
| |
|