fx pos sinus
Browse files- README.md +4 -0
- audiocraft/transformer.py +50 -67
README.md
CHANGED
|
@@ -67,6 +67,10 @@ CUDA_DEVICE_ORDER=PCI_BUS_ID HF_HOME=/data/dkounadis/.hf7/ CUDA_VISIBLE_DEVICES=
|
|
| 67 |
|
| 68 |
Following examples need `api.py` to be running. [Set this IP](https://huggingface.co/dkounadis/artificial-styletts2/blob/main/tts.py#L85) to the IP shown when starting `api.py`.
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
```
|
| 71 |
python tts.py --text assets/ocr.txt --image assets/ocr.jpg --soundscape "battle hero" --voice romanian
|
| 72 |
```
|
|
|
|
| 67 |
|
| 68 |
Following examples need `api.py` to be running. [Set this IP](https://huggingface.co/dkounadis/artificial-styletts2/blob/main/tts.py#L85) to the IP shown when starting `api.py`.
|
| 69 |
|
| 70 |
+
### Foreign Lang TTS
|
| 71 |
+
|
| 72 |
+
This will produce the following [video](https://www.youtube.com/watch?v=UeJEAsKxRZU)
|
| 73 |
+
|
| 74 |
```
|
| 75 |
python tts.py --text assets/ocr.txt --image assets/ocr.jpg --soundscape "battle hero" --voice romanian
|
| 76 |
```
|
audiocraft/transformer.py
CHANGED
|
@@ -3,44 +3,30 @@ import torch.nn as nn
|
|
| 3 |
from torch.nn import functional as F
|
| 4 |
from einops import rearrange
|
| 5 |
|
| 6 |
-
def create_sin_embedding(positions,
|
| 7 |
-
dim,
|
| 8 |
-
max_period = 10000,
|
| 9 |
-
dtype = torch.float32):
|
| 10 |
-
"""Create sinusoidal positional embedding, with shape `[B, T, C]`.
|
| 11 |
-
|
| 12 |
-
Args:
|
| 13 |
-
positions (torch.Tensor): LongTensor of positions.
|
| 14 |
-
dim (int): Dimension of the embedding.
|
| 15 |
-
max_period (float): Maximum period of the cosine/sine functions.
|
| 16 |
-
dtype (torch.dtype or str): dtype to use to generate the embedding.
|
| 17 |
-
Returns:
|
| 18 |
-
torch.Tensor: Sinusoidal positional embedding.
|
| 19 |
-
"""
|
| 20 |
-
# We aim for BTC format
|
| 21 |
assert dim % 2 == 0
|
| 22 |
half_dim = dim // 2
|
| 23 |
-
positions = positions.to(
|
| 24 |
-
adim = torch.arange(half_dim, device=positions.device, dtype=
|
| 25 |
-
max_period_tensor = torch.full([], max_period, device=positions.device, dtype=
|
| 26 |
phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
|
| 27 |
-
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
|
| 28 |
|
| 29 |
|
| 30 |
class StreamingMultiheadAttention(nn.Module):
|
| 31 |
|
| 32 |
-
def __init__(self,
|
| 33 |
-
embed_dim,
|
| 34 |
num_heads,
|
| 35 |
cross_attention = False,
|
| 36 |
):
|
| 37 |
-
|
| 38 |
super().__init__()
|
| 39 |
-
|
| 40 |
self.cross_attention = cross_attention
|
| 41 |
self.embed_dim = embed_dim
|
| 42 |
self.k_history = None # previous k from the previous tokens seen in the current generation - only for selt.attn
|
| 43 |
-
self.v_history = None # clean up IN LM after finishing GENERATION - Each 1...47 mha has different kv history
|
| 44 |
self.num_heads = num_heads
|
| 45 |
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
| 46 |
self.register_buffer('in_proj_weight', torch.ones((3 * embed_dim, embed_dim),
|
|
@@ -52,58 +38,56 @@ class StreamingMultiheadAttention(nn.Module):
|
|
| 52 |
value=None):
|
| 53 |
layout = "b h t d"
|
| 54 |
if self.cross_attention:
|
| 55 |
-
|
| 56 |
# Different queries, keys, values, we have to spit manually the in_proj_weight
|
| 57 |
-
|
| 58 |
dim = self.in_proj_weight.shape[0] // 3
|
| 59 |
-
|
| 60 |
q = nn.functional.linear(query, self.in_proj_weight[:dim])
|
| 61 |
k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim])
|
| 62 |
v = nn.functional.linear(value, self.in_proj_weight[2 * dim:])
|
| 63 |
-
|
| 64 |
q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
|
| 65 |
# print(q.shape, k.shape, v.shape, q.sum(), k.sum(), v.sum(),'CROSS A5')
|
| 66 |
else:
|
| 67 |
# 1st projected makes k,v (instantaneous)
|
| 68 |
# Here else is self_attention for audio with itself (above is cross attention txt)
|
| 69 |
-
|
| 70 |
-
|
| 71 |
# HISTORY - DIFFERENT FOR EACH TRANSF LAYER
|
| 72 |
-
|
| 73 |
projected = nn.functional.linear(query, self.in_proj_weight)
|
| 74 |
|
| 75 |
bound_layout = "b h p t d"
|
| 76 |
packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
|
| 77 |
q, k, v = packed.unbind(dim=2)
|
| 78 |
|
| 79 |
-
|
| 80 |
if self.k_history is not None:
|
| 81 |
-
# flush
|
| 82 |
if self.k_history.shape[2] > 71:
|
| 83 |
|
| 84 |
self.k_history = torch.cat([self.k_history[:, :, :4, :], self.k_history[:, :, -1:, :]], 2)
|
| 85 |
self.v_history = torch.cat([self.v_history[:, :, :4, :], self.v_history[:, :, -1:, :]], 2)
|
| 86 |
-
# fill new k/v
|
| 87 |
self.k_history = torch.cat([self.k_history, k], 2) # IF ctrl^c here during live demo it is non-atomic k!=v
|
| 88 |
self.v_history = torch.cat([self.v_history, v], 2) # thus it will try to continue with incompatible k/v dims!
|
| 89 |
-
|
| 90 |
-
else:
|
| 91 |
# init
|
| 92 |
self.k_history = k
|
| 93 |
-
self.v_history = v
|
| 94 |
# For self attn prepare
|
| 95 |
k = self.k_history
|
| 96 |
v = self.v_history
|
| 97 |
|
| 98 |
|
| 99 |
-
|
| 100 |
# KV COMPLETION ONLY ON SELF ATTENTION
|
| 101 |
|
| 102 |
x = torch.nn.functional.scaled_dot_product_attention(
|
| 103 |
q, k, v, is_causal=False, dropout_p=0
|
| 104 |
)
|
| 105 |
-
|
| 106 |
-
x = x.to(q.dtype)
|
| 107 |
x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
|
| 108 |
x = self.out_proj(x)
|
| 109 |
return x
|
|
@@ -111,14 +95,14 @@ class StreamingMultiheadAttention(nn.Module):
|
|
| 111 |
|
| 112 |
class StreamingTransformerLayer(nn.Module):
|
| 113 |
|
| 114 |
-
def __init__(self,
|
| 115 |
-
d_model,
|
| 116 |
-
num_heads,
|
| 117 |
dim_feedforward):
|
| 118 |
-
|
| 119 |
-
|
| 120 |
super().__init__()
|
| 121 |
-
|
| 122 |
self.self_attn = StreamingMultiheadAttention(embed_dim=d_model,
|
| 123 |
num_heads=num_heads)
|
| 124 |
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
|
|
@@ -126,7 +110,7 @@ class StreamingTransformerLayer(nn.Module):
|
|
| 126 |
self.cross_attention = StreamingMultiheadAttention(embed_dim=d_model,
|
| 127 |
num_heads=num_heads,
|
| 128 |
cross_attention=True)
|
| 129 |
-
self.norm_cross = nn.LayerNorm(d_model, eps=1e-5)
|
| 130 |
self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
|
| 131 |
self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
|
| 132 |
|
|
@@ -135,30 +119,30 @@ class StreamingTransformerLayer(nn.Module):
|
|
| 135 |
src,
|
| 136 |
cross_attention_src=None): # txtcond
|
| 137 |
'''T is saved float16 weights - should we cast src to float16'''
|
| 138 |
-
|
| 139 |
x = src
|
| 140 |
-
|
| 141 |
x = x + self.self_attn(self.norm1(x))
|
| 142 |
-
|
| 143 |
if cross_attention_src is not None:
|
| 144 |
x = x + self.cross_attention(
|
| 145 |
-
query = self.norm_cross(x),
|
| 146 |
-
key = cross_attention_src,
|
| 147 |
value = cross_attention_src) # txtcondition
|
| 148 |
-
|
| 149 |
x = x + self.linear2(F.gelu(self.linear1( self.norm2(x) )))
|
| 150 |
return x
|
| 151 |
|
| 152 |
|
| 153 |
class StreamingTransformer(nn.Module):
|
| 154 |
|
| 155 |
-
def __init__(self,
|
| 156 |
-
d_model=1536,
|
| 157 |
-
num_heads=24,
|
| 158 |
-
num_layers=48,
|
| 159 |
dim_feedforward=6144,
|
| 160 |
cross_attention = True,
|
| 161 |
-
positional_embedding: str = 'sin',
|
| 162 |
max_period: float = 10_000
|
| 163 |
):
|
| 164 |
super().__init__()
|
|
@@ -170,23 +154,22 @@ class StreamingTransformer(nn.Module):
|
|
| 170 |
for idx in range(num_layers):
|
| 171 |
self.layers.append(
|
| 172 |
StreamingTransformerLayer(
|
| 173 |
-
d_model=d_model,
|
| 174 |
-
num_heads=num_heads,
|
| 175 |
dim_feedforward=dim_feedforward
|
| 176 |
)
|
| 177 |
)
|
| 178 |
|
| 179 |
-
def forward(self,
|
| 180 |
-
x,
|
| 181 |
-
token_count=None,
|
| 182 |
cross_attention_src=None):
|
| 183 |
-
|
| 184 |
if self.positional_embedding in ['sin', 'sin_rope']:
|
| 185 |
-
|
| 186 |
-
|
| 187 |
x = x + pos_emb
|
| 188 |
for j, lay in enumerate(self.layers):
|
| 189 |
-
# print(f'Transf Layer c{j} {pos_emb.sum()=} {pos_emb.shape=}{x.sum()=}___________________')
|
| 190 |
x = lay(x, cross_attention_src=cross_attention_src) # cross_attention_src = txt-cond x audio
|
| 191 |
# self attn = audio x audio
|
| 192 |
# Every layer (mha) keeps itsw own kv cachE
|
|
|
|
| 3 |
from torch.nn import functional as F
|
| 4 |
from einops import rearrange
|
| 5 |
|
| 6 |
+
def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
assert dim % 2 == 0
|
| 8 |
half_dim = dim // 2
|
| 9 |
+
positions = positions.to(torch.float)
|
| 10 |
+
adim = torch.arange(half_dim, device=positions.device, dtype=torch.float).view(1, 1, -1)
|
| 11 |
+
max_period_tensor = torch.full([], max_period, device=positions.device, dtype=torch.float) # avoid sync point
|
| 12 |
phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
|
| 13 |
+
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) # OFFICIAL is torch.float32 HOWEVER self_attn.in_prod_weight = torch.float16
|
| 14 |
|
| 15 |
|
| 16 |
class StreamingMultiheadAttention(nn.Module):
|
| 17 |
|
| 18 |
+
def __init__(self,
|
| 19 |
+
embed_dim,
|
| 20 |
num_heads,
|
| 21 |
cross_attention = False,
|
| 22 |
):
|
| 23 |
+
|
| 24 |
super().__init__()
|
| 25 |
+
|
| 26 |
self.cross_attention = cross_attention
|
| 27 |
self.embed_dim = embed_dim
|
| 28 |
self.k_history = None # previous k from the previous tokens seen in the current generation - only for selt.attn
|
| 29 |
+
self.v_history = None # clean up IN LM after finishing GENERATION - Each 1...47 mha has different kv history
|
| 30 |
self.num_heads = num_heads
|
| 31 |
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
| 32 |
self.register_buffer('in_proj_weight', torch.ones((3 * embed_dim, embed_dim),
|
|
|
|
| 38 |
value=None):
|
| 39 |
layout = "b h t d"
|
| 40 |
if self.cross_attention:
|
| 41 |
+
|
| 42 |
# Different queries, keys, values, we have to spit manually the in_proj_weight
|
| 43 |
+
|
| 44 |
dim = self.in_proj_weight.shape[0] // 3
|
| 45 |
+
|
| 46 |
q = nn.functional.linear(query, self.in_proj_weight[:dim])
|
| 47 |
k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim])
|
| 48 |
v = nn.functional.linear(value, self.in_proj_weight[2 * dim:])
|
| 49 |
+
|
| 50 |
q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
|
| 51 |
# print(q.shape, k.shape, v.shape, q.sum(), k.sum(), v.sum(),'CROSS A5')
|
| 52 |
else:
|
| 53 |
# 1st projected makes k,v (instantaneous)
|
| 54 |
# Here else is self_attention for audio with itself (above is cross attention txt)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
# HISTORY - DIFFERENT FOR EACH TRANSF LAYER
|
| 58 |
+
|
| 59 |
projected = nn.functional.linear(query, self.in_proj_weight)
|
| 60 |
|
| 61 |
bound_layout = "b h p t d"
|
| 62 |
packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
|
| 63 |
q, k, v = packed.unbind(dim=2)
|
| 64 |
|
| 65 |
+
|
| 66 |
if self.k_history is not None:
|
| 67 |
+
# flush
|
| 68 |
if self.k_history.shape[2] > 71:
|
| 69 |
|
| 70 |
self.k_history = torch.cat([self.k_history[:, :, :4, :], self.k_history[:, :, -1:, :]], 2)
|
| 71 |
self.v_history = torch.cat([self.v_history[:, :, :4, :], self.v_history[:, :, -1:, :]], 2)
|
| 72 |
+
# fill new k/v
|
| 73 |
self.k_history = torch.cat([self.k_history, k], 2) # IF ctrl^c here during live demo it is non-atomic k!=v
|
| 74 |
self.v_history = torch.cat([self.v_history, v], 2) # thus it will try to continue with incompatible k/v dims!
|
| 75 |
+
|
| 76 |
+
else:
|
| 77 |
# init
|
| 78 |
self.k_history = k
|
| 79 |
+
self.v_history = v
|
| 80 |
# For self attn prepare
|
| 81 |
k = self.k_history
|
| 82 |
v = self.v_history
|
| 83 |
|
| 84 |
|
| 85 |
+
|
| 86 |
# KV COMPLETION ONLY ON SELF ATTENTION
|
| 87 |
|
| 88 |
x = torch.nn.functional.scaled_dot_product_attention(
|
| 89 |
q, k, v, is_causal=False, dropout_p=0
|
| 90 |
)
|
|
|
|
|
|
|
| 91 |
x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
|
| 92 |
x = self.out_proj(x)
|
| 93 |
return x
|
|
|
|
| 95 |
|
| 96 |
class StreamingTransformerLayer(nn.Module):
|
| 97 |
|
| 98 |
+
def __init__(self,
|
| 99 |
+
d_model,
|
| 100 |
+
num_heads,
|
| 101 |
dim_feedforward):
|
| 102 |
+
|
| 103 |
+
|
| 104 |
super().__init__()
|
| 105 |
+
|
| 106 |
self.self_attn = StreamingMultiheadAttention(embed_dim=d_model,
|
| 107 |
num_heads=num_heads)
|
| 108 |
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
|
|
|
|
| 110 |
self.cross_attention = StreamingMultiheadAttention(embed_dim=d_model,
|
| 111 |
num_heads=num_heads,
|
| 112 |
cross_attention=True)
|
| 113 |
+
self.norm_cross = nn.LayerNorm(d_model, eps=1e-5)
|
| 114 |
self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
|
| 115 |
self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
|
| 116 |
|
|
|
|
| 119 |
src,
|
| 120 |
cross_attention_src=None): # txtcond
|
| 121 |
'''T is saved float16 weights - should we cast src to float16'''
|
| 122 |
+
|
| 123 |
x = src
|
| 124 |
+
|
| 125 |
x = x + self.self_attn(self.norm1(x))
|
| 126 |
+
|
| 127 |
if cross_attention_src is not None:
|
| 128 |
x = x + self.cross_attention(
|
| 129 |
+
query = self.norm_cross(x),
|
| 130 |
+
key = cross_attention_src,
|
| 131 |
value = cross_attention_src) # txtcondition
|
| 132 |
+
|
| 133 |
x = x + self.linear2(F.gelu(self.linear1( self.norm2(x) )))
|
| 134 |
return x
|
| 135 |
|
| 136 |
|
| 137 |
class StreamingTransformer(nn.Module):
|
| 138 |
|
| 139 |
+
def __init__(self,
|
| 140 |
+
d_model=1536,
|
| 141 |
+
num_heads=24,
|
| 142 |
+
num_layers=48,
|
| 143 |
dim_feedforward=6144,
|
| 144 |
cross_attention = True,
|
| 145 |
+
positional_embedding: str = 'sin',
|
| 146 |
max_period: float = 10_000
|
| 147 |
):
|
| 148 |
super().__init__()
|
|
|
|
| 154 |
for idx in range(num_layers):
|
| 155 |
self.layers.append(
|
| 156 |
StreamingTransformerLayer(
|
| 157 |
+
d_model=d_model,
|
| 158 |
+
num_heads=num_heads,
|
| 159 |
dim_feedforward=dim_feedforward
|
| 160 |
)
|
| 161 |
)
|
| 162 |
|
| 163 |
+
def forward(self,
|
| 164 |
+
x,
|
| 165 |
+
token_count=None,
|
| 166 |
cross_attention_src=None):
|
| 167 |
+
|
| 168 |
if self.positional_embedding in ['sin', 'sin_rope']:
|
| 169 |
+
pos_emb = create_sin_embedding(torch.tensor([[[.0]], [[.0]]], device=x.device) + token_count, x.shape[2], max_period=self.max_period)
|
| 170 |
+
|
| 171 |
x = x + pos_emb
|
| 172 |
for j, lay in enumerate(self.layers):
|
|
|
|
| 173 |
x = lay(x, cross_attention_src=cross_attention_src) # cross_attention_src = txt-cond x audio
|
| 174 |
# self attn = audio x audio
|
| 175 |
# Every layer (mha) keeps itsw own kv cachE
|