Seemanth's picture
Upload Chiluka TTS model
f28049f verified
"""Diffusion transformer modules."""
from math import log, pi
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange
from torch import Tensor, einsum
from .utils import exists, default, rand_bool
class AdaLayerNorm(nn.Module):
def __init__(self, style_dim, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.fc = nn.Linear(style_dim, channels * 2)
def forward(self, x, s):
x = x.transpose(-1, -2)
x = x.transpose(1, -1)
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), eps=self.eps)
x = (1 + gamma) * x + beta
return x.transpose(1, -1).transpose(-1, -2)
class LearnedPositionalEmbedding(nn.Module):
def __init__(self, dim: int):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, x: Tensor) -> Tensor:
x = rearrange(x, "b -> b 1")
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered = torch.cat((x, fouriered), dim=-1)
return fouriered
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
return nn.Sequential(
LearnedPositionalEmbedding(dim),
nn.Linear(in_features=dim + 1, out_features=out_features),
)
class FixedEmbedding(nn.Module):
def __init__(self, max_length: int, features: int):
super().__init__()
self.max_length = max_length
self.embedding = nn.Embedding(max_length, features)
def forward(self, x: Tensor) -> Tensor:
batch_size, length, device = *x.shape[0:2], x.device
assert length <= self.max_length, "Input sequence length must be <= max_length"
position = torch.arange(length, device=device)
fixed_embedding = self.embedding(position)
fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
return fixed_embedding
class RelativePositionBias(nn.Module):
def __init__(self, num_buckets: int, max_distance: int, num_heads: int):
super().__init__()
self.num_buckets = num_buckets
self.max_distance = max_distance
self.num_heads = num_heads
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
@staticmethod
def _relative_position_bucket(relative_position: Tensor, num_buckets: int, max_distance: int):
num_buckets //= 2
ret = (relative_position >= 0).to(torch.long) * num_buckets
n = torch.abs(relative_position)
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (torch.log(n.float() / max_exact) / log(max_distance / max_exact) * (num_buckets - max_exact)).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def forward(self, num_queries: int, num_keys: int) -> Tensor:
i, j, device = num_queries, num_keys, self.relative_attention_bias.weight.device
q_pos = torch.arange(j - i, j, dtype=torch.long, device=device)
k_pos = torch.arange(j, dtype=torch.long, device=device)
rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1")
relative_position_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance)
bias = self.relative_attention_bias(relative_position_bucket)
bias = rearrange(bias, "m n h -> 1 h m n")
return bias
def FeedForward(features: int, multiplier: int) -> nn.Module:
mid_features = features * multiplier
return nn.Sequential(
nn.Linear(in_features=features, out_features=mid_features),
nn.GELU(),
nn.Linear(in_features=mid_features, out_features=features),
)
class AttentionBase(nn.Module):
def __init__(self, features: int, *, head_features: int, num_heads: int, use_rel_pos: bool,
out_features: Optional[int] = None, rel_pos_num_buckets: Optional[int] = None,
rel_pos_max_distance: Optional[int] = None):
super().__init__()
self.scale = head_features ** -0.5
self.num_heads = num_heads
self.use_rel_pos = use_rel_pos
mid_features = head_features * num_heads
if use_rel_pos:
assert exists(rel_pos_num_buckets) and exists(rel_pos_max_distance)
self.rel_pos = RelativePositionBias(num_buckets=rel_pos_num_buckets, max_distance=rel_pos_max_distance, num_heads=num_heads)
if out_features is None:
out_features = features
self.to_out = nn.Linear(in_features=mid_features, out_features=out_features)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
h = self.num_heads
q = rearrange(q, "b n (h d) -> b h n d", h=h)
k = rearrange(k, "b n (h d) -> b h n d", h=h)
v = rearrange(v, "b n (h d) -> b h n d", h=h)
sim = einsum("b h n d, b h m d -> b h n m", q, k)
sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim
sim = sim * self.scale
attn = sim.softmax(dim=-1)
out = einsum("b h n m, b h m d -> b h n d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class StyleAttention(nn.Module):
def __init__(self, features: int, *, style_dim: int, head_features: int, num_heads: int,
context_features: Optional[int] = None, use_rel_pos: bool,
rel_pos_num_buckets: Optional[int] = None, rel_pos_max_distance: Optional[int] = None):
super().__init__()
self.context_features = context_features
mid_features = head_features * num_heads
context_features = default(context_features, features)
self.norm = AdaLayerNorm(style_dim, features)
self.norm_context = AdaLayerNorm(style_dim, context_features)
self.to_q = nn.Linear(in_features=features, out_features=mid_features, bias=False)
self.to_kv = nn.Linear(in_features=context_features, out_features=mid_features * 2, bias=False)
self.attention = AttentionBase(features, num_heads=num_heads, head_features=head_features,
use_rel_pos=use_rel_pos, rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_distance=rel_pos_max_distance)
def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
context = default(context, x)
x, context = self.norm(x, s), self.norm_context(context, s)
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
return self.attention(q, k, v)
class Attention(nn.Module):
def __init__(self, features: int, *, head_features: int, num_heads: int, out_features: Optional[int] = None,
context_features: Optional[int] = None, use_rel_pos: bool,
rel_pos_num_buckets: Optional[int] = None, rel_pos_max_distance: Optional[int] = None):
super().__init__()
self.context_features = context_features
mid_features = head_features * num_heads
context_features = default(context_features, features)
self.norm = nn.LayerNorm(features)
self.norm_context = nn.LayerNorm(context_features)
self.to_q = nn.Linear(in_features=features, out_features=mid_features, bias=False)
self.to_kv = nn.Linear(in_features=context_features, out_features=mid_features * 2, bias=False)
self.attention = AttentionBase(features, out_features=out_features, num_heads=num_heads, head_features=head_features,
use_rel_pos=use_rel_pos, rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_distance=rel_pos_max_distance)
def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
context = default(context, x)
x, context = self.norm(x), self.norm_context(context)
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
return self.attention(q, k, v)
class StyleTransformerBlock(nn.Module):
def __init__(self, features: int, num_heads: int, head_features: int, style_dim: int, multiplier: int,
use_rel_pos: bool, rel_pos_num_buckets: Optional[int] = None,
rel_pos_max_distance: Optional[int] = None, context_features: Optional[int] = None):
super().__init__()
self.use_cross_attention = exists(context_features) and context_features > 0
self.attention = StyleAttention(features=features, style_dim=style_dim, num_heads=num_heads, head_features=head_features,
use_rel_pos=use_rel_pos, rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_distance=rel_pos_max_distance)
if self.use_cross_attention:
self.cross_attention = StyleAttention(features=features, style_dim=style_dim, num_heads=num_heads, head_features=head_features,
context_features=context_features, use_rel_pos=use_rel_pos,
rel_pos_num_buckets=rel_pos_num_buckets, rel_pos_max_distance=rel_pos_max_distance)
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
x = self.attention(x, s) + x
if self.use_cross_attention:
x = self.cross_attention(x, s, context=context) + x
x = self.feed_forward(x) + x
return x
class TransformerBlock(nn.Module):
def __init__(self, features: int, num_heads: int, head_features: int, multiplier: int, use_rel_pos: bool,
rel_pos_num_buckets: Optional[int] = None, rel_pos_max_distance: Optional[int] = None,
context_features: Optional[int] = None):
super().__init__()
self.use_cross_attention = exists(context_features) and context_features > 0
self.attention = Attention(features=features, num_heads=num_heads, head_features=head_features,
use_rel_pos=use_rel_pos, rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_distance=rel_pos_max_distance)
if self.use_cross_attention:
self.cross_attention = Attention(features=features, num_heads=num_heads, head_features=head_features,
context_features=context_features, use_rel_pos=use_rel_pos,
rel_pos_num_buckets=rel_pos_num_buckets, rel_pos_max_distance=rel_pos_max_distance)
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
x = self.attention(x) + x
if self.use_cross_attention:
x = self.cross_attention(x, context=context) + x
x = self.feed_forward(x) + x
return x
class StyleTransformer1d(nn.Module):
def __init__(self, num_layers: int, channels: int, num_heads: int, head_features: int, multiplier: int,
use_context_time: bool = True, use_rel_pos: bool = False, context_features_multiplier: int = 1,
rel_pos_num_buckets: Optional[int] = None, rel_pos_max_distance: Optional[int] = None,
context_features: Optional[int] = None, context_embedding_features: Optional[int] = None,
embedding_max_length: int = 512):
super().__init__()
self.blocks = nn.ModuleList([
StyleTransformerBlock(features=channels + context_embedding_features, head_features=head_features, num_heads=num_heads,
multiplier=multiplier, style_dim=context_features, use_rel_pos=use_rel_pos,
rel_pos_num_buckets=rel_pos_num_buckets, rel_pos_max_distance=rel_pos_max_distance)
for _ in range(num_layers)
])
self.to_out = nn.Sequential(
Rearrange("b t c -> b c t"),
nn.Conv1d(in_channels=channels + context_embedding_features, out_channels=channels, kernel_size=1),
)
use_context_features = exists(context_features)
self.use_context_features = use_context_features
self.use_context_time = use_context_time
if use_context_time or use_context_features:
context_mapping_features = channels + context_embedding_features
self.to_mapping = nn.Sequential(nn.Linear(context_mapping_features, context_mapping_features), nn.GELU(),
nn.Linear(context_mapping_features, context_mapping_features), nn.GELU())
if use_context_time:
self.to_time = nn.Sequential(TimePositionalEmbedding(dim=channels, out_features=context_mapping_features), nn.GELU())
if use_context_features:
self.to_features = nn.Sequential(nn.Linear(in_features=context_features, out_features=context_mapping_features), nn.GELU())
self.fixed_embedding = FixedEmbedding(max_length=embedding_max_length, features=context_embedding_features)
def get_mapping(self, time: Optional[Tensor] = None, features: Optional[Tensor] = None) -> Optional[Tensor]:
items, mapping = [], None
if self.use_context_time:
items += [self.to_time(time)]
if self.use_context_features:
items += [self.to_features(features)]
if self.use_context_time or self.use_context_features:
mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
mapping = self.to_mapping(mapping)
return mapping
def run(self, x, time, embedding, features):
mapping = self.get_mapping(time, features)
x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
for block in self.blocks:
x = x + mapping
x = block(x, features)
x = x.mean(axis=1).unsqueeze(1)
x = self.to_out(x)
x = x.transpose(-1, -2)
return x
def forward(self, x: Tensor, time: Tensor, embedding_mask_proba: float = 0.0, embedding: Optional[Tensor] = None,
features: Optional[Tensor] = None, embedding_scale: float = 1.0) -> Tensor:
b, device = embedding.shape[0], embedding.device
fixed_embedding = self.fixed_embedding(embedding)
if embedding_mask_proba > 0.0:
batch_mask = rand_bool(shape=(b, 1, 1), proba=embedding_mask_proba, device=device)
embedding = torch.where(batch_mask, fixed_embedding, embedding)
if embedding_scale != 1.0:
out = self.run(x, time, embedding=embedding, features=features)
out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
return out_masked + (out - out_masked) * embedding_scale
else:
return self.run(x, time, embedding=embedding, features=features)
class Transformer1d(nn.Module):
def __init__(self, num_layers: int, channels: int, num_heads: int, head_features: int, multiplier: int,
use_context_time: bool = True, use_rel_pos: bool = False, context_features_multiplier: int = 1,
rel_pos_num_buckets: Optional[int] = None, rel_pos_max_distance: Optional[int] = None,
context_features: Optional[int] = None, context_embedding_features: Optional[int] = None,
embedding_max_length: int = 512):
super().__init__()
self.blocks = nn.ModuleList([
TransformerBlock(features=channels + context_embedding_features, head_features=head_features, num_heads=num_heads,
multiplier=multiplier, use_rel_pos=use_rel_pos, rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_distance=rel_pos_max_distance)
for _ in range(num_layers)
])
self.to_out = nn.Sequential(
Rearrange("b t c -> b c t"),
nn.Conv1d(in_channels=channels + context_embedding_features, out_channels=channels, kernel_size=1),
)
use_context_features = exists(context_features)
self.use_context_features = use_context_features
self.use_context_time = use_context_time
if use_context_time or use_context_features:
context_mapping_features = channels + context_embedding_features
self.to_mapping = nn.Sequential(nn.Linear(context_mapping_features, context_mapping_features), nn.GELU(),
nn.Linear(context_mapping_features, context_mapping_features), nn.GELU())
if use_context_time:
self.to_time = nn.Sequential(TimePositionalEmbedding(dim=channels, out_features=context_mapping_features), nn.GELU())
if use_context_features:
self.to_features = nn.Sequential(nn.Linear(in_features=context_features, out_features=context_mapping_features), nn.GELU())
self.fixed_embedding = FixedEmbedding(max_length=embedding_max_length, features=context_embedding_features)
def get_mapping(self, time: Optional[Tensor] = None, features: Optional[Tensor] = None) -> Optional[Tensor]:
items, mapping = [], None
if self.use_context_time:
items += [self.to_time(time)]
if self.use_context_features:
items += [self.to_features(features)]
if self.use_context_time or self.use_context_features:
mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
mapping = self.to_mapping(mapping)
return mapping
def run(self, x, time, embedding, features):
mapping = self.get_mapping(time, features)
x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
for block in self.blocks:
x = x + mapping
x = block(x)
x = x.mean(axis=1).unsqueeze(1)
x = self.to_out(x)
x = x.transpose(-1, -2)
return x
def forward(self, x: Tensor, time: Tensor, embedding_mask_proba: float = 0.0, embedding: Optional[Tensor] = None,
features: Optional[Tensor] = None, embedding_scale: float = 1.0) -> Tensor:
b, device = embedding.shape[0], embedding.device
fixed_embedding = self.fixed_embedding(embedding)
if embedding_mask_proba > 0.0:
batch_mask = rand_bool(shape=(b, 1, 1), proba=embedding_mask_proba, device=device)
embedding = torch.where(batch_mask, fixed_embedding, embedding)
if embedding_scale != 1.0:
out = self.run(x, time, embedding=embedding, features=features)
out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
return out_masked + (out - out_masked) * embedding_scale
else:
return self.run(x, time, embedding=embedding, features=features)