Dionyssos commited on
Commit
4898f81
·
1 Parent(s): be18bf8

fx pos sinus

Browse files
Files changed (2) hide show
  1. README.md +4 -0
  2. audiocraft/transformer.py +50 -67
README.md CHANGED
@@ -67,6 +67,10 @@ CUDA_DEVICE_ORDER=PCI_BUS_ID HF_HOME=/data/dkounadis/.hf7/ CUDA_VISIBLE_DEVICES=
67
 
68
  Following examples need `api.py` to be running. [Set this IP](https://huggingface.co/dkounadis/artificial-styletts2/blob/main/tts.py#L85) to the IP shown when starting `api.py`.
69
 
 
 
 
 
70
  ```
71
  python tts.py --text assets/ocr.txt --image assets/ocr.jpg --soundscape "battle hero" --voice romanian
72
  ```
 
67
 
68
  Following examples need `api.py` to be running. [Set this IP](https://huggingface.co/dkounadis/artificial-styletts2/blob/main/tts.py#L85) to the IP shown when starting `api.py`.
69
 
70
+ ### Foreign Lang TTS
71
+
72
+ This will produce the following [video](https://www.youtube.com/watch?v=UeJEAsKxRZU)
73
+
74
  ```
75
  python tts.py --text assets/ocr.txt --image assets/ocr.jpg --soundscape "battle hero" --voice romanian
76
  ```
audiocraft/transformer.py CHANGED
@@ -3,44 +3,30 @@ import torch.nn as nn
3
  from torch.nn import functional as F
4
  from einops import rearrange
5
 
6
- def create_sin_embedding(positions,
7
- dim,
8
- max_period = 10000,
9
- dtype = torch.float32):
10
- """Create sinusoidal positional embedding, with shape `[B, T, C]`.
11
-
12
- Args:
13
- positions (torch.Tensor): LongTensor of positions.
14
- dim (int): Dimension of the embedding.
15
- max_period (float): Maximum period of the cosine/sine functions.
16
- dtype (torch.dtype or str): dtype to use to generate the embedding.
17
- Returns:
18
- torch.Tensor: Sinusoidal positional embedding.
19
- """
20
- # We aim for BTC format
21
  assert dim % 2 == 0
22
  half_dim = dim // 2
23
- positions = positions.to(dtype)
24
- adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
25
- max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point
26
  phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
27
- return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
28
 
29
 
30
  class StreamingMultiheadAttention(nn.Module):
31
 
32
- def __init__(self,
33
- embed_dim,
34
  num_heads,
35
  cross_attention = False,
36
  ):
37
-
38
  super().__init__()
39
-
40
  self.cross_attention = cross_attention
41
  self.embed_dim = embed_dim
42
  self.k_history = None # previous k from the previous tokens seen in the current generation - only for selt.attn
43
- self.v_history = None # clean up IN LM after finishing GENERATION - Each 1...47 mha has different kv history
44
  self.num_heads = num_heads
45
  self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
46
  self.register_buffer('in_proj_weight', torch.ones((3 * embed_dim, embed_dim),
@@ -52,58 +38,56 @@ class StreamingMultiheadAttention(nn.Module):
52
  value=None):
53
  layout = "b h t d"
54
  if self.cross_attention:
55
-
56
  # Different queries, keys, values, we have to spit manually the in_proj_weight
57
-
58
  dim = self.in_proj_weight.shape[0] // 3
59
-
60
  q = nn.functional.linear(query, self.in_proj_weight[:dim])
61
  k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim])
62
  v = nn.functional.linear(value, self.in_proj_weight[2 * dim:])
63
-
64
  q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
65
  # print(q.shape, k.shape, v.shape, q.sum(), k.sum(), v.sum(),'CROSS A5')
66
  else:
67
  # 1st projected makes k,v (instantaneous)
68
  # Here else is self_attention for audio with itself (above is cross attention txt)
69
-
70
-
71
  # HISTORY - DIFFERENT FOR EACH TRANSF LAYER
72
-
73
  projected = nn.functional.linear(query, self.in_proj_weight)
74
 
75
  bound_layout = "b h p t d"
76
  packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
77
  q, k, v = packed.unbind(dim=2)
78
 
79
-
80
  if self.k_history is not None:
81
- # flush
82
  if self.k_history.shape[2] > 71:
83
 
84
  self.k_history = torch.cat([self.k_history[:, :, :4, :], self.k_history[:, :, -1:, :]], 2)
85
  self.v_history = torch.cat([self.v_history[:, :, :4, :], self.v_history[:, :, -1:, :]], 2)
86
- # fill new k/v
87
  self.k_history = torch.cat([self.k_history, k], 2) # IF ctrl^c here during live demo it is non-atomic k!=v
88
  self.v_history = torch.cat([self.v_history, v], 2) # thus it will try to continue with incompatible k/v dims!
89
-
90
- else:
91
  # init
92
  self.k_history = k
93
- self.v_history = v
94
  # For self attn prepare
95
  k = self.k_history
96
  v = self.v_history
97
 
98
 
99
-
100
  # KV COMPLETION ONLY ON SELF ATTENTION
101
 
102
  x = torch.nn.functional.scaled_dot_product_attention(
103
  q, k, v, is_causal=False, dropout_p=0
104
  )
105
-
106
- x = x.to(q.dtype)
107
  x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
108
  x = self.out_proj(x)
109
  return x
@@ -111,14 +95,14 @@ class StreamingMultiheadAttention(nn.Module):
111
 
112
  class StreamingTransformerLayer(nn.Module):
113
 
114
- def __init__(self,
115
- d_model,
116
- num_heads,
117
  dim_feedforward):
118
-
119
-
120
  super().__init__()
121
-
122
  self.self_attn = StreamingMultiheadAttention(embed_dim=d_model,
123
  num_heads=num_heads)
124
  self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
@@ -126,7 +110,7 @@ class StreamingTransformerLayer(nn.Module):
126
  self.cross_attention = StreamingMultiheadAttention(embed_dim=d_model,
127
  num_heads=num_heads,
128
  cross_attention=True)
129
- self.norm_cross = nn.LayerNorm(d_model, eps=1e-5)
130
  self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
131
  self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
132
 
@@ -135,30 +119,30 @@ class StreamingTransformerLayer(nn.Module):
135
  src,
136
  cross_attention_src=None): # txtcond
137
  '''T is saved float16 weights - should we cast src to float16'''
138
-
139
  x = src
140
-
141
  x = x + self.self_attn(self.norm1(x))
142
-
143
  if cross_attention_src is not None:
144
  x = x + self.cross_attention(
145
- query = self.norm_cross(x),
146
- key = cross_attention_src,
147
  value = cross_attention_src) # txtcondition
148
-
149
  x = x + self.linear2(F.gelu(self.linear1( self.norm2(x) )))
150
  return x
151
 
152
 
153
  class StreamingTransformer(nn.Module):
154
 
155
- def __init__(self,
156
- d_model=1536,
157
- num_heads=24,
158
- num_layers=48,
159
  dim_feedforward=6144,
160
  cross_attention = True,
161
- positional_embedding: str = 'sin',
162
  max_period: float = 10_000
163
  ):
164
  super().__init__()
@@ -170,23 +154,22 @@ class StreamingTransformer(nn.Module):
170
  for idx in range(num_layers):
171
  self.layers.append(
172
  StreamingTransformerLayer(
173
- d_model=d_model,
174
- num_heads=num_heads,
175
  dim_feedforward=dim_feedforward
176
  )
177
  )
178
 
179
- def forward(self,
180
- x,
181
- token_count=None,
182
  cross_attention_src=None):
183
- B, T, C = x.shape
184
  if self.positional_embedding in ['sin', 'sin_rope']:
185
- positions = torch.arange(T, device=x.device).view(1, -1, 1)
186
- pos_emb = create_sin_embedding(positions + token_count, C, max_period=self.max_period, dtype=x.dtype)
187
  x = x + pos_emb
188
  for j, lay in enumerate(self.layers):
189
- # print(f'Transf Layer c{j} {pos_emb.sum()=} {pos_emb.shape=}{x.sum()=}___________________')
190
  x = lay(x, cross_attention_src=cross_attention_src) # cross_attention_src = txt-cond x audio
191
  # self attn = audio x audio
192
  # Every layer (mha) keeps itsw own kv cachE
 
3
  from torch.nn import functional as F
4
  from einops import rearrange
5
 
6
+ def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  assert dim % 2 == 0
8
  half_dim = dim // 2
9
+ positions = positions.to(torch.float)
10
+ adim = torch.arange(half_dim, device=positions.device, dtype=torch.float).view(1, 1, -1)
11
+ max_period_tensor = torch.full([], max_period, device=positions.device, dtype=torch.float) # avoid sync point
12
  phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
13
+ return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) # OFFICIAL is torch.float32 HOWEVER self_attn.in_prod_weight = torch.float16
14
 
15
 
16
  class StreamingMultiheadAttention(nn.Module):
17
 
18
+ def __init__(self,
19
+ embed_dim,
20
  num_heads,
21
  cross_attention = False,
22
  ):
23
+
24
  super().__init__()
25
+
26
  self.cross_attention = cross_attention
27
  self.embed_dim = embed_dim
28
  self.k_history = None # previous k from the previous tokens seen in the current generation - only for selt.attn
29
+ self.v_history = None # clean up IN LM after finishing GENERATION - Each 1...47 mha has different kv history
30
  self.num_heads = num_heads
31
  self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
32
  self.register_buffer('in_proj_weight', torch.ones((3 * embed_dim, embed_dim),
 
38
  value=None):
39
  layout = "b h t d"
40
  if self.cross_attention:
41
+
42
  # Different queries, keys, values, we have to spit manually the in_proj_weight
43
+
44
  dim = self.in_proj_weight.shape[0] // 3
45
+
46
  q = nn.functional.linear(query, self.in_proj_weight[:dim])
47
  k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim])
48
  v = nn.functional.linear(value, self.in_proj_weight[2 * dim:])
49
+
50
  q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
51
  # print(q.shape, k.shape, v.shape, q.sum(), k.sum(), v.sum(),'CROSS A5')
52
  else:
53
  # 1st projected makes k,v (instantaneous)
54
  # Here else is self_attention for audio with itself (above is cross attention txt)
55
+
56
+
57
  # HISTORY - DIFFERENT FOR EACH TRANSF LAYER
58
+
59
  projected = nn.functional.linear(query, self.in_proj_weight)
60
 
61
  bound_layout = "b h p t d"
62
  packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
63
  q, k, v = packed.unbind(dim=2)
64
 
65
+
66
  if self.k_history is not None:
67
+ # flush
68
  if self.k_history.shape[2] > 71:
69
 
70
  self.k_history = torch.cat([self.k_history[:, :, :4, :], self.k_history[:, :, -1:, :]], 2)
71
  self.v_history = torch.cat([self.v_history[:, :, :4, :], self.v_history[:, :, -1:, :]], 2)
72
+ # fill new k/v
73
  self.k_history = torch.cat([self.k_history, k], 2) # IF ctrl^c here during live demo it is non-atomic k!=v
74
  self.v_history = torch.cat([self.v_history, v], 2) # thus it will try to continue with incompatible k/v dims!
75
+
76
+ else:
77
  # init
78
  self.k_history = k
79
+ self.v_history = v
80
  # For self attn prepare
81
  k = self.k_history
82
  v = self.v_history
83
 
84
 
85
+
86
  # KV COMPLETION ONLY ON SELF ATTENTION
87
 
88
  x = torch.nn.functional.scaled_dot_product_attention(
89
  q, k, v, is_causal=False, dropout_p=0
90
  )
 
 
91
  x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
92
  x = self.out_proj(x)
93
  return x
 
95
 
96
  class StreamingTransformerLayer(nn.Module):
97
 
98
+ def __init__(self,
99
+ d_model,
100
+ num_heads,
101
  dim_feedforward):
102
+
103
+
104
  super().__init__()
105
+
106
  self.self_attn = StreamingMultiheadAttention(embed_dim=d_model,
107
  num_heads=num_heads)
108
  self.linear1 = nn.Linear(d_model, dim_feedforward, bias=False)
 
110
  self.cross_attention = StreamingMultiheadAttention(embed_dim=d_model,
111
  num_heads=num_heads,
112
  cross_attention=True)
113
+ self.norm_cross = nn.LayerNorm(d_model, eps=1e-5)
114
  self.norm1 = nn.LayerNorm(d_model, eps=1e-5)
115
  self.norm2 = nn.LayerNorm(d_model, eps=1e-5)
116
 
 
119
  src,
120
  cross_attention_src=None): # txtcond
121
  '''T is saved float16 weights - should we cast src to float16'''
122
+
123
  x = src
124
+
125
  x = x + self.self_attn(self.norm1(x))
126
+
127
  if cross_attention_src is not None:
128
  x = x + self.cross_attention(
129
+ query = self.norm_cross(x),
130
+ key = cross_attention_src,
131
  value = cross_attention_src) # txtcondition
132
+
133
  x = x + self.linear2(F.gelu(self.linear1( self.norm2(x) )))
134
  return x
135
 
136
 
137
  class StreamingTransformer(nn.Module):
138
 
139
+ def __init__(self,
140
+ d_model=1536,
141
+ num_heads=24,
142
+ num_layers=48,
143
  dim_feedforward=6144,
144
  cross_attention = True,
145
+ positional_embedding: str = 'sin',
146
  max_period: float = 10_000
147
  ):
148
  super().__init__()
 
154
  for idx in range(num_layers):
155
  self.layers.append(
156
  StreamingTransformerLayer(
157
+ d_model=d_model,
158
+ num_heads=num_heads,
159
  dim_feedforward=dim_feedforward
160
  )
161
  )
162
 
163
+ def forward(self,
164
+ x,
165
+ token_count=None,
166
  cross_attention_src=None):
167
+
168
  if self.positional_embedding in ['sin', 'sin_rope']:
169
+ pos_emb = create_sin_embedding(torch.tensor([[[.0]], [[.0]]], device=x.device) + token_count, x.shape[2], max_period=self.max_period)
170
+
171
  x = x + pos_emb
172
  for j, lay in enumerate(self.layers):
 
173
  x = lay(x, cross_attention_src=cross_attention_src) # cross_attention_src = txt-cond x audio
174
  # self attn = audio x audio
175
  # Every layer (mha) keeps itsw own kv cachE