noblebarkrr commited on
Commit
09d1dfb
·
verified ·
1 Parent(s): 51c079b

Upload 15 files

Browse files
models/bs_roformer/__init__.py CHANGED
@@ -1,11 +1,11 @@
1
- from .bs_roformer import BSRoformer
2
- from .bs_conformer import BSConformer
3
- from .bs_roformer_sw import BSRoformer_SW
4
- from .bs_roformer_fno import BSRoformer_FNO
5
- from .bs_roformer_hyperace import BSRoformerHyperACE
6
- from .bs_roformer_hyperace2 import BSRoformerHyperACE_2
7
- from .bs_roformer_conditional import BSRoformer_Conditional
8
- from .bs_roformer_unwa_inst_large_2 import BSRoformer_2
9
- from .bs_siamese_roformer import BSSiameseRoformer
10
- from .mel_band_roformer import MelBandRoformer
11
- from .mel_band_conformer import MelBandConformer
 
1
+ from .bs_roformer import BSRoformer
2
+ from .bs_conformer import BSConformer
3
+ from .bs_roformer_sw import BSRoformer_SW
4
+ from .bs_roformer_fno import BSRoformer_FNO
5
+ from .bs_roformer_hyperace import BSRoformerHyperACE
6
+ from .bs_roformer_hyperace2 import BSRoformerHyperACE_2
7
+ from .bs_roformer_conditional import BSRoformer_Conditional
8
+ from .bs_roformer_unwa_inst_large_2 import BSRoformer_2
9
+ from .bs_siamese_roformer import BSSiameseRoformer
10
+ from .mel_band_roformer import MelBandRoformer
11
+ from .mel_band_conformer import MelBandConformer
models/bs_roformer/attend.py CHANGED
@@ -1,128 +1,151 @@
1
- from functools import wraps
2
- from packaging import version
3
- from collections import namedtuple
4
-
5
- import os
6
- import torch
7
- from torch import nn, einsum
8
- import torch.nn.functional as F
9
-
10
- from einops import rearrange, reduce
11
-
12
-
13
- FlashAttentionConfig = namedtuple(
14
- "FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]
15
- )
16
-
17
-
18
- def exists(val):
19
- return val is not None
20
-
21
-
22
- def default(v, d):
23
- return v if exists(v) else d
24
-
25
-
26
- def once(fn):
27
- called = False
28
-
29
- @wraps(fn)
30
- def inner(x):
31
- nonlocal called
32
- if called:
33
- return
34
- called = True
35
- return fn(x)
36
-
37
- return inner
38
-
39
-
40
- print_once = once(print)
41
-
42
-
43
- class Attend(nn.Module):
44
- def __init__(self, dropout=0.0, flash=False, scale=None):
45
- super().__init__()
46
- self.scale = scale
47
- self.dropout = dropout
48
- self.attn_dropout = nn.Dropout(dropout)
49
-
50
- self.flash = flash
51
- self.use_torch_2_sdpa = False
52
- self._config_checked = False
53
-
54
- # Проверяем версию PyTorch при первом вызове
55
- if flash and not self._config_checked:
56
- if version.parse(torch.__version__) >= version.parse("2.0.0"):
57
- print_once("PyTorch >= 2.0 detected, will use SDPA if available.")
58
- self.use_torch_2_sdpa = True
59
-
60
- # Настройки для PyTorch >= 2.0
61
- self.cpu_config = FlashAttentionConfig(True, True, True)
62
- self.cuda_config = None
63
-
64
- if torch.cuda.is_available():
65
- device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
66
- device_version = version.parse(
67
- f"{device_properties.major}.{device_properties.minor}"
68
- )
69
-
70
- if device_version >= version.parse("8.0"):
71
- if os.name == "nt":
72
- print_once(
73
- "Windows OS detected, using math or mem efficient attention if input tensor is on cuda"
74
- )
75
- self.cuda_config = FlashAttentionConfig(False, True, True)
76
- else:
77
- print_once(
78
- "GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda"
79
- )
80
- self.cuda_config = FlashAttentionConfig(True, False, False)
81
- else:
82
- print_once(
83
- "GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda"
84
- )
85
- self.cuda_config = FlashAttentionConfig(False, True, True)
86
- else:
87
- print_once("PyTorch < 2.0 detected, flash attention will use einsum fallback.")
88
- self.use_torch_2_sdpa = False
89
-
90
- self._config_checked = True
91
-
92
- def flash_attn_torch2(self, q, k, v):
93
- """SDPA для PyTorch >= 2.0"""
94
- if exists(self.scale):
95
- default_scale = q.shape[-1] ** -0.5
96
- q = q * (self.scale / default_scale)
97
-
98
- is_cuda = q.is_cuda
99
- config = self.cuda_config if is_cuda else self.cpu_config
100
-
101
- with torch.backends.cuda.sdp_kernel(**config._asdict()):
102
- out = F.scaled_dot_product_attention(
103
- q, k, v, dropout_p=self.dropout if self.training else 0.0
104
- )
105
-
106
- return out
107
-
108
- def forward(self, q, k, v):
109
- q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
110
-
111
- scale = default(self.scale, q.shape[-1] ** -0.5)
112
-
113
- if self.flash and self.use_torch_2_sdpa:
114
- try:
115
- return self.flash_attn_torch2(q, k, v)
116
- except Exception as e:
117
- print(f"Flash attention failed: {e}. Falling back to einsum.")
118
- self.use_torch_2_sdpa = False
119
-
120
- # Fallback для PyTorch < 2.0 или если flash отключен
121
- sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
122
-
123
- attn = sim.softmax(dim=-1)
124
- attn = self.attn_dropout(attn)
125
-
126
- out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
127
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  return out
 
1
+ from functools import wraps
2
+ from packaging import version
3
+ from collections import namedtuple
4
+
5
+ import os
6
+ import torch
7
+ from torch import nn, einsum
8
+ import torch.nn.functional as F
9
+
10
+ from einops import rearrange, reduce
11
+
12
+
13
+ FlashAttentionConfig = namedtuple(
14
+ "FlashAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]
15
+ )
16
+
17
+
18
+ def exists(val):
19
+ return val is not None
20
+
21
+
22
+ def default(v, d):
23
+ return v if exists(v) else d
24
+
25
+
26
+ def once(fn):
27
+ called = False
28
+
29
+ @wraps(fn)
30
+ def inner(x):
31
+ nonlocal called
32
+ if called:
33
+ return
34
+ called = True
35
+ return fn(x)
36
+
37
+ return inner
38
+
39
+
40
+ print_once = once(print)
41
+
42
+
43
+ class Attend(nn.Module):
44
+ def __init__(self, dropout=0.0, flash=False, scale=None):
45
+ super().__init__()
46
+ self.scale = scale
47
+ self.dropout = dropout
48
+ self.attn_dropout = nn.Dropout(dropout)
49
+
50
+ self.flash = flash
51
+ self.use_torch_2_sdpa = False
52
+ self._config_checked = False
53
+
54
+ # Проверяем версию PyTorch при первом вызове
55
+ if flash and not self._config_checked:
56
+ if version.parse(torch.__version__) >= version.parse("2.0.0"):
57
+ print_once("PyTorch >= 2.0 detected, will use SDPA if available.")
58
+ self.use_torch_2_sdpa = True
59
+
60
+ # Настройки для PyTorch >= 2.0
61
+ self.cpu_config = FlashAttentionConfig(True, True, True)
62
+ self.cuda_config = None
63
+
64
+ if torch.cuda.is_available():
65
+ device_properties = torch.cuda.get_device_properties(torch.device("cuda"))
66
+ device_version = version.parse(
67
+ f"{device_properties.major}.{device_properties.minor}"
68
+ )
69
+
70
+ if device_version >= version.parse("8.0"):
71
+ if os.name == "nt":
72
+ print_once(
73
+ "Windows OS detected, using math or mem efficient attention if input tensor is on cuda"
74
+ )
75
+ self.cuda_config = FlashAttentionConfig(False, True, True)
76
+ else:
77
+ print_once(
78
+ "GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda"
79
+ )
80
+ self.cuda_config = FlashAttentionConfig(True, False, False)
81
+ else:
82
+ print_once(
83
+ "GPU Compute Capability below 8.0, using math or mem efficient attention if input tensor is on cuda"
84
+ )
85
+ self.cuda_config = FlashAttentionConfig(False, True, True)
86
+ else:
87
+ print_once("PyTorch < 2.0 detected, flash attention will use einsum fallback.")
88
+ self.use_torch_2_sdpa = False
89
+
90
+ self._config_checked = True
91
+
92
+ def flash_attn_torch2(self, q, k, v):
93
+ """SDPA для PyTorch >= 2.0"""
94
+ if exists(self.scale):
95
+ default_scale = q.shape[-1] ** -0.5
96
+ q = q * (self.scale / default_scale)
97
+
98
+ is_cuda = q.is_cuda
99
+ config = self.cuda_config if is_cuda else self.cpu_config
100
+
101
+ old_sdp_kernel = False
102
+ if hasattr(torch, "backends"):
103
+ if hasattr(torch.backends, "cuda"):
104
+ if hasattr(torch.backends.cuda, "sdp_kernel"):
105
+ old_sdp_kernel = True
106
+ new_sdp_kernel = False
107
+ if hasattr(torch, "nn"):
108
+ if hasattr(torch.nn, "attention"):
109
+ if hasattr(torch.nn.attention, "sdpa_kernel") and hasattr(torch.nn.attention, "SDPBackend"):
110
+ new_sdp_kernel = True
111
+
112
+ if old_sdp_kernel and not new_sdp_kernel:
113
+ with torch.backends.cuda.sdp_kernel(**config._asdict()):
114
+ out = F.scaled_dot_product_attention(
115
+ q, k, v, dropout_p=self.dropout if self.training else 0.0
116
+ )
117
+ else:
118
+ backends = []
119
+ if config.enable_flash:
120
+ backends.append(torch.nn.attention.SDPBackend.FLASH_ATTENTION)
121
+ if config.enable_mem_efficient:
122
+ backends.append(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION)
123
+ if config.enable_math:
124
+ backends.append(torch.nn.attention.SDPBackend.MATH)
125
+ with torch.nn.attention.sdpa_kernel(backends):
126
+ out = F.scaled_dot_product_attention(
127
+ q, k, v, dropout_p=self.dropout if self.training else 0.0
128
+ )
129
+ return out
130
+
131
+ def forward(self, q, k, v):
132
+ q_len, k_len, device = q.shape[-2], k.shape[-2], q.device
133
+
134
+ scale = default(self.scale, q.shape[-1] ** -0.5)
135
+
136
+ if self.flash and self.use_torch_2_sdpa:
137
+ try:
138
+ return self.flash_attn_torch2(q, k, v)
139
+ except Exception as e:
140
+ print(f"Flash attention failed: {e}. Falling back to einsum.")
141
+ self.use_torch_2_sdpa = False
142
+
143
+ # Fallback для PyTorch < 2.0 или если flash отключен
144
+ sim = einsum(f"b h i d, b h j d -> b h i j", q, k) * scale
145
+
146
+ attn = sim.softmax(dim=-1)
147
+ attn = self.attn_dropout(attn)
148
+
149
+ out = einsum(f"b h i j, b h j d -> b h i d", attn, v)
150
+
151
  return out
models/bs_roformer/bs_conformer.py CHANGED
@@ -6,10 +6,6 @@ from torch.nn import Module, ModuleList
6
  import torch.nn.functional as F
7
 
8
  from .attend import Attend
9
- try:
10
- from .attend_sage import Attend as AttendSage
11
- except:
12
- pass
13
  from torch.utils.checkpoint import checkpoint
14
 
15
  from beartype.typing import Tuple, Optional, List, Callable
@@ -95,7 +91,6 @@ class Attention(Module):
95
  dropout=0.,
96
  rotary_embed=None,
97
  flash=True,
98
- sage_attention=False,
99
  ):
100
  super().__init__()
101
  self.heads = heads
@@ -104,10 +99,7 @@ class Attention(Module):
104
 
105
  self.rotary_embed = rotary_embed
106
 
107
- if sage_attention:
108
- self.attend = AttendSage(flash=flash, dropout=dropout)
109
- else:
110
- self.attend = Attend(flash=flash, dropout=dropout)
111
 
112
  self.norm = RMSNorm(dim)
113
  self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
@@ -151,7 +143,7 @@ class LinearAttention(Module):
151
  scale=8,
152
  flash=True,
153
  dropout=0.,
154
- sage_attention=False
155
  ):
156
  super().__init__()
157
  dim_inner = dim_head * heads
@@ -164,10 +156,7 @@ class LinearAttention(Module):
164
 
165
  self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
166
 
167
- if sage_attention:
168
- self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash)
169
- else:
170
- self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
171
 
172
  self.to_out = nn.Sequential(
173
  Rearrange('b h d n -> b n (h d)'),
@@ -200,7 +189,6 @@ class Transformer(Module):
200
  rotary_embed=None,
201
  flash_attn=True,
202
  linear_attn=False,
203
- sage_attention=False,
204
  ):
205
  super().__init__()
206
  self.layers = ModuleList([])
@@ -213,7 +201,6 @@ class Transformer(Module):
213
  heads=heads,
214
  dropout=attn_dropout,
215
  flash=flash_attn,
216
- sage_attention=sage_attention
217
  )
218
  else:
219
  attn = Attention(
@@ -223,7 +210,6 @@ class Transformer(Module):
223
  dropout=attn_dropout,
224
  rotary_embed=rotary_embed,
225
  flash=flash_attn,
226
- sage_attention=sage_attention
227
  )
228
 
229
  self.layers.append(ModuleList([
@@ -288,7 +274,6 @@ class ConformerBlock(nn.Module):
288
  conv_kernel_size=31,
289
  rotary_embed=None,
290
  flash_attn=True,
291
- sage_attention=False
292
  ):
293
  super().__init__()
294
  self.ff1 = MacaronFF(dim=dim, mult=ff_mult, dropout=ff_dropout)
@@ -299,7 +284,6 @@ class ConformerBlock(nn.Module):
299
  dropout=attn_dropout,
300
  rotary_embed=rotary_embed,
301
  flash=flash_attn,
302
- sage_attention=sage_attention
303
  )
304
  self.conv = ConformerConvModule(
305
  dim=dim,
@@ -331,7 +315,6 @@ class Conformer(Module):
331
  ff_mult=4,
332
  rotary_embed=None,
333
  flash_attn=True,
334
- sage_attention=False,
335
  conv_expansion_factor=2,
336
  conv_kernel_size=31,
337
  norm_output=True
@@ -349,7 +332,6 @@ class Conformer(Module):
349
  conv_kernel_size=conv_kernel_size,
350
  rotary_embed=rotary_embed,
351
  flash_attn=flash_attn,
352
- sage_attention=sage_attention
353
  ) for _ in range(depth)
354
  ])
355
  self.norm = RMSNorm(dim) if norm_output else nn.Identity()
@@ -473,11 +455,11 @@ class BSConformer(Module):
473
  mlp_expansion_factor = 4,
474
  use_torch_checkpoint = False,
475
  skip_connection = False,
476
- sage_attention = False,
477
  # conformer-specific
478
  ff_mult = 4,
479
  conv_expansion_factor = 2,
480
- conv_kernel_size = 31
 
481
  ):
482
  super().__init__()
483
  self.stereo = stereo
@@ -488,9 +470,6 @@ class BSConformer(Module):
488
 
489
  self.layers = ModuleList([])
490
 
491
- if sage_attention:
492
- print("Use Sage Attention")
493
-
494
  transformer_kwargs = dict(
495
  dim = dim,
496
  heads = heads,
@@ -498,7 +477,6 @@ class BSConformer(Module):
498
  attn_dropout = attn_dropout,
499
  ff_dropout = ff_dropout,
500
  flash_attn = flash_attn,
501
- sage_attention = sage_attention,
502
  norm_output = False
503
  )
504
 
 
6
  import torch.nn.functional as F
7
 
8
  from .attend import Attend
 
 
 
 
9
  from torch.utils.checkpoint import checkpoint
10
 
11
  from beartype.typing import Tuple, Optional, List, Callable
 
91
  dropout=0.,
92
  rotary_embed=None,
93
  flash=True,
 
94
  ):
95
  super().__init__()
96
  self.heads = heads
 
99
 
100
  self.rotary_embed = rotary_embed
101
 
102
+ self.attend = Attend(flash=flash, dropout=dropout)
 
 
 
103
 
104
  self.norm = RMSNorm(dim)
105
  self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
 
143
  scale=8,
144
  flash=True,
145
  dropout=0.,
146
+ **kwargs
147
  ):
148
  super().__init__()
149
  dim_inner = dim_head * heads
 
156
 
157
  self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
158
 
159
+ self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
 
 
 
160
 
161
  self.to_out = nn.Sequential(
162
  Rearrange('b h d n -> b n (h d)'),
 
189
  rotary_embed=None,
190
  flash_attn=True,
191
  linear_attn=False,
 
192
  ):
193
  super().__init__()
194
  self.layers = ModuleList([])
 
201
  heads=heads,
202
  dropout=attn_dropout,
203
  flash=flash_attn,
 
204
  )
205
  else:
206
  attn = Attention(
 
210
  dropout=attn_dropout,
211
  rotary_embed=rotary_embed,
212
  flash=flash_attn,
 
213
  )
214
 
215
  self.layers.append(ModuleList([
 
274
  conv_kernel_size=31,
275
  rotary_embed=None,
276
  flash_attn=True,
 
277
  ):
278
  super().__init__()
279
  self.ff1 = MacaronFF(dim=dim, mult=ff_mult, dropout=ff_dropout)
 
284
  dropout=attn_dropout,
285
  rotary_embed=rotary_embed,
286
  flash=flash_attn,
 
287
  )
288
  self.conv = ConformerConvModule(
289
  dim=dim,
 
315
  ff_mult=4,
316
  rotary_embed=None,
317
  flash_attn=True,
 
318
  conv_expansion_factor=2,
319
  conv_kernel_size=31,
320
  norm_output=True
 
332
  conv_kernel_size=conv_kernel_size,
333
  rotary_embed=rotary_embed,
334
  flash_attn=flash_attn,
 
335
  ) for _ in range(depth)
336
  ])
337
  self.norm = RMSNorm(dim) if norm_output else nn.Identity()
 
455
  mlp_expansion_factor = 4,
456
  use_torch_checkpoint = False,
457
  skip_connection = False,
 
458
  # conformer-specific
459
  ff_mult = 4,
460
  conv_expansion_factor = 2,
461
+ conv_kernel_size = 31,
462
+ **kwargs
463
  ):
464
  super().__init__()
465
  self.stereo = stereo
 
470
 
471
  self.layers = ModuleList([])
472
 
 
 
 
473
  transformer_kwargs = dict(
474
  dim = dim,
475
  heads = heads,
 
477
  attn_dropout = attn_dropout,
478
  ff_dropout = ff_dropout,
479
  flash_attn = flash_attn,
 
480
  norm_output = False
481
  )
482
 
models/bs_roformer/bs_roformer.py CHANGED
@@ -1,767 +1,768 @@
1
- from functools import partial
2
-
3
- import torch
4
- from torch import nn, einsum, tensor, Tensor
5
- from torch.nn import Module, ModuleList
6
- import torch.nn.functional as F
7
-
8
- from .attend import Attend
9
-
10
- from torch.utils.checkpoint import checkpoint
11
-
12
- from beartype.typing import Tuple, Optional, List, Callable
13
- from beartype import beartype
14
-
15
- from rotary_embedding_torch import RotaryEmbedding
16
-
17
- from einops import rearrange, pack, unpack
18
- from einops.layers.torch import Rearrange
19
-
20
- try:
21
- from .pope.attention import flash_attn_with_pope
22
- from .pope.pope import PoPE
23
- _HAS_POPE = True
24
- except Exception:
25
- PoPE = None
26
- flash_attn_with_pope = None
27
- _HAS_POPE = False
28
-
29
- # helper functions
30
-
31
- def exists(val):
32
- return val is not None
33
-
34
-
35
- def default(v, d):
36
- return v if exists(v) else d
37
-
38
-
39
- def pack_one(t, pattern):
40
- return pack([t], pattern)
41
-
42
-
43
- def unpack_one(t, ps, pattern):
44
- return unpack(t, ps, pattern)[0]
45
-
46
-
47
- # norm
48
-
49
- def l2norm(t):
50
- return F.normalize(t, dim = -1, p = 2)
51
-
52
-
53
- class RMSNorm(Module):
54
- def __init__(self, dim):
55
- super().__init__()
56
- self.scale = dim ** 0.5
57
- self.gamma = nn.Parameter(torch.ones(dim))
58
-
59
- def forward(self, x):
60
- return F.normalize(x, dim=-1) * self.scale * self.gamma
61
-
62
-
63
- # attention
64
-
65
- class FeedForward(Module):
66
- def __init__(
67
- self,
68
- dim,
69
- mult=4,
70
- dropout=0.
71
- ):
72
- super().__init__()
73
- dim_inner = int(dim * mult)
74
- self.net = nn.Sequential(
75
- RMSNorm(dim),
76
- nn.Linear(dim, dim_inner),
77
- nn.GELU(),
78
- nn.Dropout(dropout),
79
- nn.Linear(dim_inner, dim),
80
- nn.Dropout(dropout)
81
- )
82
-
83
- def forward(self, x):
84
- return self.net(x)
85
-
86
-
87
- class Attention(Module):
88
- def __init__(
89
- self,
90
- dim,
91
- heads=8,
92
- dim_head=64,
93
- dropout=0.,
94
- rotary_embed=None,
95
- flash=True,
96
- pope_embed=None,
97
- learned_value_residual_mix = False
98
- ):
99
- super().__init__()
100
- self.heads = heads
101
- self.scale = dim_head ** -0.5
102
- dim_inner = heads * dim_head
103
-
104
- self.rotary_embed = rotary_embed
105
- self.pope_embed = pope_embed
106
- assert not (self.rotary_embed is not None and self.pope_embed is not None), \
107
- "cannot have both rotary and pope embeddings"
108
-
109
- self.attend = Attend(flash=flash, dropout=dropout)
110
-
111
- self.norm = RMSNorm(dim)
112
- self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
113
-
114
- self.to_value_residual_mix = nn.Linear(dim, heads) if learned_value_residual_mix else None
115
-
116
- self.to_gates = nn.Linear(dim, heads)
117
-
118
- self.to_out = nn.Sequential(
119
- nn.Linear(dim_inner, dim, bias=False),
120
- nn.Dropout(dropout)
121
- )
122
-
123
- def forward(self, x, value_residual = None):
124
- x = self.norm(x)
125
-
126
- q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
127
-
128
- orig_v = v
129
-
130
- if exists(self.pope_embed):
131
- out = flash_attn_with_pope(
132
- q, k, v,
133
- pos_emb=self.pope_embed(q.shape[-2]),
134
- softmax_scale=self.scale
135
- )
136
- elif exists(self.rotary_embed):
137
- q = self.rotary_embed.rotate_queries_or_keys(q)
138
- k = self.rotary_embed.rotate_queries_or_keys(k)
139
- out = self.attend(q, k, v)
140
- elif exists(self.to_value_residual_mix):
141
- mix = self.to_value_residual_mix(x)
142
- mix = rearrange(mix, 'b n h -> b h n 1').sigmoid()
143
-
144
- assert exists(value_residual)
145
- v = v.lerp(value_residual, mix)
146
- out = self.attend(q, k, v)
147
- else:
148
- out = self.attend(q, k, v)
149
-
150
- gates = self.to_gates(x)
151
- out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
152
-
153
- out = rearrange(out, 'b h n d -> b n (h d)')
154
- return self.to_out(out), orig_v
155
-
156
-
157
- class LinearAttention(Module):
158
- """
159
- this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
160
- """
161
-
162
- @beartype
163
- def __init__(
164
- self,
165
- *,
166
- dim,
167
- dim_head=32,
168
- heads=8,
169
- scale=8,
170
- flash=False,
171
- dropout=0.
172
- ):
173
- super().__init__()
174
- dim_inner = dim_head * heads
175
- self.norm = RMSNorm(dim)
176
-
177
- self.to_qkv = nn.Sequential(
178
- nn.Linear(dim, dim_inner * 3, bias=False),
179
- Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
180
- )
181
-
182
- self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
183
-
184
- self.attend = Attend(
185
- scale=scale,
186
- dropout=dropout,
187
- flash=flash
188
- )
189
-
190
- self.to_out = nn.Sequential(
191
- Rearrange('b h d n -> b n (h d)'),
192
- nn.Linear(dim_inner, dim, bias=False)
193
- )
194
-
195
- def forward(
196
- self,
197
- x
198
- ):
199
- x = self.norm(x)
200
-
201
- q, k, v = self.to_qkv(x)
202
-
203
- q, k = map(l2norm, (q, k))
204
- q = q * self.temperature.exp()
205
-
206
- out = self.attend(q, k, v)
207
-
208
- return self.to_out(out)
209
-
210
-
211
- class Transformer(Module):
212
- def __init__(
213
- self,
214
- *,
215
- dim,
216
- depth,
217
- dim_head=64,
218
- heads=8,
219
- attn_dropout=0.,
220
- ff_dropout=0.,
221
- ff_mult=4,
222
- norm_output=True,
223
- rotary_embed=None,
224
- pope_embed=None,
225
- flash_attn=True,
226
- linear_attn=False,
227
- add_value_residual = False
228
- ):
229
- super().__init__()
230
- self.layers = ModuleList([])
231
-
232
- for _ in range(depth):
233
- if linear_attn:
234
- attn = LinearAttention(
235
- dim=dim,
236
- dim_head=dim_head,
237
- heads=heads,
238
- dropout=attn_dropout,
239
- flash=flash_attn
240
- )
241
- else:
242
- attn = Attention(
243
- dim=dim,
244
- dim_head=dim_head,
245
- heads=heads,
246
- dropout=attn_dropout,
247
- rotary_embed=rotary_embed,
248
- pope_embed=pope_embed,
249
- flash=flash_attn,
250
- learned_value_residual_mix=add_value_residual
251
- )
252
-
253
- self.layers.append(ModuleList([
254
- attn,
255
- FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
256
- ]))
257
-
258
- self.norm = RMSNorm(dim) if norm_output else nn.Identity()
259
-
260
- def forward(self, x, value_residual=None):
261
- first_values = None
262
-
263
- for attn, ff in self.layers:
264
- attn_out, next_values = attn(x, value_residual=value_residual)
265
-
266
- if first_values is None:
267
- first_values = next_values
268
-
269
- x = attn_out + x
270
- x = ff(x) + x
271
-
272
- return self.norm(x), first_values
273
-
274
-
275
- # bandsplit module
276
-
277
- class BandSplit(Module):
278
- @beartype
279
- def __init__(
280
- self,
281
- dim,
282
- dim_inputs: Tuple[int, ...]
283
- ):
284
- super().__init__()
285
- self.dim_inputs = dim_inputs
286
- self.to_features = ModuleList([])
287
-
288
- for dim_in in dim_inputs:
289
- net = nn.Sequential(
290
- RMSNorm(dim_in),
291
- nn.Linear(dim_in, dim)
292
- )
293
-
294
- self.to_features.append(net)
295
-
296
- def forward(self, x):
297
- x = x.split(self.dim_inputs, dim=-1)
298
-
299
- outs = []
300
- for split_input, to_feature in zip(x, self.to_features):
301
- split_output = to_feature(split_input)
302
- outs.append(split_output)
303
-
304
- return torch.stack(outs, dim=-2)
305
-
306
-
307
- def MLP(
308
- dim_in,
309
- dim_out,
310
- dim_hidden=None,
311
- depth=1,
312
- activation=nn.Tanh
313
- ):
314
- dim_hidden = default(dim_hidden, dim_in)
315
-
316
- net = []
317
- dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
318
-
319
- for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
320
- is_last = ind == (len(dims) - 2)
321
-
322
- net.append(nn.Linear(layer_dim_in, layer_dim_out))
323
-
324
- if is_last:
325
- continue
326
-
327
- net.append(activation())
328
-
329
- return nn.Sequential(*net)
330
-
331
-
332
- class MaskEstimator(Module):
333
- @beartype
334
- def __init__(
335
- self,
336
- dim,
337
- dim_inputs: Tuple[int, ...],
338
- depth,
339
- mlp_expansion_factor=4
340
- ):
341
- super().__init__()
342
- self.dim_inputs = dim_inputs
343
- self.to_freqs = ModuleList([])
344
- dim_hidden = dim * mlp_expansion_factor
345
-
346
- for dim_in in dim_inputs:
347
- net = []
348
-
349
- mlp = nn.Sequential(
350
- MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
351
- nn.GLU(dim=-1)
352
- )
353
-
354
- self.to_freqs.append(mlp)
355
-
356
- def forward(self, x):
357
- x = x.unbind(dim=-2)
358
-
359
- outs = []
360
-
361
- for band_features, mlp in zip(x, self.to_freqs):
362
- freq_out = mlp(band_features)
363
- outs.append(freq_out)
364
-
365
- return torch.cat(outs, dim=-1)
366
-
367
-
368
- # main class
369
-
370
- DEFAULT_FREQS_PER_BANDS = (
371
- 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
372
- 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
373
- 2, 2, 2, 2,
374
- 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
375
- 12, 12, 12, 12, 12, 12, 12, 12,
376
- 24, 24, 24, 24, 24, 24, 24, 24,
377
- 48, 48, 48, 48, 48, 48, 48, 48,
378
- 128, 129,
379
- )
380
-
381
-
382
- class BSRoformer(Module):
383
-
384
- @beartype
385
- def __init__(
386
- self,
387
- dim,
388
- *,
389
- depth,
390
- stereo=False,
391
- num_stems=1,
392
- time_transformer_depth=2,
393
- freq_transformer_depth=2,
394
- linear_transformer_depth=0,
395
- freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
396
- # in the paper, they divide into ~60 bands, test with 1 for starters
397
- dim_head=64,
398
- heads=8,
399
- attn_dropout=0.,
400
- ff_dropout=0.,
401
- flash_attn=True,
402
- dim_freqs_in=1025,
403
- stft_n_fft=2048,
404
- stft_hop_length=512,
405
- # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
406
- stft_win_length=2048,
407
- stft_normalized=False,
408
- stft_window_fn: Optional[Callable] = None,
409
- zero_dc = True,
410
- mask_estimator_depth=2,
411
- multi_stft_resolution_loss_weight=1.,
412
- multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
413
- multi_stft_hop_size=147,
414
- multi_stft_normalized=False,
415
- multi_stft_window_fn: Callable = torch.hann_window,
416
- mlp_expansion_factor=4,
417
- use_torch_checkpoint=False,
418
- skip_connection=False,
419
- use_pope: bool = False,
420
- residual_value: bool = False
421
- ):
422
- super().__init__()
423
-
424
- self.stereo = stereo
425
- self.audio_channels = 2 if stereo else 1
426
- self.num_stems = num_stems
427
- self.use_torch_checkpoint = use_torch_checkpoint
428
- self.skip_connection = skip_connection
429
-
430
- self.layers = ModuleList([])
431
-
432
- transformer_kwargs = dict(
433
- dim=dim,
434
- heads=heads,
435
- dim_head=dim_head,
436
- attn_dropout=attn_dropout,
437
- ff_dropout=ff_dropout,
438
- flash_attn=flash_attn,
439
- norm_output=False,
440
- )
441
-
442
- if use_pope:
443
- time_pope_embed = PoPE(dim=dim_head, heads=heads)
444
- freq_pope_embed = PoPE(dim=dim_head, heads=heads)
445
- time_rotary_embed = None
446
- freq_rotary_embed = None
447
- else:
448
- time_rotary_embed = RotaryEmbedding(dim = dim_head)
449
- freq_rotary_embed = RotaryEmbedding(dim = dim_head)
450
- time_pope_embed = freq_pope_embed = None
451
-
452
- if residual_value:
453
- for layer_index in range(depth):
454
- tran_modules = []
455
- is_first = layer_index == 0
456
- if linear_transformer_depth > 0:
457
- tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, add_value_residual=not is_first, **transformer_kwargs))
458
- tran_modules.append(
459
- Transformer(
460
- depth=time_transformer_depth,
461
- rotary_embed=time_rotary_embed,
462
- pope_embed=time_pope_embed,
463
- add_value_residual=not is_first,
464
- **transformer_kwargs
465
- )
466
- )
467
- tran_modules.append(
468
- Transformer(
469
- depth=freq_transformer_depth,
470
- rotary_embed=freq_rotary_embed,
471
- pope_embed=freq_pope_embed,
472
- add_value_residual=not is_first,
473
- **transformer_kwargs
474
- )
475
- )
476
- self.layers.append(nn.ModuleList(tran_modules))
477
- else:
478
- for layer_index in range(depth):
479
- tran_modules = []
480
- if linear_transformer_depth > 0:
481
- tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, add_value_residual=not is_first, **transformer_kwargs))
482
- tran_modules.append(
483
- Transformer(
484
- depth=time_transformer_depth,
485
- rotary_embed=time_rotary_embed,
486
- pope_embed=time_pope_embed,
487
- add_value_residual=False,
488
- **transformer_kwargs
489
- )
490
- )
491
- tran_modules.append(
492
- Transformer(
493
- depth=freq_transformer_depth,
494
- rotary_embed=freq_rotary_embed,
495
- pope_embed=freq_pope_embed,
496
- add_value_residual=False,
497
- **transformer_kwargs
498
- )
499
- )
500
- self.layers.append(nn.ModuleList(tran_modules))
501
-
502
- self.final_norm = RMSNorm(dim)
503
-
504
- self.stft_kwargs = dict(
505
- n_fft=stft_n_fft,
506
- hop_length=stft_hop_length,
507
- win_length=stft_win_length,
508
- normalized=stft_normalized
509
- )
510
-
511
- self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
512
-
513
- freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_win_length), return_complex=True).shape[1]
514
-
515
- assert len(freqs_per_bands) > 1
516
- assert sum(
517
- freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}'
518
-
519
- freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
520
-
521
- self.band_split = BandSplit(
522
- dim=dim,
523
- dim_inputs=freqs_per_bands_with_complex
524
- )
525
-
526
- self.mask_estimators = nn.ModuleList([])
527
-
528
- for _ in range(num_stems):
529
- mask_estimator = MaskEstimator(
530
- dim=dim,
531
- dim_inputs=freqs_per_bands_with_complex,
532
- depth=mask_estimator_depth,
533
- mlp_expansion_factor=mlp_expansion_factor,
534
- )
535
-
536
- self.mask_estimators.append(mask_estimator)
537
-
538
- # whether to zero out dc
539
-
540
- self.zero_dc = zero_dc
541
-
542
- # for the multi-resolution stft loss
543
-
544
- self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
545
- self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
546
- self.multi_stft_n_fft = stft_n_fft
547
- self.multi_stft_window_fn = multi_stft_window_fn
548
-
549
- self.multi_stft_kwargs = dict(
550
- hop_length=multi_stft_hop_size,
551
- normalized=multi_stft_normalized
552
- )
553
-
554
- def forward(
555
- self,
556
- raw_audio,
557
- target=None,
558
- active_stem_ids=None,
559
- return_loss_breakdown=False
560
- ):
561
- """
562
- einops
563
-
564
- b - batch
565
- f - freq
566
- t - time
567
- s - audio channel (1 for mono, 2 for stereo)
568
- n - number of 'stems'
569
- c - complex (2)
570
- d - feature dimension
571
- """
572
-
573
- device = raw_audio.device
574
- x_is_mps = True if device.type == "mps" else False
575
-
576
- if raw_audio.ndim == 2:
577
- raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
578
-
579
- channels = raw_audio.shape[1]
580
- assert (not self.stereo and channels == 1) or (
581
- self.stereo and channels == 2),\
582
- ('stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2).'
583
- ' also need to be False if mono (channel dimension of 1)')
584
-
585
- # to stft
586
-
587
- raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
588
-
589
- stft_window = self.stft_window_fn(device=device)
590
-
591
- try:
592
- stft_repr = torch.stft(
593
- raw_audio,
594
- **self.stft_kwargs,
595
- window=stft_window,
596
- return_complex=True
597
- )
598
- except:
599
- stft_repr = torch.stft(
600
- raw_audio.cpu() if x_is_mps else raw_audio,
601
- **self.stft_kwargs,
602
- window=stft_window.cpu() if x_is_mps else stft_window,
603
- return_complex=True
604
- ).to(device)
605
- stft_repr = torch.view_as_real(stft_repr)
606
-
607
- stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
608
-
609
- # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
610
- stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
611
-
612
- x = rearrange(stft_repr, 'b f t c -> b t (f c)')
613
-
614
- if self.use_torch_checkpoint:
615
- x = checkpoint(self.band_split, x, use_reentrant=False)
616
- else:
617
- x = self.band_split(x)
618
-
619
- # axial / hierarchical attention
620
-
621
- store = [None] * len(self.layers)
622
-
623
- # Initialize value residuals if residual_value is enabled
624
- time_v_residual = None
625
- freq_v_residual = None
626
-
627
- for i, transformer_block in enumerate(self.layers):
628
-
629
- if len(transformer_block) == 3:
630
- linear_transformer, time_transformer, freq_transformer = transformer_block
631
-
632
- x, ft_ps = pack([x], 'b * d')
633
- if self.use_torch_checkpoint:
634
- linear_out, _ = checkpoint(linear_transformer, x, use_reentrant=False)
635
- else:
636
- linear_out, _ = linear_transformer(x)
637
- x, = unpack(linear_out, ft_ps, 'b * d')
638
- else:
639
- time_transformer, freq_transformer = transformer_block
640
-
641
- if self.skip_connection:
642
- # Sum all previous
643
- for j in range(i):
644
- x = x + store[j]
645
-
646
- # Time transformer
647
- x = rearrange(x, 'b t f d -> b f t d')
648
- x, ps = pack([x], '* t d')
649
-
650
- if self.use_torch_checkpoint:
651
- time_out, next_time_v_residual = checkpoint(time_transformer, x, use_reentrant=False)
652
- else:
653
- time_out, next_time_v_residual = time_transformer(x, value_residual=time_v_residual)
654
-
655
- if time_v_residual is None:
656
- time_v_residual = next_time_v_residual
657
-
658
- x = time_out
659
- x, = unpack(x, ps, '* t d')
660
- x = rearrange(x, 'b f t d -> b t f d')
661
-
662
- # Frequency transformer
663
- x, ps = pack([x], '* f d')
664
-
665
- if self.use_torch_checkpoint:
666
- freq_out, next_freq_v_residual = checkpoint(freq_transformer, x, use_reentrant=False)
667
- else:
668
- freq_out, next_freq_v_residual = freq_transformer(x, value_residual=freq_v_residual)
669
-
670
- if freq_v_residual is None:
671
- freq_v_residual = next_freq_v_residual
672
-
673
- x = freq_out
674
- x, = unpack(x, ps, '* f d')
675
-
676
- if self.skip_connection:
677
- store[i] = x
678
-
679
- x = self.final_norm(x)
680
-
681
- if active_stem_ids is None:
682
- heads = self.mask_estimators
683
- stem_ids = list(range(len(self.mask_estimators)))
684
- else:
685
- heads = [self.mask_estimators[i] for i in active_stem_ids]
686
- stem_ids = active_stem_ids
687
-
688
- num_stems = len(heads)
689
-
690
- if self.use_torch_checkpoint:
691
- mask = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in heads], dim=1)
692
- else:
693
- mask = torch.stack([fn(x) for fn in heads], dim=1)
694
- mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
695
-
696
- # modulate frequency representation
697
-
698
- stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
699
-
700
- # complex number multiplication
701
-
702
- stft_repr = torch.view_as_complex(stft_repr)
703
- mask = torch.view_as_complex(mask)
704
-
705
- stft_repr = stft_repr * mask
706
-
707
- # istft
708
-
709
- stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
710
-
711
- if self.zero_dc:
712
- # whether to dc filter
713
- stft_repr = stft_repr.index_fill(1, tensor(0, device = device), 0.)
714
-
715
- try:
716
- recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1])
717
- except:
718
- recon_audio = torch.istft(
719
- stft_repr.cpu() if x_is_mps else stft_repr,
720
- **self.stft_kwargs,
721
- window=stft_window.cpu() if x_is_mps else stft_window,
722
- return_complex=False,
723
- length=raw_audio.shape[-1]
724
- ).to(device)
725
-
726
- recon_audio = rearrange(
727
- recon_audio,
728
- '(b n s) t -> b n s t',
729
- s=self.audio_channels,
730
- n=num_stems
731
- )
732
-
733
- if not exists(target):
734
- return recon_audio
735
-
736
- if target.ndim == 2:
737
- target = rearrange(target, '... t -> ... 1 t')
738
-
739
- target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
740
-
741
- target_sel = target[:, stem_ids]
742
- loss = F.l1_loss(recon_audio, target_sel)
743
-
744
- multi_stft_resolution_loss = 0.
745
-
746
- for window_size in self.multi_stft_resolutions_window_sizes:
747
- res_stft_kwargs = dict(
748
- n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
749
- win_length=window_size,
750
- return_complex=True,
751
- window=self.multi_stft_window_fn(window_size, device=device),
752
- **self.multi_stft_kwargs,
753
- )
754
-
755
- recon_Y = torch.stft(rearrange(recon_audio, 'b n s t -> (b n s) t'),**res_stft_kwargs)
756
- target_Y = torch.stft(rearrange(target_sel, 'b n s t -> (b n s) t'),**res_stft_kwargs)
757
-
758
- multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
759
-
760
- weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
761
-
762
- total_loss = loss + weighted_multi_resolution_loss
763
-
764
- if not return_loss_breakdown:
765
- return total_loss
766
-
 
767
  return total_loss, (loss, multi_stft_resolution_loss)
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn, einsum, tensor, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ import torch.nn.functional as F
7
+
8
+ from .attend import Attend
9
+
10
+ from torch.utils.checkpoint import checkpoint
11
+
12
+ from beartype.typing import Tuple, Optional, List, Callable
13
+ from beartype import beartype
14
+
15
+ from rotary_embedding_torch import RotaryEmbedding
16
+
17
+ from einops import rearrange, pack, unpack
18
+ from einops.layers.torch import Rearrange
19
+
20
+ try:
21
+ from .pope.attention import flash_attn_with_pope
22
+ from .pope.pope import PoPE
23
+ _HAS_POPE = True
24
+ except Exception:
25
+ PoPE = None
26
+ flash_attn_with_pope = None
27
+ _HAS_POPE = False
28
+
29
+ # helper functions
30
+
31
+ def exists(val):
32
+ return val is not None
33
+
34
+
35
+ def default(v, d):
36
+ return v if exists(v) else d
37
+
38
+
39
+ def pack_one(t, pattern):
40
+ return pack([t], pattern)
41
+
42
+
43
+ def unpack_one(t, ps, pattern):
44
+ return unpack(t, ps, pattern)[0]
45
+
46
+
47
+ # norm
48
+
49
+ def l2norm(t):
50
+ return F.normalize(t, dim = -1, p = 2)
51
+
52
+
53
+ class RMSNorm(Module):
54
+ def __init__(self, dim):
55
+ super().__init__()
56
+ self.scale = dim ** 0.5
57
+ self.gamma = nn.Parameter(torch.ones(dim))
58
+
59
+ def forward(self, x):
60
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
61
+
62
+
63
+ # attention
64
+
65
+ class FeedForward(Module):
66
+ def __init__(
67
+ self,
68
+ dim,
69
+ mult=4,
70
+ dropout=0.
71
+ ):
72
+ super().__init__()
73
+ dim_inner = int(dim * mult)
74
+ self.net = nn.Sequential(
75
+ RMSNorm(dim),
76
+ nn.Linear(dim, dim_inner),
77
+ nn.GELU(),
78
+ nn.Dropout(dropout),
79
+ nn.Linear(dim_inner, dim),
80
+ nn.Dropout(dropout)
81
+ )
82
+
83
+ def forward(self, x):
84
+ return self.net(x)
85
+
86
+
87
+ class Attention(Module):
88
+ def __init__(
89
+ self,
90
+ dim,
91
+ heads=8,
92
+ dim_head=64,
93
+ dropout=0.,
94
+ rotary_embed=None,
95
+ flash=True,
96
+ pope_embed=None,
97
+ learned_value_residual_mix = False
98
+ ):
99
+ super().__init__()
100
+ self.heads = heads
101
+ self.scale = dim_head ** -0.5
102
+ dim_inner = heads * dim_head
103
+
104
+ self.rotary_embed = rotary_embed
105
+ self.pope_embed = pope_embed
106
+ assert not (self.rotary_embed is not None and self.pope_embed is not None), \
107
+ "cannot have both rotary and pope embeddings"
108
+
109
+ self.attend = Attend(flash=flash, dropout=dropout)
110
+
111
+ self.norm = RMSNorm(dim)
112
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
113
+
114
+ self.to_value_residual_mix = nn.Linear(dim, heads) if learned_value_residual_mix else None
115
+
116
+ self.to_gates = nn.Linear(dim, heads)
117
+
118
+ self.to_out = nn.Sequential(
119
+ nn.Linear(dim_inner, dim, bias=False),
120
+ nn.Dropout(dropout)
121
+ )
122
+
123
+ def forward(self, x, value_residual = None):
124
+ x = self.norm(x)
125
+
126
+ q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
127
+
128
+ orig_v = v
129
+
130
+ if exists(self.pope_embed):
131
+ out = flash_attn_with_pope(
132
+ q, k, v,
133
+ pos_emb=self.pope_embed(q.shape[-2]),
134
+ softmax_scale=self.scale
135
+ )
136
+ elif exists(self.rotary_embed):
137
+ q = self.rotary_embed.rotate_queries_or_keys(q)
138
+ k = self.rotary_embed.rotate_queries_or_keys(k)
139
+ out = self.attend(q, k, v)
140
+ elif exists(self.to_value_residual_mix):
141
+ mix = self.to_value_residual_mix(x)
142
+ mix = rearrange(mix, 'b n h -> b h n 1').sigmoid()
143
+
144
+ assert exists(value_residual)
145
+ v = v.lerp(value_residual, mix)
146
+ out = self.attend(q, k, v)
147
+ else:
148
+ out = self.attend(q, k, v)
149
+
150
+ gates = self.to_gates(x)
151
+ out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
152
+
153
+ out = rearrange(out, 'b h n d -> b n (h d)')
154
+ return self.to_out(out), orig_v
155
+
156
+
157
+ class LinearAttention(Module):
158
+ """
159
+ this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
160
+ """
161
+
162
+ @beartype
163
+ def __init__(
164
+ self,
165
+ *,
166
+ dim,
167
+ dim_head=32,
168
+ heads=8,
169
+ scale=8,
170
+ flash=False,
171
+ dropout=0.
172
+ ):
173
+ super().__init__()
174
+ dim_inner = dim_head * heads
175
+ self.norm = RMSNorm(dim)
176
+
177
+ self.to_qkv = nn.Sequential(
178
+ nn.Linear(dim, dim_inner * 3, bias=False),
179
+ Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
180
+ )
181
+
182
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
183
+
184
+ self.attend = Attend(
185
+ scale=scale,
186
+ dropout=dropout,
187
+ flash=flash
188
+ )
189
+
190
+ self.to_out = nn.Sequential(
191
+ Rearrange('b h d n -> b n (h d)'),
192
+ nn.Linear(dim_inner, dim, bias=False)
193
+ )
194
+
195
+ def forward(
196
+ self,
197
+ x
198
+ ):
199
+ x = self.norm(x)
200
+
201
+ q, k, v = self.to_qkv(x)
202
+
203
+ q, k = map(l2norm, (q, k))
204
+ q = q * self.temperature.exp()
205
+
206
+ out = self.attend(q, k, v)
207
+
208
+ return self.to_out(out)
209
+
210
+
211
+ class Transformer(Module):
212
+ def __init__(
213
+ self,
214
+ *,
215
+ dim,
216
+ depth,
217
+ dim_head=64,
218
+ heads=8,
219
+ attn_dropout=0.,
220
+ ff_dropout=0.,
221
+ ff_mult=4,
222
+ norm_output=True,
223
+ rotary_embed=None,
224
+ pope_embed=None,
225
+ flash_attn=True,
226
+ linear_attn=False,
227
+ add_value_residual = False
228
+ ):
229
+ super().__init__()
230
+ self.layers = ModuleList([])
231
+
232
+ for _ in range(depth):
233
+ if linear_attn:
234
+ attn = LinearAttention(
235
+ dim=dim,
236
+ dim_head=dim_head,
237
+ heads=heads,
238
+ dropout=attn_dropout,
239
+ flash=flash_attn
240
+ )
241
+ else:
242
+ attn = Attention(
243
+ dim=dim,
244
+ dim_head=dim_head,
245
+ heads=heads,
246
+ dropout=attn_dropout,
247
+ rotary_embed=rotary_embed,
248
+ pope_embed=pope_embed,
249
+ flash=flash_attn,
250
+ learned_value_residual_mix=add_value_residual
251
+ )
252
+
253
+ self.layers.append(ModuleList([
254
+ attn,
255
+ FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
256
+ ]))
257
+
258
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
259
+
260
+ def forward(self, x, value_residual=None):
261
+ first_values = None
262
+
263
+ for attn, ff in self.layers:
264
+ attn_out, next_values = attn(x, value_residual=value_residual)
265
+
266
+ if first_values is None:
267
+ first_values = next_values
268
+
269
+ x = attn_out + x
270
+ x = ff(x) + x
271
+
272
+ return self.norm(x), first_values
273
+
274
+
275
+ # bandsplit module
276
+
277
+ class BandSplit(Module):
278
+ @beartype
279
+ def __init__(
280
+ self,
281
+ dim,
282
+ dim_inputs: Tuple[int, ...]
283
+ ):
284
+ super().__init__()
285
+ self.dim_inputs = dim_inputs
286
+ self.to_features = ModuleList([])
287
+
288
+ for dim_in in dim_inputs:
289
+ net = nn.Sequential(
290
+ RMSNorm(dim_in),
291
+ nn.Linear(dim_in, dim)
292
+ )
293
+
294
+ self.to_features.append(net)
295
+
296
+ def forward(self, x):
297
+ x = x.split(self.dim_inputs, dim=-1)
298
+
299
+ outs = []
300
+ for split_input, to_feature in zip(x, self.to_features):
301
+ split_output = to_feature(split_input)
302
+ outs.append(split_output)
303
+
304
+ return torch.stack(outs, dim=-2)
305
+
306
+
307
+ def MLP(
308
+ dim_in,
309
+ dim_out,
310
+ dim_hidden=None,
311
+ depth=1,
312
+ activation=nn.Tanh
313
+ ):
314
+ dim_hidden = default(dim_hidden, dim_in)
315
+
316
+ net = []
317
+ dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
318
+
319
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
320
+ is_last = ind == (len(dims) - 2)
321
+
322
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
323
+
324
+ if is_last:
325
+ continue
326
+
327
+ net.append(activation())
328
+
329
+ return nn.Sequential(*net)
330
+
331
+
332
+ class MaskEstimator(Module):
333
+ @beartype
334
+ def __init__(
335
+ self,
336
+ dim,
337
+ dim_inputs: Tuple[int, ...],
338
+ depth,
339
+ mlp_expansion_factor=4
340
+ ):
341
+ super().__init__()
342
+ self.dim_inputs = dim_inputs
343
+ self.to_freqs = ModuleList([])
344
+ dim_hidden = dim * mlp_expansion_factor
345
+
346
+ for dim_in in dim_inputs:
347
+ net = []
348
+
349
+ mlp = nn.Sequential(
350
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
351
+ nn.GLU(dim=-1)
352
+ )
353
+
354
+ self.to_freqs.append(mlp)
355
+
356
+ def forward(self, x):
357
+ x = x.unbind(dim=-2)
358
+
359
+ outs = []
360
+
361
+ for band_features, mlp in zip(x, self.to_freqs):
362
+ freq_out = mlp(band_features)
363
+ outs.append(freq_out)
364
+
365
+ return torch.cat(outs, dim=-1)
366
+
367
+
368
+ # main class
369
+
370
+ DEFAULT_FREQS_PER_BANDS = (
371
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
372
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
373
+ 2, 2, 2, 2,
374
+ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
375
+ 12, 12, 12, 12, 12, 12, 12, 12,
376
+ 24, 24, 24, 24, 24, 24, 24, 24,
377
+ 48, 48, 48, 48, 48, 48, 48, 48,
378
+ 128, 129,
379
+ )
380
+
381
+
382
+ class BSRoformer(Module):
383
+
384
+ @beartype
385
+ def __init__(
386
+ self,
387
+ dim,
388
+ *,
389
+ depth,
390
+ stereo=False,
391
+ num_stems=1,
392
+ time_transformer_depth=2,
393
+ freq_transformer_depth=2,
394
+ linear_transformer_depth=0,
395
+ freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
396
+ # in the paper, they divide into ~60 bands, test with 1 for starters
397
+ dim_head=64,
398
+ heads=8,
399
+ attn_dropout=0.,
400
+ ff_dropout=0.,
401
+ flash_attn=True,
402
+ dim_freqs_in=1025,
403
+ stft_n_fft=2048,
404
+ stft_hop_length=512,
405
+ # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
406
+ stft_win_length=2048,
407
+ stft_normalized=False,
408
+ stft_window_fn: Optional[Callable] = None,
409
+ zero_dc = True,
410
+ mask_estimator_depth=2,
411
+ multi_stft_resolution_loss_weight=1.,
412
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
413
+ multi_stft_hop_size=147,
414
+ multi_stft_normalized=False,
415
+ multi_stft_window_fn: Callable = torch.hann_window,
416
+ mlp_expansion_factor=4,
417
+ use_torch_checkpoint=False,
418
+ skip_connection=False,
419
+ use_pope: bool = False,
420
+ residual_value: bool = False,
421
+ **kwargs
422
+ ):
423
+ super().__init__()
424
+
425
+ self.stereo = stereo
426
+ self.audio_channels = 2 if stereo else 1
427
+ self.num_stems = num_stems
428
+ self.use_torch_checkpoint = use_torch_checkpoint
429
+ self.skip_connection = skip_connection
430
+
431
+ self.layers = ModuleList([])
432
+
433
+ transformer_kwargs = dict(
434
+ dim=dim,
435
+ heads=heads,
436
+ dim_head=dim_head,
437
+ attn_dropout=attn_dropout,
438
+ ff_dropout=ff_dropout,
439
+ flash_attn=flash_attn,
440
+ norm_output=False,
441
+ )
442
+
443
+ if use_pope:
444
+ time_pope_embed = PoPE(dim=dim_head, heads=heads)
445
+ freq_pope_embed = PoPE(dim=dim_head, heads=heads)
446
+ time_rotary_embed = None
447
+ freq_rotary_embed = None
448
+ else:
449
+ time_rotary_embed = RotaryEmbedding(dim = dim_head)
450
+ freq_rotary_embed = RotaryEmbedding(dim = dim_head)
451
+ time_pope_embed = freq_pope_embed = None
452
+
453
+ if residual_value:
454
+ for layer_index in range(depth):
455
+ tran_modules = []
456
+ is_first = layer_index == 0
457
+ if linear_transformer_depth > 0:
458
+ tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, add_value_residual=not is_first, **transformer_kwargs))
459
+ tran_modules.append(
460
+ Transformer(
461
+ depth=time_transformer_depth,
462
+ rotary_embed=time_rotary_embed,
463
+ pope_embed=time_pope_embed,
464
+ add_value_residual=not is_first,
465
+ **transformer_kwargs
466
+ )
467
+ )
468
+ tran_modules.append(
469
+ Transformer(
470
+ depth=freq_transformer_depth,
471
+ rotary_embed=freq_rotary_embed,
472
+ pope_embed=freq_pope_embed,
473
+ add_value_residual=not is_first,
474
+ **transformer_kwargs
475
+ )
476
+ )
477
+ self.layers.append(nn.ModuleList(tran_modules))
478
+ else:
479
+ for layer_index in range(depth):
480
+ tran_modules = []
481
+ if linear_transformer_depth > 0:
482
+ tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, add_value_residual=not is_first, **transformer_kwargs))
483
+ tran_modules.append(
484
+ Transformer(
485
+ depth=time_transformer_depth,
486
+ rotary_embed=time_rotary_embed,
487
+ pope_embed=time_pope_embed,
488
+ add_value_residual=False,
489
+ **transformer_kwargs
490
+ )
491
+ )
492
+ tran_modules.append(
493
+ Transformer(
494
+ depth=freq_transformer_depth,
495
+ rotary_embed=freq_rotary_embed,
496
+ pope_embed=freq_pope_embed,
497
+ add_value_residual=False,
498
+ **transformer_kwargs
499
+ )
500
+ )
501
+ self.layers.append(nn.ModuleList(tran_modules))
502
+
503
+ self.final_norm = RMSNorm(dim)
504
+
505
+ self.stft_kwargs = dict(
506
+ n_fft=stft_n_fft,
507
+ hop_length=stft_hop_length,
508
+ win_length=stft_win_length,
509
+ normalized=stft_normalized
510
+ )
511
+
512
+ self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
513
+
514
+ freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_win_length), return_complex=True).shape[1]
515
+
516
+ assert len(freqs_per_bands) > 1
517
+ assert sum(
518
+ freqs_per_bands) == freqs, f'the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}'
519
+
520
+ freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in freqs_per_bands)
521
+
522
+ self.band_split = BandSplit(
523
+ dim=dim,
524
+ dim_inputs=freqs_per_bands_with_complex
525
+ )
526
+
527
+ self.mask_estimators = nn.ModuleList([])
528
+
529
+ for _ in range(num_stems):
530
+ mask_estimator = MaskEstimator(
531
+ dim=dim,
532
+ dim_inputs=freqs_per_bands_with_complex,
533
+ depth=mask_estimator_depth,
534
+ mlp_expansion_factor=mlp_expansion_factor,
535
+ )
536
+
537
+ self.mask_estimators.append(mask_estimator)
538
+
539
+ # whether to zero out dc
540
+
541
+ self.zero_dc = zero_dc
542
+
543
+ # for the multi-resolution stft loss
544
+
545
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
546
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
547
+ self.multi_stft_n_fft = stft_n_fft
548
+ self.multi_stft_window_fn = multi_stft_window_fn
549
+
550
+ self.multi_stft_kwargs = dict(
551
+ hop_length=multi_stft_hop_size,
552
+ normalized=multi_stft_normalized
553
+ )
554
+
555
+ def forward(
556
+ self,
557
+ raw_audio,
558
+ target=None,
559
+ active_stem_ids=None,
560
+ return_loss_breakdown=False
561
+ ):
562
+ """
563
+ einops
564
+
565
+ b - batch
566
+ f - freq
567
+ t - time
568
+ s - audio channel (1 for mono, 2 for stereo)
569
+ n - number of 'stems'
570
+ c - complex (2)
571
+ d - feature dimension
572
+ """
573
+
574
+ device = raw_audio.device
575
+ x_is_mps = True if device.type == "mps" else False
576
+
577
+ if raw_audio.ndim == 2:
578
+ raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
579
+
580
+ channels = raw_audio.shape[1]
581
+ assert (not self.stereo and channels == 1) or (
582
+ self.stereo and channels == 2),\
583
+ ('stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2).'
584
+ ' also need to be False if mono (channel dimension of 1)')
585
+
586
+ # to stft
587
+
588
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
589
+
590
+ stft_window = self.stft_window_fn(device=device)
591
+
592
+ try:
593
+ stft_repr = torch.stft(
594
+ raw_audio,
595
+ **self.stft_kwargs,
596
+ window=stft_window,
597
+ return_complex=True
598
+ )
599
+ except:
600
+ stft_repr = torch.stft(
601
+ raw_audio.cpu() if x_is_mps else raw_audio,
602
+ **self.stft_kwargs,
603
+ window=stft_window.cpu() if x_is_mps else stft_window,
604
+ return_complex=True
605
+ ).to(device)
606
+ stft_repr = torch.view_as_real(stft_repr)
607
+
608
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
609
+
610
+ # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
611
+ stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
612
+
613
+ x = rearrange(stft_repr, 'b f t c -> b t (f c)')
614
+
615
+ if self.use_torch_checkpoint:
616
+ x = checkpoint(self.band_split, x, use_reentrant=False)
617
+ else:
618
+ x = self.band_split(x)
619
+
620
+ # axial / hierarchical attention
621
+
622
+ store = [None] * len(self.layers)
623
+
624
+ # Initialize value residuals if residual_value is enabled
625
+ time_v_residual = None
626
+ freq_v_residual = None
627
+
628
+ for i, transformer_block in enumerate(self.layers):
629
+
630
+ if len(transformer_block) == 3:
631
+ linear_transformer, time_transformer, freq_transformer = transformer_block
632
+
633
+ x, ft_ps = pack([x], 'b * d')
634
+ if self.use_torch_checkpoint:
635
+ linear_out, _ = checkpoint(linear_transformer, x, use_reentrant=False)
636
+ else:
637
+ linear_out, _ = linear_transformer(x)
638
+ x, = unpack(linear_out, ft_ps, 'b * d')
639
+ else:
640
+ time_transformer, freq_transformer = transformer_block
641
+
642
+ if self.skip_connection:
643
+ # Sum all previous
644
+ for j in range(i):
645
+ x = x + store[j]
646
+
647
+ # Time transformer
648
+ x = rearrange(x, 'b t f d -> b f t d')
649
+ x, ps = pack([x], '* t d')
650
+
651
+ if self.use_torch_checkpoint:
652
+ time_out, next_time_v_residual = checkpoint(time_transformer, x, use_reentrant=False)
653
+ else:
654
+ time_out, next_time_v_residual = time_transformer(x, value_residual=time_v_residual)
655
+
656
+ if time_v_residual is None:
657
+ time_v_residual = next_time_v_residual
658
+
659
+ x = time_out
660
+ x, = unpack(x, ps, '* t d')
661
+ x = rearrange(x, 'b f t d -> b t f d')
662
+
663
+ # Frequency transformer
664
+ x, ps = pack([x], '* f d')
665
+
666
+ if self.use_torch_checkpoint:
667
+ freq_out, next_freq_v_residual = checkpoint(freq_transformer, x, use_reentrant=False)
668
+ else:
669
+ freq_out, next_freq_v_residual = freq_transformer(x, value_residual=freq_v_residual)
670
+
671
+ if freq_v_residual is None:
672
+ freq_v_residual = next_freq_v_residual
673
+
674
+ x = freq_out
675
+ x, = unpack(x, ps, '* f d')
676
+
677
+ if self.skip_connection:
678
+ store[i] = x
679
+
680
+ x = self.final_norm(x)
681
+
682
+ if active_stem_ids is None:
683
+ heads = self.mask_estimators
684
+ stem_ids = list(range(len(self.mask_estimators)))
685
+ else:
686
+ heads = [self.mask_estimators[i] for i in active_stem_ids]
687
+ stem_ids = active_stem_ids
688
+
689
+ num_stems = len(heads)
690
+
691
+ if self.use_torch_checkpoint:
692
+ mask = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in heads], dim=1)
693
+ else:
694
+ mask = torch.stack([fn(x) for fn in heads], dim=1)
695
+ mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
696
+
697
+ # modulate frequency representation
698
+
699
+ stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
700
+
701
+ # complex number multiplication
702
+
703
+ stft_repr = torch.view_as_complex(stft_repr)
704
+ mask = torch.view_as_complex(mask)
705
+
706
+ stft_repr = stft_repr * mask
707
+
708
+ # istft
709
+
710
+ stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
711
+
712
+ if self.zero_dc:
713
+ # whether to dc filter
714
+ stft_repr = stft_repr.index_fill(1, tensor(0, device = device), 0.)
715
+
716
+ try:
717
+ recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False, length=raw_audio.shape[-1])
718
+ except:
719
+ recon_audio = torch.istft(
720
+ stft_repr.cpu() if x_is_mps else stft_repr,
721
+ **self.stft_kwargs,
722
+ window=stft_window.cpu() if x_is_mps else stft_window,
723
+ return_complex=False,
724
+ length=raw_audio.shape[-1]
725
+ ).to(device)
726
+
727
+ recon_audio = rearrange(
728
+ recon_audio,
729
+ '(b n s) t -> b n s t',
730
+ s=self.audio_channels,
731
+ n=num_stems
732
+ )
733
+
734
+ if not exists(target):
735
+ return recon_audio
736
+
737
+ if target.ndim == 2:
738
+ target = rearrange(target, '... t -> ... 1 t')
739
+
740
+ target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
741
+
742
+ target_sel = target[:, stem_ids]
743
+ loss = F.l1_loss(recon_audio, target_sel)
744
+
745
+ multi_stft_resolution_loss = 0.
746
+
747
+ for window_size in self.multi_stft_resolutions_window_sizes:
748
+ res_stft_kwargs = dict(
749
+ n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
750
+ win_length=window_size,
751
+ return_complex=True,
752
+ window=self.multi_stft_window_fn(window_size, device=device),
753
+ **self.multi_stft_kwargs,
754
+ )
755
+
756
+ recon_Y = torch.stft(rearrange(recon_audio, 'b n s t -> (b n s) t'),**res_stft_kwargs)
757
+ target_Y = torch.stft(rearrange(target_sel, 'b n s t -> (b n s) t'),**res_stft_kwargs)
758
+
759
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
760
+
761
+ weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
762
+
763
+ total_loss = loss + weighted_multi_resolution_loss
764
+
765
+ if not return_loss_breakdown:
766
+ return total_loss
767
+
768
  return total_loss, (loss, multi_stft_resolution_loss)
models/bs_roformer/bs_roformer_conditional.py CHANGED
@@ -7,10 +7,7 @@ import torch.nn.functional as F
7
 
8
  from .attend import Attend
9
  from .conditioner import BandEmbedder
10
- try:
11
- from .attend_sage import Attend as AttendSage
12
- except:
13
- pass
14
  from torch.utils.checkpoint import checkpoint
15
 
16
  from beartype.typing import Tuple, Optional, List, Callable
@@ -88,7 +85,6 @@ class Attention(Module):
88
  dropout=0.,
89
  rotary_embed=None,
90
  flash=True,
91
- sage_attention=False,
92
  ):
93
  super().__init__()
94
  self.heads = heads
@@ -97,10 +93,7 @@ class Attention(Module):
97
 
98
  self.rotary_embed = rotary_embed
99
 
100
- if sage_attention:
101
- self.attend = AttendSage(flash=flash, dropout=dropout)
102
- else:
103
- self.attend = Attend(flash=flash, dropout=dropout)
104
 
105
  self.norm = RMSNorm(dim)
106
  self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
@@ -145,7 +138,6 @@ class LinearAttention(Module):
145
  scale=8,
146
  flash=True,
147
  dropout=0.,
148
- sage_attention=False,
149
  ):
150
  super().__init__()
151
  dim_inner = dim_head * heads
@@ -158,18 +150,11 @@ class LinearAttention(Module):
158
 
159
  self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
160
 
161
- if sage_attention:
162
- self.attend = AttendSage(
163
- scale=scale,
164
- dropout=dropout,
165
- flash=flash
166
- )
167
- else:
168
- self.attend = Attend(
169
- scale=scale,
170
- dropout=dropout,
171
- flash=flash
172
- )
173
 
174
  self.to_out = nn.Sequential(
175
  Rearrange('b h d n -> b n (h d)'),
@@ -207,7 +192,6 @@ class Transformer(Module):
207
  rotary_embed=None,
208
  flash_attn=True,
209
  linear_attn=False,
210
- sage_attention=False,
211
  ):
212
  super().__init__()
213
  self.layers = ModuleList([])
@@ -220,7 +204,6 @@ class Transformer(Module):
220
  heads=heads,
221
  dropout=attn_dropout,
222
  flash=flash_attn,
223
- sage_attention=sage_attention
224
  )
225
  else:
226
  attn = Attention(
@@ -230,7 +213,6 @@ class Transformer(Module):
230
  dropout=attn_dropout,
231
  rotary_embed=rotary_embed,
232
  flash=flash_attn,
233
- sage_attention=sage_attention
234
  )
235
 
236
  self.layers.append(ModuleList([
@@ -268,7 +250,6 @@ class ScaleTransformer(Module):
268
  rotary_embed=None,
269
  flash_attn=True,
270
  linear_attn=False,
271
- sage_attention=False,
272
  ):
273
  super().__init__()
274
  self.layers = ModuleList([])
@@ -285,7 +266,6 @@ class ScaleTransformer(Module):
285
  heads=heads,
286
  dropout=attn_dropout,
287
  flash=flash_attn,
288
- sage_attention=sage_attention
289
  )
290
  else:
291
  attn = Attention(
@@ -295,7 +275,6 @@ class ScaleTransformer(Module):
295
  dropout=attn_dropout,
296
  rotary_embed=rotary_embed,
297
  flash=flash_attn,
298
- sage_attention=sage_attention
299
  )
300
 
301
  norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
@@ -353,7 +332,6 @@ class Transformer(Module):
353
  rotary_embed=None,
354
  flash_attn=True,
355
  linear_attn=False,
356
- sage_attention=False,
357
  ):
358
  super().__init__()
359
  self.layers = ModuleList([])
@@ -366,7 +344,6 @@ class Transformer(Module):
366
  heads=heads,
367
  dropout=attn_dropout,
368
  flash=flash_attn,
369
- sage_attention=sage_attention
370
  )
371
  else:
372
  attn = Attention(
@@ -376,7 +353,6 @@ class Transformer(Module):
376
  dropout=attn_dropout,
377
  rotary_embed=rotary_embed,
378
  flash=flash_attn,
379
- sage_attention=sage_attention
380
  )
381
 
382
  self.layers.append(ModuleList([
@@ -538,7 +514,7 @@ class BandConditionalBSRoformer(Module):
538
  mlp_expansion_factor=4,
539
  use_torch_checkpoint=False,
540
  skip_connection=False,
541
- sage_attention=False,
542
  ):
543
  super().__init__()
544
 
@@ -551,9 +527,6 @@ class BandConditionalBSRoformer(Module):
551
 
552
  self.layers = ModuleList([])
553
 
554
- if sage_attention:
555
- print("Use Sage Attention")
556
-
557
  transformer_kwargs = dict(
558
  dim=dim,
559
  heads=heads,
@@ -562,7 +535,6 @@ class BandConditionalBSRoformer(Module):
562
  ff_dropout=ff_dropout,
563
  flash_attn=flash_attn,
564
  norm_output=False,
565
- sage_attention=sage_attention,
566
  )
567
 
568
  time_rotary_embed = RotaryEmbedding(dim=dim_head)
@@ -884,7 +856,7 @@ class BSRoformer_Conditional(Module):
884
  mlp_expansion_factor=4,
885
  use_torch_checkpoint=False,
886
  skip_connection=False,
887
- sage_attention=False,
888
  ):
889
  super().__init__()
890
 
@@ -896,9 +868,6 @@ class BSRoformer_Conditional(Module):
896
 
897
  self.layers = ModuleList([])
898
 
899
- if sage_attention:
900
- print("Use Sage Attention")
901
-
902
  transformer_kwargs = dict(
903
  dim=dim,
904
  heads=heads,
@@ -907,7 +876,6 @@ class BSRoformer_Conditional(Module):
907
  ff_dropout=ff_dropout,
908
  flash_attn=flash_attn,
909
  norm_output=False,
910
- sage_attention=sage_attention,
911
  )
912
 
913
  time_rotary_embed = RotaryEmbedding(dim=dim_head)
 
7
 
8
  from .attend import Attend
9
  from .conditioner import BandEmbedder
10
+
 
 
 
11
  from torch.utils.checkpoint import checkpoint
12
 
13
  from beartype.typing import Tuple, Optional, List, Callable
 
85
  dropout=0.,
86
  rotary_embed=None,
87
  flash=True,
 
88
  ):
89
  super().__init__()
90
  self.heads = heads
 
93
 
94
  self.rotary_embed = rotary_embed
95
 
96
+ self.attend = Attend(flash=flash, dropout=dropout)
 
 
 
97
 
98
  self.norm = RMSNorm(dim)
99
  self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
 
138
  scale=8,
139
  flash=True,
140
  dropout=0.,
 
141
  ):
142
  super().__init__()
143
  dim_inner = dim_head * heads
 
150
 
151
  self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
152
 
153
+ self.attend = Attend(
154
+ scale=scale,
155
+ dropout=dropout,
156
+ flash=flash
157
+ )
 
 
 
 
 
 
 
158
 
159
  self.to_out = nn.Sequential(
160
  Rearrange('b h d n -> b n (h d)'),
 
192
  rotary_embed=None,
193
  flash_attn=True,
194
  linear_attn=False,
 
195
  ):
196
  super().__init__()
197
  self.layers = ModuleList([])
 
204
  heads=heads,
205
  dropout=attn_dropout,
206
  flash=flash_attn,
 
207
  )
208
  else:
209
  attn = Attention(
 
213
  dropout=attn_dropout,
214
  rotary_embed=rotary_embed,
215
  flash=flash_attn,
 
216
  )
217
 
218
  self.layers.append(ModuleList([
 
250
  rotary_embed=None,
251
  flash_attn=True,
252
  linear_attn=False,
 
253
  ):
254
  super().__init__()
255
  self.layers = ModuleList([])
 
266
  heads=heads,
267
  dropout=attn_dropout,
268
  flash=flash_attn,
 
269
  )
270
  else:
271
  attn = Attention(
 
275
  dropout=attn_dropout,
276
  rotary_embed=rotary_embed,
277
  flash=flash_attn,
 
278
  )
279
 
280
  norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
 
332
  rotary_embed=None,
333
  flash_attn=True,
334
  linear_attn=False,
 
335
  ):
336
  super().__init__()
337
  self.layers = ModuleList([])
 
344
  heads=heads,
345
  dropout=attn_dropout,
346
  flash=flash_attn,
 
347
  )
348
  else:
349
  attn = Attention(
 
353
  dropout=attn_dropout,
354
  rotary_embed=rotary_embed,
355
  flash=flash_attn,
 
356
  )
357
 
358
  self.layers.append(ModuleList([
 
514
  mlp_expansion_factor=4,
515
  use_torch_checkpoint=False,
516
  skip_connection=False,
517
+ **kwargs
518
  ):
519
  super().__init__()
520
 
 
527
 
528
  self.layers = ModuleList([])
529
 
 
 
 
530
  transformer_kwargs = dict(
531
  dim=dim,
532
  heads=heads,
 
535
  ff_dropout=ff_dropout,
536
  flash_attn=flash_attn,
537
  norm_output=False,
 
538
  )
539
 
540
  time_rotary_embed = RotaryEmbedding(dim=dim_head)
 
856
  mlp_expansion_factor=4,
857
  use_torch_checkpoint=False,
858
  skip_connection=False,
859
+ **kwargs
860
  ):
861
  super().__init__()
862
 
 
868
 
869
  self.layers = ModuleList([])
870
 
 
 
 
871
  transformer_kwargs = dict(
872
  dim=dim,
873
  heads=heads,
 
876
  ff_dropout=ff_dropout,
877
  flash_attn=flash_attn,
878
  norm_output=False,
 
879
  )
880
 
881
  time_rotary_embed = RotaryEmbedding(dim=dim_head)
models/bs_roformer/bs_roformer_fno.py CHANGED
@@ -1,704 +1,685 @@
1
- from functools import partial
2
-
3
- import torch
4
- from torch import nn, einsum, Tensor
5
- from torch.nn import Module, ModuleList
6
- import torch.nn.functional as F
7
- from .fno1d import FNO1d
8
-
9
- from .attend import Attend
10
-
11
- try:
12
- from .attend_sage import Attend as AttendSage
13
- except:
14
- pass
15
- from torch.utils.checkpoint import checkpoint
16
-
17
- from beartype.typing import Tuple, Optional, List, Callable
18
- from beartype import beartype
19
-
20
- from rotary_embedding_torch import RotaryEmbedding
21
-
22
- from einops import rearrange, pack, unpack
23
- from einops.layers.torch import Rearrange
24
-
25
-
26
- def exists(val):
27
- return val is not None
28
-
29
-
30
- def default(v, d):
31
- return v if exists(v) else d
32
-
33
-
34
- def pack_one(t, pattern):
35
- return pack([t], pattern)
36
-
37
-
38
- def unpack_one(t, ps, pattern):
39
- return unpack(t, ps, pattern)[0]
40
-
41
-
42
- def l2norm(t):
43
- return F.normalize(t, dim=-1, p=2)
44
-
45
-
46
- class RMSNorm(Module):
47
- def __init__(self, dim):
48
- super().__init__()
49
- self.scale = dim**0.5
50
- self.gamma = nn.Parameter(torch.ones(dim))
51
-
52
- def forward(self, x):
53
- return F.normalize(x, dim=-1) * self.scale * self.gamma
54
-
55
-
56
- class FeedForward(Module):
57
- def __init__(self, dim, mult=4, dropout=0.0):
58
- super().__init__()
59
- dim_inner = int(dim * mult)
60
- self.net = nn.Sequential(
61
- RMSNorm(dim),
62
- nn.Linear(dim, dim_inner),
63
- nn.GELU(),
64
- nn.Dropout(dropout),
65
- nn.Linear(dim_inner, dim),
66
- nn.Dropout(dropout),
67
- )
68
-
69
- def forward(self, x):
70
- return self.net(x)
71
-
72
-
73
- class Attention(Module):
74
- def __init__(
75
- self,
76
- dim,
77
- heads=8,
78
- dim_head=64,
79
- dropout=0.0,
80
- rotary_embed=None,
81
- flash=True,
82
- sage_attention=False,
83
- ):
84
- super().__init__()
85
- self.heads = heads
86
- self.scale = dim_head**-0.5
87
- dim_inner = heads * dim_head
88
-
89
- self.rotary_embed = rotary_embed
90
-
91
- if sage_attention:
92
- self.attend = AttendSage(flash=flash, dropout=dropout)
93
- else:
94
- self.attend = Attend(flash=flash, dropout=dropout)
95
-
96
- self.norm = RMSNorm(dim)
97
- self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
98
-
99
- self.to_gates = nn.Linear(dim, heads)
100
-
101
- self.to_out = nn.Sequential(
102
- nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
103
- )
104
-
105
- def forward(self, x):
106
- x = self.norm(x)
107
-
108
- q, k, v = rearrange(
109
- self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
110
- )
111
-
112
- if exists(self.rotary_embed):
113
- q = self.rotary_embed.rotate_queries_or_keys(q)
114
- k = self.rotary_embed.rotate_queries_or_keys(k)
115
-
116
- out = self.attend(q, k, v)
117
-
118
- gates = self.to_gates(x)
119
- out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
120
-
121
- out = rearrange(out, "b h n d -> b n (h d)")
122
- return self.to_out(out)
123
-
124
-
125
- class LinearAttention(Module):
126
-
127
- @beartype
128
- def __init__(
129
- self,
130
- *,
131
- dim,
132
- dim_head=32,
133
- heads=8,
134
- scale=8,
135
- flash=False,
136
- dropout=0.0,
137
- sage_attention=False,
138
- ):
139
- super().__init__()
140
- dim_inner = dim_head * heads
141
- self.norm = RMSNorm(dim)
142
-
143
- self.to_qkv = nn.Sequential(
144
- nn.Linear(dim, dim_inner * 3, bias=False),
145
- Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
146
- )
147
-
148
- self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
149
-
150
- if sage_attention:
151
- self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash)
152
- else:
153
- self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
154
-
155
- self.to_out = nn.Sequential(
156
- Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
157
- )
158
-
159
- def forward(self, x):
160
- x = self.norm(x)
161
-
162
- q, k, v = self.to_qkv(x)
163
-
164
- q, k = map(l2norm, (q, k))
165
- q = q * self.temperature.exp()
166
-
167
- out = self.attend(q, k, v)
168
-
169
- return self.to_out(out)
170
-
171
-
172
- class Transformer(Module):
173
- def __init__(
174
- self,
175
- *,
176
- dim,
177
- depth,
178
- dim_head=64,
179
- heads=8,
180
- attn_dropout=0.0,
181
- ff_dropout=0.0,
182
- ff_mult=4,
183
- norm_output=True,
184
- rotary_embed=None,
185
- flash_attn=True,
186
- linear_attn=False,
187
- sage_attention=False,
188
- ):
189
- super().__init__()
190
- self.layers = ModuleList([])
191
-
192
- for _ in range(depth):
193
- if linear_attn:
194
- attn = LinearAttention(
195
- dim=dim,
196
- dim_head=dim_head,
197
- heads=heads,
198
- dropout=attn_dropout,
199
- flash=flash_attn,
200
- sage_attention=sage_attention,
201
- )
202
- else:
203
- attn = Attention(
204
- dim=dim,
205
- dim_head=dim_head,
206
- heads=heads,
207
- dropout=attn_dropout,
208
- rotary_embed=rotary_embed,
209
- flash=flash_attn,
210
- sage_attention=sage_attention,
211
- )
212
-
213
- self.layers.append(
214
- ModuleList(
215
- [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
216
- )
217
- )
218
-
219
- self.norm = RMSNorm(dim) if norm_output else nn.Identity()
220
-
221
- def forward(self, x):
222
-
223
- for attn, ff in self.layers:
224
- x = attn(x) + x
225
- x = ff(x) + x
226
-
227
- return self.norm(x)
228
-
229
-
230
- class BandSplit(Module):
231
- @beartype
232
- def __init__(self, dim, dim_inputs: Tuple[int, ...]):
233
- super().__init__()
234
- self.dim_inputs = dim_inputs
235
- self.to_features = ModuleList([])
236
-
237
- for dim_in in dim_inputs:
238
- net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
239
-
240
- self.to_features.append(net)
241
-
242
- def forward(self, x):
243
- x = x.split(self.dim_inputs, dim=-1)
244
-
245
- outs = []
246
- for split_input, to_feature in zip(x, self.to_features):
247
- split_output = to_feature(split_input)
248
- outs.append(split_output)
249
-
250
- return torch.stack(outs, dim=-2)
251
-
252
-
253
- def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
254
- dim_hidden = default(dim_hidden, dim_in)
255
-
256
- net = []
257
- dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
258
-
259
- for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
260
- is_last = ind == (len(dims) - 2)
261
-
262
- net.append(nn.Linear(layer_dim_in, layer_dim_out))
263
-
264
- if is_last:
265
- continue
266
-
267
- net.append(activation())
268
-
269
- return nn.Sequential(*net)
270
-
271
-
272
- class MaskEstimator(Module):
273
- @beartype
274
- def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
275
- super().__init__()
276
- self.dim_inputs = dim_inputs
277
- self.to_freqs = ModuleList([])
278
- dim_hidden = dim * mlp_expansion_factor
279
-
280
- for dim_in in dim_inputs:
281
- net = []
282
-
283
- mlp = nn.Sequential(
284
- FNO1d(
285
- n_modes_height=64,
286
- hidden_channels=dim,
287
- in_channels=dim,
288
- out_channels=dim_in * 2,
289
- lifting_channels=dim,
290
- projection_channels=dim,
291
- n_layers=3,
292
- separable=True,
293
- ),
294
- nn.GLU(dim=-2),
295
- )
296
-
297
- self.to_freqs.append(mlp)
298
-
299
- def forward(self, x):
300
- x = x.unbind(dim=-2)
301
-
302
- outs = []
303
-
304
- for band_features, mlp in zip(x, self.to_freqs):
305
- band_features = rearrange(band_features, "b t c -> b c t")
306
- with torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32):
307
- freq_out = mlp(band_features).float()
308
- freq_out = rearrange(freq_out, "b c t -> b t c")
309
- outs.append(freq_out)
310
-
311
- return torch.cat(outs, dim=-1)
312
-
313
-
314
- DEFAULT_FREQS_PER_BANDS = (
315
- 2,
316
- 2,
317
- 2,
318
- 2,
319
- 2,
320
- 2,
321
- 2,
322
- 2,
323
- 2,
324
- 2,
325
- 2,
326
- 2,
327
- 2,
328
- 2,
329
- 2,
330
- 2,
331
- 2,
332
- 2,
333
- 2,
334
- 2,
335
- 2,
336
- 2,
337
- 2,
338
- 2,
339
- 4,
340
- 4,
341
- 4,
342
- 4,
343
- 4,
344
- 4,
345
- 4,
346
- 4,
347
- 4,
348
- 4,
349
- 4,
350
- 4,
351
- 12,
352
- 12,
353
- 12,
354
- 12,
355
- 12,
356
- 12,
357
- 12,
358
- 12,
359
- 24,
360
- 24,
361
- 24,
362
- 24,
363
- 24,
364
- 24,
365
- 24,
366
- 24,
367
- 48,
368
- 48,
369
- 48,
370
- 48,
371
- 48,
372
- 48,
373
- 48,
374
- 48,
375
- 128,
376
- 129,
377
- )
378
-
379
-
380
- class BSRoformer_FNO(Module):
381
-
382
- @beartype
383
- def __init__(
384
- self,
385
- dim,
386
- *,
387
- depth,
388
- stereo=False,
389
- num_stems=1,
390
- time_transformer_depth=2,
391
- freq_transformer_depth=2,
392
- linear_transformer_depth=0,
393
- freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
394
- dim_head=64,
395
- heads=8,
396
- attn_dropout=0.0,
397
- ff_dropout=0.0,
398
- flash_attn=True,
399
- dim_freqs_in=1025,
400
- stft_n_fft=2048,
401
- stft_hop_length=512,
402
- stft_win_length=2048,
403
- stft_normalized=False,
404
- stft_window_fn: Optional[Callable] = None,
405
- mask_estimator_depth=2,
406
- multi_stft_resolution_loss_weight=1.0,
407
- multi_stft_resolutions_window_sizes: Tuple[int, ...] = (
408
- 4096,
409
- 2048,
410
- 1024,
411
- 512,
412
- 256,
413
- ),
414
- multi_stft_hop_size=147,
415
- multi_stft_normalized=False,
416
- multi_stft_window_fn: Callable = torch.hann_window,
417
- mlp_expansion_factor=4,
418
- use_torch_checkpoint=False,
419
- skip_connection=False,
420
- sage_attention=False,
421
- ):
422
- super().__init__()
423
-
424
- self.stereo = stereo
425
- self.audio_channels = 2 if stereo else 1
426
- self.num_stems = num_stems
427
- self.use_torch_checkpoint = use_torch_checkpoint
428
- self.skip_connection = skip_connection
429
-
430
- self.layers = ModuleList([])
431
-
432
- if sage_attention:
433
- print("Use Sage Attention")
434
-
435
- transformer_kwargs = dict(
436
- dim=dim,
437
- heads=heads,
438
- dim_head=dim_head,
439
- attn_dropout=attn_dropout,
440
- ff_dropout=ff_dropout,
441
- flash_attn=flash_attn,
442
- norm_output=False,
443
- sage_attention=sage_attention,
444
- )
445
-
446
- time_rotary_embed = RotaryEmbedding(dim=dim_head)
447
- freq_rotary_embed = RotaryEmbedding(dim=dim_head)
448
-
449
- for _ in range(depth):
450
- tran_modules = []
451
- if linear_transformer_depth > 0:
452
- tran_modules.append(
453
- Transformer(
454
- depth=linear_transformer_depth,
455
- linear_attn=True,
456
- **transformer_kwargs,
457
- )
458
- )
459
- tran_modules.append(
460
- Transformer(
461
- depth=time_transformer_depth,
462
- rotary_embed=time_rotary_embed,
463
- **transformer_kwargs,
464
- )
465
- )
466
- tran_modules.append(
467
- Transformer(
468
- depth=freq_transformer_depth,
469
- rotary_embed=freq_rotary_embed,
470
- **transformer_kwargs,
471
- )
472
- )
473
- self.layers.append(nn.ModuleList(tran_modules))
474
-
475
- self.final_norm = RMSNorm(dim)
476
-
477
- self.stft_kwargs = dict(
478
- n_fft=stft_n_fft,
479
- hop_length=stft_hop_length,
480
- win_length=stft_win_length,
481
- normalized=stft_normalized,
482
- )
483
-
484
- self.stft_window_fn = partial(
485
- default(stft_window_fn, torch.hann_window), stft_win_length
486
- )
487
-
488
- freqs = torch.stft(
489
- torch.randn(1, 4096),
490
- **self.stft_kwargs,
491
- window=torch.ones(stft_win_length),
492
- return_complex=True,
493
- ).shape[1]
494
-
495
- assert len(freqs_per_bands) > 1
496
- assert (
497
- sum(freqs_per_bands) == freqs
498
- ), f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
499
-
500
- freqs_per_bands_with_complex = tuple(
501
- 2 * f * self.audio_channels for f in freqs_per_bands
502
- )
503
-
504
- self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
505
-
506
- self.mask_estimators = nn.ModuleList([])
507
-
508
- for _ in range(num_stems):
509
- mask_estimator = MaskEstimator(
510
- dim=dim,
511
- dim_inputs=freqs_per_bands_with_complex,
512
- depth=mask_estimator_depth,
513
- mlp_expansion_factor=mlp_expansion_factor,
514
- )
515
-
516
- self.mask_estimators.append(mask_estimator)
517
-
518
- self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
519
- self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
520
- self.multi_stft_n_fft = stft_n_fft
521
- self.multi_stft_window_fn = multi_stft_window_fn
522
-
523
- self.multi_stft_kwargs = dict(
524
- hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
525
- )
526
-
527
- def forward(self, raw_audio, target=None, return_loss_breakdown=False):
528
-
529
- device = raw_audio.device
530
-
531
- x_is_mps = True if device.type == "mps" else False
532
-
533
- if raw_audio.ndim == 2:
534
- raw_audio = rearrange(raw_audio, "b t -> b 1 t")
535
-
536
- channels = raw_audio.shape[1]
537
- assert (not self.stereo and channels == 1) or (
538
- self.stereo and channels == 2
539
- ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
540
-
541
- raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
542
-
543
- stft_window = self.stft_window_fn(device=device)
544
-
545
- try:
546
- stft_repr = torch.stft(
547
- raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
548
- )
549
- except:
550
- stft_repr = torch.stft(
551
- raw_audio.cpu() if x_is_mps else raw_audio,
552
- **self.stft_kwargs,
553
- window=stft_window.cpu() if x_is_mps else stft_window,
554
- return_complex=True,
555
- ).to(device)
556
- stft_repr = torch.view_as_real(stft_repr)
557
-
558
- stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
559
-
560
- stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
561
-
562
- x = rearrange(stft_repr, "b f t c -> b t (f c)")
563
-
564
- if self.use_torch_checkpoint:
565
- x = checkpoint(self.band_split, x, use_reentrant=False)
566
- else:
567
- x = self.band_split(x)
568
-
569
- store = [None] * len(self.layers)
570
- for i, transformer_block in enumerate(self.layers):
571
-
572
- if len(transformer_block) == 3:
573
- linear_transformer, time_transformer, freq_transformer = (
574
- transformer_block
575
- )
576
-
577
- x, ft_ps = pack([x], "b * d")
578
- if self.use_torch_checkpoint:
579
- x = checkpoint(linear_transformer, x, use_reentrant=False)
580
- else:
581
- x = linear_transformer(x)
582
- (x,) = unpack(x, ft_ps, "b * d")
583
- else:
584
- time_transformer, freq_transformer = transformer_block
585
-
586
- if self.skip_connection:
587
- for j in range(i):
588
- x = x + store[j]
589
-
590
- x = rearrange(x, "b t f d -> b f t d")
591
- x, ps = pack([x], "* t d")
592
-
593
- if self.use_torch_checkpoint:
594
- x = checkpoint(time_transformer, x, use_reentrant=False)
595
- else:
596
- x = time_transformer(x)
597
-
598
- (x,) = unpack(x, ps, "* t d")
599
- x = rearrange(x, "b f t d -> b t f d")
600
- x, ps = pack([x], "* f d")
601
-
602
- if self.use_torch_checkpoint:
603
- x = checkpoint(freq_transformer, x, use_reentrant=False)
604
- else:
605
- x = freq_transformer(x)
606
-
607
- (x,) = unpack(x, ps, "* f d")
608
-
609
- if self.skip_connection:
610
- store[i] = x
611
-
612
- x = self.final_norm(x)
613
-
614
- num_stems = len(self.mask_estimators)
615
-
616
- if self.use_torch_checkpoint:
617
- mask = torch.stack(
618
- [checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators],
619
- dim=1,
620
- )
621
- else:
622
- mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
623
- mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
624
-
625
- stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
626
-
627
- stft_repr = torch.view_as_complex(stft_repr)
628
- mask = torch.view_as_complex(mask)
629
-
630
- stft_repr = stft_repr * mask
631
-
632
- stft_repr = rearrange(
633
- stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
634
- )
635
-
636
- try:
637
- recon_audio = torch.istft(
638
- stft_repr,
639
- **self.stft_kwargs,
640
- window=stft_window,
641
- return_complex=False,
642
- length=raw_audio.shape[-1],
643
- )
644
- except:
645
- recon_audio = torch.istft(
646
- stft_repr.cpu() if x_is_mps else stft_repr,
647
- **self.stft_kwargs,
648
- window=stft_window.cpu() if x_is_mps else stft_window,
649
- return_complex=False,
650
- length=raw_audio.shape[-1],
651
- ).to(device)
652
-
653
- recon_audio = rearrange(
654
- recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
655
- )
656
-
657
- if num_stems == 1:
658
- recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
659
-
660
- if not exists(target):
661
- return recon_audio
662
-
663
- if self.num_stems > 1:
664
- assert target.ndim == 4 and target.shape[1] == self.num_stems
665
-
666
- if target.ndim == 2:
667
- target = rearrange(target, "... t -> ... 1 t")
668
-
669
- target = target[..., : recon_audio.shape[-1]]
670
-
671
- loss = F.l1_loss(recon_audio, target)
672
-
673
- multi_stft_resolution_loss = 0.0
674
-
675
- for window_size in self.multi_stft_resolutions_window_sizes:
676
- res_stft_kwargs = dict(
677
- n_fft=max(window_size, self.multi_stft_n_fft),
678
- win_length=window_size,
679
- return_complex=True,
680
- window=self.multi_stft_window_fn(window_size, device=device),
681
- **self.multi_stft_kwargs,
682
- )
683
-
684
- recon_Y = torch.stft(
685
- rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
686
- )
687
- target_Y = torch.stft(
688
- rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
689
- )
690
-
691
- multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
692
- recon_Y, target_Y
693
- )
694
-
695
- weighted_multi_resolution_loss = (
696
- multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
697
- )
698
-
699
- total_loss = loss + weighted_multi_resolution_loss
700
-
701
- if not return_loss_breakdown:
702
- return total_loss
703
-
704
- return total_loss, (loss, multi_stft_resolution_loss)
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn, einsum, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ import torch.nn.functional as F
7
+ from .fno1d import FNO1d
8
+
9
+ from .attend import Attend
10
+
11
+ from torch.utils.checkpoint import checkpoint
12
+
13
+ from beartype.typing import Tuple, Optional, List, Callable
14
+ from beartype import beartype
15
+
16
+ from rotary_embedding_torch import RotaryEmbedding
17
+
18
+ from einops import rearrange, pack, unpack
19
+ from einops.layers.torch import Rearrange
20
+
21
+
22
+ def exists(val):
23
+ return val is not None
24
+
25
+
26
+ def default(v, d):
27
+ return v if exists(v) else d
28
+
29
+
30
+ def pack_one(t, pattern):
31
+ return pack([t], pattern)
32
+
33
+
34
+ def unpack_one(t, ps, pattern):
35
+ return unpack(t, ps, pattern)[0]
36
+
37
+
38
+ def l2norm(t):
39
+ return F.normalize(t, dim=-1, p=2)
40
+
41
+
42
+ class RMSNorm(Module):
43
+ def __init__(self, dim):
44
+ super().__init__()
45
+ self.scale = dim**0.5
46
+ self.gamma = nn.Parameter(torch.ones(dim))
47
+
48
+ def forward(self, x):
49
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
50
+
51
+
52
+ class FeedForward(Module):
53
+ def __init__(self, dim, mult=4, dropout=0.0):
54
+ super().__init__()
55
+ dim_inner = int(dim * mult)
56
+ self.net = nn.Sequential(
57
+ RMSNorm(dim),
58
+ nn.Linear(dim, dim_inner),
59
+ nn.GELU(),
60
+ nn.Dropout(dropout),
61
+ nn.Linear(dim_inner, dim),
62
+ nn.Dropout(dropout),
63
+ )
64
+
65
+ def forward(self, x):
66
+ return self.net(x)
67
+
68
+
69
+ class Attention(Module):
70
+ def __init__(
71
+ self,
72
+ dim,
73
+ heads=8,
74
+ dim_head=64,
75
+ dropout=0.0,
76
+ rotary_embed=None,
77
+ flash=True,
78
+ ):
79
+ super().__init__()
80
+ self.heads = heads
81
+ self.scale = dim_head**-0.5
82
+ dim_inner = heads * dim_head
83
+
84
+ self.rotary_embed = rotary_embed
85
+
86
+ self.attend = Attend(flash=flash, dropout=dropout)
87
+
88
+ self.norm = RMSNorm(dim)
89
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
90
+
91
+ self.to_gates = nn.Linear(dim, heads)
92
+
93
+ self.to_out = nn.Sequential(
94
+ nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
95
+ )
96
+
97
+ def forward(self, x):
98
+ x = self.norm(x)
99
+
100
+ q, k, v = rearrange(
101
+ self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
102
+ )
103
+
104
+ if exists(self.rotary_embed):
105
+ q = self.rotary_embed.rotate_queries_or_keys(q)
106
+ k = self.rotary_embed.rotate_queries_or_keys(k)
107
+
108
+ out = self.attend(q, k, v)
109
+
110
+ gates = self.to_gates(x)
111
+ out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
112
+
113
+ out = rearrange(out, "b h n d -> b n (h d)")
114
+ return self.to_out(out)
115
+
116
+
117
+ class LinearAttention(Module):
118
+
119
+ @beartype
120
+ def __init__(
121
+ self,
122
+ *,
123
+ dim,
124
+ dim_head=32,
125
+ heads=8,
126
+ scale=8,
127
+ flash=False,
128
+ dropout=0.0,
129
+ ):
130
+ super().__init__()
131
+ dim_inner = dim_head * heads
132
+ self.norm = RMSNorm(dim)
133
+
134
+ self.to_qkv = nn.Sequential(
135
+ nn.Linear(dim, dim_inner * 3, bias=False),
136
+ Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
137
+ )
138
+
139
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
140
+
141
+ self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
142
+
143
+ self.to_out = nn.Sequential(
144
+ Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
145
+ )
146
+
147
+ def forward(self, x):
148
+ x = self.norm(x)
149
+
150
+ q, k, v = self.to_qkv(x)
151
+
152
+ q, k = map(l2norm, (q, k))
153
+ q = q * self.temperature.exp()
154
+
155
+ out = self.attend(q, k, v)
156
+
157
+ return self.to_out(out)
158
+
159
+
160
+ class Transformer(Module):
161
+ def __init__(
162
+ self,
163
+ *,
164
+ dim,
165
+ depth,
166
+ dim_head=64,
167
+ heads=8,
168
+ attn_dropout=0.0,
169
+ ff_dropout=0.0,
170
+ ff_mult=4,
171
+ norm_output=True,
172
+ rotary_embed=None,
173
+ flash_attn=True,
174
+ linear_attn=False,
175
+ ):
176
+ super().__init__()
177
+ self.layers = ModuleList([])
178
+
179
+ for _ in range(depth):
180
+ if linear_attn:
181
+ attn = LinearAttention(
182
+ dim=dim,
183
+ dim_head=dim_head,
184
+ heads=heads,
185
+ dropout=attn_dropout,
186
+ flash=flash_attn,
187
+ )
188
+ else:
189
+ attn = Attention(
190
+ dim=dim,
191
+ dim_head=dim_head,
192
+ heads=heads,
193
+ dropout=attn_dropout,
194
+ rotary_embed=rotary_embed,
195
+ flash=flash_attn,
196
+ )
197
+
198
+ self.layers.append(
199
+ ModuleList(
200
+ [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
201
+ )
202
+ )
203
+
204
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
205
+
206
+ def forward(self, x):
207
+
208
+ for attn, ff in self.layers:
209
+ x = attn(x) + x
210
+ x = ff(x) + x
211
+
212
+ return self.norm(x)
213
+
214
+
215
+ class BandSplit(Module):
216
+ @beartype
217
+ def __init__(self, dim, dim_inputs: Tuple[int, ...]):
218
+ super().__init__()
219
+ self.dim_inputs = dim_inputs
220
+ self.to_features = ModuleList([])
221
+
222
+ for dim_in in dim_inputs:
223
+ net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
224
+
225
+ self.to_features.append(net)
226
+
227
+ def forward(self, x):
228
+ x = x.split(self.dim_inputs, dim=-1)
229
+
230
+ outs = []
231
+ for split_input, to_feature in zip(x, self.to_features):
232
+ split_output = to_feature(split_input)
233
+ outs.append(split_output)
234
+
235
+ return torch.stack(outs, dim=-2)
236
+
237
+
238
+ def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
239
+ dim_hidden = default(dim_hidden, dim_in)
240
+
241
+ net = []
242
+ dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
243
+
244
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
245
+ is_last = ind == (len(dims) - 2)
246
+
247
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
248
+
249
+ if is_last:
250
+ continue
251
+
252
+ net.append(activation())
253
+
254
+ return nn.Sequential(*net)
255
+
256
+
257
+ class MaskEstimator(Module):
258
+ @beartype
259
+ def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
260
+ super().__init__()
261
+ self.dim_inputs = dim_inputs
262
+ self.to_freqs = ModuleList([])
263
+ dim_hidden = dim * mlp_expansion_factor
264
+
265
+ for dim_in in dim_inputs:
266
+ net = []
267
+
268
+ mlp = nn.Sequential(
269
+ FNO1d(
270
+ n_modes_height=64,
271
+ hidden_channels=dim,
272
+ in_channels=dim,
273
+ out_channels=dim_in * 2,
274
+ lifting_channels=dim,
275
+ projection_channels=dim,
276
+ n_layers=3,
277
+ separable=True,
278
+ ),
279
+ nn.GLU(dim=-2),
280
+ )
281
+
282
+ self.to_freqs.append(mlp)
283
+
284
+ def forward(self, x):
285
+ x = x.unbind(dim=-2)
286
+
287
+ outs = []
288
+
289
+ for band_features, mlp in zip(x, self.to_freqs):
290
+ band_features = rearrange(band_features, "b t c -> b c t")
291
+ with torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32):
292
+ freq_out = mlp(band_features).float()
293
+ freq_out = rearrange(freq_out, "b c t -> b t c")
294
+ outs.append(freq_out)
295
+
296
+ return torch.cat(outs, dim=-1)
297
+
298
+
299
+ DEFAULT_FREQS_PER_BANDS = (
300
+ 2,
301
+ 2,
302
+ 2,
303
+ 2,
304
+ 2,
305
+ 2,
306
+ 2,
307
+ 2,
308
+ 2,
309
+ 2,
310
+ 2,
311
+ 2,
312
+ 2,
313
+ 2,
314
+ 2,
315
+ 2,
316
+ 2,
317
+ 2,
318
+ 2,
319
+ 2,
320
+ 2,
321
+ 2,
322
+ 2,
323
+ 2,
324
+ 4,
325
+ 4,
326
+ 4,
327
+ 4,
328
+ 4,
329
+ 4,
330
+ 4,
331
+ 4,
332
+ 4,
333
+ 4,
334
+ 4,
335
+ 4,
336
+ 12,
337
+ 12,
338
+ 12,
339
+ 12,
340
+ 12,
341
+ 12,
342
+ 12,
343
+ 12,
344
+ 24,
345
+ 24,
346
+ 24,
347
+ 24,
348
+ 24,
349
+ 24,
350
+ 24,
351
+ 24,
352
+ 48,
353
+ 48,
354
+ 48,
355
+ 48,
356
+ 48,
357
+ 48,
358
+ 48,
359
+ 48,
360
+ 128,
361
+ 129,
362
+ )
363
+
364
+
365
+ class BSRoformer_FNO(Module):
366
+
367
+ @beartype
368
+ def __init__(
369
+ self,
370
+ dim,
371
+ *,
372
+ depth,
373
+ stereo=False,
374
+ num_stems=1,
375
+ time_transformer_depth=2,
376
+ freq_transformer_depth=2,
377
+ linear_transformer_depth=0,
378
+ freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
379
+ dim_head=64,
380
+ heads=8,
381
+ attn_dropout=0.0,
382
+ ff_dropout=0.0,
383
+ flash_attn=True,
384
+ dim_freqs_in=1025,
385
+ stft_n_fft=2048,
386
+ stft_hop_length=512,
387
+ stft_win_length=2048,
388
+ stft_normalized=False,
389
+ stft_window_fn: Optional[Callable] = None,
390
+ mask_estimator_depth=2,
391
+ multi_stft_resolution_loss_weight=1.0,
392
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (
393
+ 4096,
394
+ 2048,
395
+ 1024,
396
+ 512,
397
+ 256,
398
+ ),
399
+ multi_stft_hop_size=147,
400
+ multi_stft_normalized=False,
401
+ multi_stft_window_fn: Callable = torch.hann_window,
402
+ mlp_expansion_factor=4,
403
+ use_torch_checkpoint=False,
404
+ skip_connection=False,
405
+ **kwargs
406
+ ):
407
+ super().__init__()
408
+
409
+ self.stereo = stereo
410
+ self.audio_channels = 2 if stereo else 1
411
+ self.num_stems = num_stems
412
+ self.use_torch_checkpoint = use_torch_checkpoint
413
+ self.skip_connection = skip_connection
414
+
415
+ self.layers = ModuleList([])
416
+
417
+ transformer_kwargs = dict(
418
+ dim=dim,
419
+ heads=heads,
420
+ dim_head=dim_head,
421
+ attn_dropout=attn_dropout,
422
+ ff_dropout=ff_dropout,
423
+ flash_attn=flash_attn,
424
+ norm_output=False,
425
+ )
426
+
427
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
428
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
429
+
430
+ for _ in range(depth):
431
+ tran_modules = []
432
+ if linear_transformer_depth > 0:
433
+ tran_modules.append(
434
+ Transformer(
435
+ depth=linear_transformer_depth,
436
+ linear_attn=True,
437
+ **transformer_kwargs,
438
+ )
439
+ )
440
+ tran_modules.append(
441
+ Transformer(
442
+ depth=time_transformer_depth,
443
+ rotary_embed=time_rotary_embed,
444
+ **transformer_kwargs,
445
+ )
446
+ )
447
+ tran_modules.append(
448
+ Transformer(
449
+ depth=freq_transformer_depth,
450
+ rotary_embed=freq_rotary_embed,
451
+ **transformer_kwargs,
452
+ )
453
+ )
454
+ self.layers.append(nn.ModuleList(tran_modules))
455
+
456
+ self.final_norm = RMSNorm(dim)
457
+
458
+ self.stft_kwargs = dict(
459
+ n_fft=stft_n_fft,
460
+ hop_length=stft_hop_length,
461
+ win_length=stft_win_length,
462
+ normalized=stft_normalized,
463
+ )
464
+
465
+ self.stft_window_fn = partial(
466
+ default(stft_window_fn, torch.hann_window), stft_win_length
467
+ )
468
+
469
+ freqs = torch.stft(
470
+ torch.randn(1, 4096),
471
+ **self.stft_kwargs,
472
+ window=torch.ones(stft_win_length),
473
+ return_complex=True,
474
+ ).shape[1]
475
+
476
+ assert len(freqs_per_bands) > 1
477
+ assert (
478
+ sum(freqs_per_bands) == freqs
479
+ ), f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
480
+
481
+ freqs_per_bands_with_complex = tuple(
482
+ 2 * f * self.audio_channels for f in freqs_per_bands
483
+ )
484
+
485
+ self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
486
+
487
+ self.mask_estimators = nn.ModuleList([])
488
+
489
+ for _ in range(num_stems):
490
+ mask_estimator = MaskEstimator(
491
+ dim=dim,
492
+ dim_inputs=freqs_per_bands_with_complex,
493
+ depth=mask_estimator_depth,
494
+ mlp_expansion_factor=mlp_expansion_factor,
495
+ )
496
+
497
+ self.mask_estimators.append(mask_estimator)
498
+
499
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
500
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
501
+ self.multi_stft_n_fft = stft_n_fft
502
+ self.multi_stft_window_fn = multi_stft_window_fn
503
+
504
+ self.multi_stft_kwargs = dict(
505
+ hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
506
+ )
507
+
508
+ def forward(self, raw_audio, target=None, return_loss_breakdown=False):
509
+
510
+ device = raw_audio.device
511
+
512
+ x_is_mps = True if device.type == "mps" else False
513
+
514
+ if raw_audio.ndim == 2:
515
+ raw_audio = rearrange(raw_audio, "b t -> b 1 t")
516
+
517
+ channels = raw_audio.shape[1]
518
+ assert (not self.stereo and channels == 1) or (
519
+ self.stereo and channels == 2
520
+ ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
521
+
522
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
523
+
524
+ stft_window = self.stft_window_fn(device=device)
525
+
526
+ try:
527
+ stft_repr = torch.stft(
528
+ raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
529
+ )
530
+ except:
531
+ stft_repr = torch.stft(
532
+ raw_audio.cpu() if x_is_mps else raw_audio,
533
+ **self.stft_kwargs,
534
+ window=stft_window.cpu() if x_is_mps else stft_window,
535
+ return_complex=True,
536
+ ).to(device)
537
+ stft_repr = torch.view_as_real(stft_repr)
538
+
539
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
540
+
541
+ stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
542
+
543
+ x = rearrange(stft_repr, "b f t c -> b t (f c)")
544
+
545
+ if self.use_torch_checkpoint:
546
+ x = checkpoint(self.band_split, x, use_reentrant=False)
547
+ else:
548
+ x = self.band_split(x)
549
+
550
+ store = [None] * len(self.layers)
551
+ for i, transformer_block in enumerate(self.layers):
552
+
553
+ if len(transformer_block) == 3:
554
+ linear_transformer, time_transformer, freq_transformer = (
555
+ transformer_block
556
+ )
557
+
558
+ x, ft_ps = pack([x], "b * d")
559
+ if self.use_torch_checkpoint:
560
+ x = checkpoint(linear_transformer, x, use_reentrant=False)
561
+ else:
562
+ x = linear_transformer(x)
563
+ (x,) = unpack(x, ft_ps, "b * d")
564
+ else:
565
+ time_transformer, freq_transformer = transformer_block
566
+
567
+ if self.skip_connection:
568
+ for j in range(i):
569
+ x = x + store[j]
570
+
571
+ x = rearrange(x, "b t f d -> b f t d")
572
+ x, ps = pack([x], "* t d")
573
+
574
+ if self.use_torch_checkpoint:
575
+ x = checkpoint(time_transformer, x, use_reentrant=False)
576
+ else:
577
+ x = time_transformer(x)
578
+
579
+ (x,) = unpack(x, ps, "* t d")
580
+ x = rearrange(x, "b f t d -> b t f d")
581
+ x, ps = pack([x], "* f d")
582
+
583
+ if self.use_torch_checkpoint:
584
+ x = checkpoint(freq_transformer, x, use_reentrant=False)
585
+ else:
586
+ x = freq_transformer(x)
587
+
588
+ (x,) = unpack(x, ps, "* f d")
589
+
590
+ if self.skip_connection:
591
+ store[i] = x
592
+
593
+ x = self.final_norm(x)
594
+
595
+ num_stems = len(self.mask_estimators)
596
+
597
+ if self.use_torch_checkpoint:
598
+ mask = torch.stack(
599
+ [checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators],
600
+ dim=1,
601
+ )
602
+ else:
603
+ mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
604
+ mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
605
+
606
+ stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
607
+
608
+ stft_repr = torch.view_as_complex(stft_repr)
609
+ mask = torch.view_as_complex(mask)
610
+
611
+ stft_repr = stft_repr * mask
612
+
613
+ stft_repr = rearrange(
614
+ stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
615
+ )
616
+
617
+ try:
618
+ recon_audio = torch.istft(
619
+ stft_repr,
620
+ **self.stft_kwargs,
621
+ window=stft_window,
622
+ return_complex=False,
623
+ length=raw_audio.shape[-1],
624
+ )
625
+ except:
626
+ recon_audio = torch.istft(
627
+ stft_repr.cpu() if x_is_mps else stft_repr,
628
+ **self.stft_kwargs,
629
+ window=stft_window.cpu() if x_is_mps else stft_window,
630
+ return_complex=False,
631
+ length=raw_audio.shape[-1],
632
+ ).to(device)
633
+
634
+ recon_audio = rearrange(
635
+ recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
636
+ )
637
+
638
+ if num_stems == 1:
639
+ recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
640
+
641
+ if not exists(target):
642
+ return recon_audio
643
+
644
+ if self.num_stems > 1:
645
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
646
+
647
+ if target.ndim == 2:
648
+ target = rearrange(target, "... t -> ... 1 t")
649
+
650
+ target = target[..., : recon_audio.shape[-1]]
651
+
652
+ loss = F.l1_loss(recon_audio, target)
653
+
654
+ multi_stft_resolution_loss = 0.0
655
+
656
+ for window_size in self.multi_stft_resolutions_window_sizes:
657
+ res_stft_kwargs = dict(
658
+ n_fft=max(window_size, self.multi_stft_n_fft),
659
+ win_length=window_size,
660
+ return_complex=True,
661
+ window=self.multi_stft_window_fn(window_size, device=device),
662
+ **self.multi_stft_kwargs,
663
+ )
664
+
665
+ recon_Y = torch.stft(
666
+ rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
667
+ )
668
+ target_Y = torch.stft(
669
+ rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
670
+ )
671
+
672
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
673
+ recon_Y, target_Y
674
+ )
675
+
676
+ weighted_multi_resolution_loss = (
677
+ multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
678
+ )
679
+
680
+ total_loss = loss + weighted_multi_resolution_loss
681
+
682
+ if not return_loss_breakdown:
683
+ return total_loss
684
+
685
+ return total_loss, (loss, multi_stft_resolution_loss)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/bs_roformer/bs_roformer_hyperace.py CHANGED
@@ -1,1122 +1,1103 @@
1
- from functools import partial
2
-
3
- import torch
4
- from torch import nn, einsum, Tensor
5
- from torch.nn import Module, ModuleList
6
- import torch.nn.functional as F
7
-
8
- from .attend import Attend
9
-
10
- try:
11
- from .attend_sage import Attend as AttendSage
12
- except:
13
- pass
14
- from torch.utils.checkpoint import checkpoint
15
-
16
- from beartype.typing import Tuple, Optional, List, Callable
17
- from beartype import beartype
18
-
19
- from rotary_embedding_torch import RotaryEmbedding
20
-
21
- from einops import rearrange, pack, unpack
22
- from einops.layers.torch import Rearrange
23
- import torchaudio
24
-
25
-
26
- def exists(val):
27
- return val is not None
28
-
29
-
30
- def default(v, d):
31
- return v if exists(v) else d
32
-
33
-
34
- def pack_one(t, pattern):
35
- return pack([t], pattern)
36
-
37
-
38
- def unpack_one(t, ps, pattern):
39
- return unpack(t, ps, pattern)[0]
40
-
41
-
42
- def l2norm(t):
43
- return F.normalize(t, dim=-1, p=2)
44
-
45
-
46
- class RMSNorm(Module):
47
- def __init__(self, dim):
48
- super().__init__()
49
- self.scale = dim**0.5
50
- self.gamma = nn.Parameter(torch.ones(dim))
51
-
52
- def forward(self, x):
53
- return F.normalize(x, dim=-1) * self.scale * self.gamma
54
-
55
-
56
- class FeedForward(Module):
57
- def __init__(self, dim, mult=4, dropout=0.0):
58
- super().__init__()
59
- dim_inner = int(dim * mult)
60
- self.net = nn.Sequential(
61
- RMSNorm(dim),
62
- nn.Linear(dim, dim_inner),
63
- nn.GELU(),
64
- nn.Dropout(dropout),
65
- nn.Linear(dim_inner, dim),
66
- nn.Dropout(dropout),
67
- )
68
-
69
- def forward(self, x):
70
- return self.net(x)
71
-
72
-
73
- class Attention(Module):
74
- def __init__(
75
- self,
76
- dim,
77
- heads=8,
78
- dim_head=64,
79
- dropout=0.0,
80
- rotary_embed=None,
81
- flash=True,
82
- sage_attention=False,
83
- ):
84
- super().__init__()
85
- self.heads = heads
86
- self.scale = dim_head**-0.5
87
- dim_inner = heads * dim_head
88
-
89
- self.rotary_embed = rotary_embed
90
-
91
- if sage_attention:
92
- self.attend = AttendSage(flash=flash, dropout=dropout)
93
- else:
94
- self.attend = Attend(flash=flash, dropout=dropout)
95
-
96
- self.norm = RMSNorm(dim)
97
- self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
98
-
99
- self.to_gates = nn.Linear(dim, heads)
100
-
101
- self.to_out = nn.Sequential(
102
- nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
103
- )
104
-
105
- def forward(self, x):
106
- x = self.norm(x)
107
-
108
- q, k, v = rearrange(
109
- self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
110
- )
111
-
112
- if exists(self.rotary_embed):
113
- q = self.rotary_embed.rotate_queries_or_keys(q)
114
- k = self.rotary_embed.rotate_queries_or_keys(k)
115
-
116
- out = self.attend(q, k, v)
117
-
118
- gates = self.to_gates(x)
119
- out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
120
-
121
- out = rearrange(out, "b h n d -> b n (h d)")
122
- return self.to_out(out)
123
-
124
-
125
- class LinearAttention(Module):
126
-
127
- @beartype
128
- def __init__(
129
- self,
130
- *,
131
- dim,
132
- dim_head=32,
133
- heads=8,
134
- scale=8,
135
- flash=True,
136
- dropout=0.0,
137
- sage_attention=False,
138
- ):
139
- super().__init__()
140
- dim_inner = dim_head * heads
141
- self.norm = RMSNorm(dim)
142
-
143
- self.to_qkv = nn.Sequential(
144
- nn.Linear(dim, dim_inner * 3, bias=False),
145
- Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
146
- )
147
-
148
- self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
149
-
150
- if sage_attention:
151
- self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash)
152
- else:
153
- self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
154
-
155
- self.to_out = nn.Sequential(
156
- Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
157
- )
158
-
159
- def forward(self, x):
160
- x = self.norm(x)
161
-
162
- q, k, v = self.to_qkv(x)
163
-
164
- q, k = map(l2norm, (q, k))
165
- q = q * self.temperature.exp()
166
-
167
- out = self.attend(q, k, v)
168
-
169
- return self.to_out(out)
170
-
171
-
172
- class Transformer(Module):
173
- def __init__(
174
- self,
175
- *,
176
- dim,
177
- depth,
178
- dim_head=64,
179
- heads=8,
180
- attn_dropout=0.0,
181
- ff_dropout=0.0,
182
- ff_mult=4,
183
- norm_output=True,
184
- rotary_embed=None,
185
- flash_attn=True,
186
- linear_attn=False,
187
- sage_attention=False,
188
- ):
189
- super().__init__()
190
- self.layers = ModuleList([])
191
-
192
- for _ in range(depth):
193
- if linear_attn:
194
- attn = LinearAttention(
195
- dim=dim,
196
- dim_head=dim_head,
197
- heads=heads,
198
- dropout=attn_dropout,
199
- flash=flash_attn,
200
- sage_attention=sage_attention,
201
- )
202
- else:
203
- attn = Attention(
204
- dim=dim,
205
- dim_head=dim_head,
206
- heads=heads,
207
- dropout=attn_dropout,
208
- rotary_embed=rotary_embed,
209
- flash=flash_attn,
210
- sage_attention=sage_attention,
211
- )
212
-
213
- self.layers.append(
214
- ModuleList(
215
- [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
216
- )
217
- )
218
-
219
- self.norm = RMSNorm(dim) if norm_output else nn.Identity()
220
-
221
- def forward(self, x):
222
-
223
- for attn, ff in self.layers:
224
- x = attn(x) + x
225
- x = ff(x) + x
226
-
227
- return self.norm(x)
228
-
229
-
230
- class BandSplit(Module):
231
- @beartype
232
- def __init__(self, dim, dim_inputs: Tuple[int, ...]):
233
- super().__init__()
234
- self.dim_inputs = dim_inputs
235
- self.to_features = ModuleList([])
236
-
237
- for dim_in in dim_inputs:
238
- net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
239
-
240
- self.to_features.append(net)
241
-
242
- def forward(self, x):
243
-
244
- x = x.split(self.dim_inputs, dim=-1)
245
-
246
- outs = []
247
- for split_input, to_feature in zip(x, self.to_features):
248
- split_output = to_feature(split_input)
249
- outs.append(split_output)
250
-
251
- x = torch.stack(outs, dim=-2)
252
-
253
- return x
254
-
255
-
256
- class Conv(nn.Module):
257
- def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
258
- super().__init__()
259
- self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
260
- self.bn = nn.InstanceNorm2d(c2, affine=True, eps=1e-8)
261
- self.act = nn.SiLU() if act else nn.Identity()
262
-
263
- def forward(self, x):
264
- return self.act(self.bn(self.conv(x)))
265
-
266
-
267
- def autopad(k, p=None):
268
- if p is None:
269
- p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
270
- return p
271
-
272
-
273
- class DSConv(nn.Module):
274
- def __init__(self, c1, c2, k=3, s=1, p=None, act=True):
275
- super().__init__()
276
- self.dwconv = nn.Conv2d(c1, c1, k, s, autopad(k, p), groups=c1, bias=False)
277
- self.pwconv = nn.Conv2d(c1, c2, 1, 1, 0, bias=False)
278
- self.bn = nn.InstanceNorm2d(c2, affine=True, eps=1e-8)
279
- self.act = nn.SiLU() if act else nn.Identity()
280
-
281
- def forward(self, x):
282
- return self.act(self.bn(self.pwconv(self.dwconv(x))))
283
-
284
-
285
- class DS_Bottleneck(nn.Module):
286
- def __init__(self, c1, c2, k=3, shortcut=True):
287
- super().__init__()
288
- c_ = c1
289
- self.dsconv1 = DSConv(c1, c_, k=3, s=1)
290
- self.dsconv2 = DSConv(c_, c2, k=k, s=1)
291
- self.shortcut = shortcut and c1 == c2
292
-
293
- def forward(self, x):
294
- return (
295
- x + self.dsconv2(self.dsconv1(x))
296
- if self.shortcut
297
- else self.dsconv2(self.dsconv1(x))
298
- )
299
-
300
-
301
- class DS_C3k(nn.Module):
302
- def __init__(self, c1, c2, n=1, k=3, e=0.5):
303
- super().__init__()
304
- c_ = int(c2 * e)
305
- self.cv1 = Conv(c1, c_, 1, 1)
306
- self.cv2 = Conv(c1, c_, 1, 1)
307
- self.cv3 = Conv(2 * c_, c2, 1, 1)
308
- self.m = nn.Sequential(
309
- *[DS_Bottleneck(c_, c_, k=k, shortcut=True) for _ in range(n)]
310
- )
311
-
312
- def forward(self, x):
313
- return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
314
-
315
-
316
- class DS_C3k2(nn.Module):
317
- def __init__(self, c1, c2, n=1, k=3, e=0.5):
318
- super().__init__()
319
- c_ = int(c2 * e)
320
- self.cv1 = Conv(c1, c_, 1, 1)
321
- self.m = DS_C3k(c_, c_, n=n, k=k, e=1.0)
322
- self.cv2 = Conv(c_, c2, 1, 1)
323
-
324
- def forward(self, x):
325
- x_ = self.cv1(x)
326
- x_ = self.m(x_)
327
- return self.cv2(x_)
328
-
329
-
330
- class AdaptiveHyperedgeGeneration(nn.Module):
331
- def __init__(self, in_channels, num_hyperedges, num_heads=8):
332
- super().__init__()
333
- self.num_hyperedges = num_hyperedges
334
- self.num_heads = num_heads
335
- self.head_dim = in_channels // num_heads
336
-
337
- self.global_proto = nn.Parameter(torch.randn(num_hyperedges, in_channels))
338
-
339
- self.context_mapper = nn.Linear(
340
- 2 * in_channels, num_hyperedges * in_channels, bias=False
341
- )
342
-
343
- self.query_proj = nn.Linear(in_channels, in_channels, bias=False)
344
-
345
- self.scale = self.head_dim**-0.5
346
-
347
- def forward(self, x):
348
- B, N, C = x.shape
349
-
350
- f_avg = F.adaptive_avg_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
351
- f_max = F.adaptive_max_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
352
- f_ctx = torch.cat((f_avg, f_max), dim=1)
353
-
354
- delta_P = self.context_mapper(f_ctx).view(B, self.num_hyperedges, C)
355
- P = self.global_proto.unsqueeze(0) + delta_P
356
-
357
- z = self.query_proj(x)
358
-
359
- z = z.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
360
-
361
- P = P.view(B, self.num_hyperedges, self.num_heads, self.head_dim).permute(
362
- 0, 2, 3, 1
363
- )
364
-
365
- sim = (z @ P) * self.scale
366
-
367
- s_bar = sim.mean(dim=1)
368
-
369
- A = F.softmax(s_bar.permute(0, 2, 1), dim=-1)
370
-
371
- return A
372
-
373
-
374
- class HypergraphConvolution(nn.Module):
375
- def __init__(self, in_channels, out_channels):
376
- super().__init__()
377
- self.W_e = nn.Linear(in_channels, in_channels, bias=False)
378
- self.W_v = nn.Linear(in_channels, out_channels, bias=False)
379
- self.act = nn.SiLU()
380
-
381
- def forward(self, x, A):
382
- f_m = torch.bmm(A, x)
383
- f_m = self.act(self.W_e(f_m))
384
-
385
- x_out = torch.bmm(A.transpose(1, 2), f_m)
386
- x_out = self.act(self.W_v(x_out))
387
-
388
- return x + x_out
389
-
390
-
391
- class AdaptiveHypergraphComputation(nn.Module):
392
- def __init__(self, in_channels, out_channels, num_hyperedges=8, num_heads=8):
393
- super().__init__()
394
- self.adaptive_hyperedge_gen = AdaptiveHyperedgeGeneration(
395
- in_channels, num_hyperedges, num_heads
396
- )
397
- self.hypergraph_conv = HypergraphConvolution(in_channels, out_channels)
398
-
399
- def forward(self, x):
400
- B, C, H, W = x.shape
401
- x_flat = x.flatten(2).permute(0, 2, 1)
402
-
403
- A = self.adaptive_hyperedge_gen(x_flat)
404
-
405
- x_out_flat = self.hypergraph_conv(x_flat, A)
406
-
407
- x_out = x_out_flat.permute(0, 2, 1).view(B, -1, H, W)
408
- return x_out
409
-
410
-
411
- class C3AH(nn.Module):
412
- def __init__(self, c1, c2, num_hyperedges=8, num_heads=8, e=0.5):
413
- super().__init__()
414
- c_ = int(c1 * e)
415
- self.cv1 = Conv(c1, c_, 1, 1)
416
- self.cv2 = Conv(c1, c_, 1, 1)
417
- self.ahc = AdaptiveHypergraphComputation(c_, c_, num_hyperedges, num_heads)
418
- self.cv3 = Conv(2 * c_, c2, 1, 1)
419
-
420
- def forward(self, x):
421
- x_lateral = self.cv1(x)
422
- x_ahc = self.ahc(self.cv2(x))
423
- return self.cv3(torch.cat((x_ahc, x_lateral), dim=1))
424
-
425
-
426
- class HyperACE(nn.Module):
427
- def __init__(
428
- self,
429
- in_channels: List[int],
430
- out_channels: int,
431
- num_hyperedges=8,
432
- num_heads=8,
433
- k=2,
434
- l=1,
435
- c_h=0.5,
436
- c_l=0.25,
437
- ):
438
- super().__init__()
439
-
440
- c2, c3, c4, c5 = in_channels
441
- c_mid = c4
442
-
443
- self.fuse_conv = Conv(c2 + c3 + c4 + c5, c_mid, 1, 1)
444
-
445
- self.c_h = int(c_mid * c_h)
446
- self.c_l = int(c_mid * c_l)
447
- self.c_s = c_mid - self.c_h - self.c_l
448
- assert self.c_s > 0, "Channel split error"
449
-
450
- self.high_order_branch = nn.ModuleList(
451
- [
452
- C3AH(self.c_h, self.c_h, num_hyperedges, num_heads, e=1.0)
453
- for _ in range(k)
454
- ]
455
- )
456
- self.high_order_fuse = Conv(self.c_h * k, self.c_h, 1, 1)
457
-
458
- self.low_order_branch = nn.Sequential(
459
- *[DS_C3k(self.c_l, self.c_l, n=1, k=3, e=1.0) for _ in range(l)]
460
- )
461
-
462
- self.final_fuse = Conv(self.c_h + self.c_l + self.c_s, out_channels, 1, 1)
463
-
464
- def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
465
- B2, B3, B4, B5 = x
466
-
467
- B, _, H4, W4 = B4.shape
468
-
469
- B2_resized = F.interpolate(
470
- B2, size=(H4, W4), mode="bilinear", align_corners=False
471
- )
472
- B3_resized = F.interpolate(
473
- B3, size=(H4, W4), mode="bilinear", align_corners=False
474
- )
475
- B5_resized = F.interpolate(
476
- B5, size=(H4, W4), mode="bilinear", align_corners=False
477
- )
478
-
479
- x_b = self.fuse_conv(torch.cat((B2_resized, B3_resized, B4, B5_resized), dim=1))
480
-
481
- x_h, x_l, x_s = torch.split(x_b, [self.c_h, self.c_l, self.c_s], dim=1)
482
-
483
- x_h_outs = [m(x_h) for m in self.high_order_branch]
484
- x_h_fused = self.high_order_fuse(torch.cat(x_h_outs, dim=1))
485
-
486
- x_l_out = self.low_order_branch(x_l)
487
-
488
- y = self.final_fuse(torch.cat((x_h_fused, x_l_out, x_s), dim=1))
489
-
490
- return y
491
-
492
-
493
- class GatedFusion(nn.Module):
494
- def __init__(self, in_channels):
495
- super().__init__()
496
- self.gamma = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
497
-
498
- def forward(self, f_in, h):
499
- if f_in.shape[1] != h.shape[1]:
500
- raise ValueError(f"Channel mismatch: f_in={f_in.shape}, h={h.shape}")
501
- return f_in + self.gamma * h
502
-
503
-
504
- class Backbone(nn.Module):
505
- def __init__(self, in_channels=256, base_channels=64, base_depth=3):
506
- super().__init__()
507
- c = base_channels
508
- c2 = base_channels
509
- c3 = 256
510
- c4 = 384
511
- c5 = 512
512
- c6 = 768
513
-
514
- self.stem = DSConv(in_channels, c2, k=3, s=(2, 1), p=1)
515
-
516
- self.p2 = nn.Sequential(
517
- DSConv(c2, c3, k=3, s=(2, 1), p=1), DS_C3k2(c3, c3, n=base_depth)
518
- )
519
-
520
- self.p3 = nn.Sequential(
521
- DSConv(c3, c4, k=3, s=(2, 1), p=1), DS_C3k2(c4, c4, n=base_depth * 2)
522
- )
523
-
524
- self.p4 = nn.Sequential(
525
- DSConv(c4, c5, k=3, s=(2, 1), p=1), DS_C3k2(c5, c5, n=base_depth * 2)
526
- )
527
-
528
- self.p5 = nn.Sequential(
529
- DSConv(c5, c6, k=3, s=(2, 1), p=1), DS_C3k2(c6, c6, n=base_depth)
530
- )
531
-
532
- self.out_channels = [c3, c4, c5, c6]
533
-
534
- def forward(self, x):
535
- x = self.stem(x)
536
- x2 = self.p2(x)
537
- x3 = self.p3(x2)
538
- x4 = self.p4(x3)
539
- x5 = self.p5(x4)
540
- return [x2, x3, x4, x5]
541
-
542
-
543
- class Decoder(nn.Module):
544
- def __init__(
545
- self,
546
- encoder_channels: List[int],
547
- hyperace_out_c: int,
548
- decoder_channels: List[int],
549
- ):
550
- super().__init__()
551
- c_p2, c_p3, c_p4, c_p5 = encoder_channels
552
- c_d2, c_d3, c_d4, c_d5 = decoder_channels
553
-
554
- self.h_to_d5 = Conv(hyperace_out_c, c_d5, 1, 1)
555
- self.h_to_d4 = Conv(hyperace_out_c, c_d4, 1, 1)
556
- self.h_to_d3 = Conv(hyperace_out_c, c_d3, 1, 1)
557
- self.h_to_d2 = Conv(hyperace_out_c, c_d2, 1, 1)
558
-
559
- self.fusion_d5 = GatedFusion(c_d5)
560
- self.fusion_d4 = GatedFusion(c_d4)
561
- self.fusion_d3 = GatedFusion(c_d3)
562
- self.fusion_d2 = GatedFusion(c_d2)
563
-
564
- self.skip_p5 = Conv(c_p5, c_d5, 1, 1)
565
- self.skip_p4 = Conv(c_p4, c_d4, 1, 1)
566
- self.skip_p3 = Conv(c_p3, c_d3, 1, 1)
567
- self.skip_p2 = Conv(c_p2, c_d2, 1, 1)
568
-
569
- self.up_d5 = DS_C3k2(c_d5, c_d4, n=1)
570
- self.up_d4 = DS_C3k2(c_d4, c_d3, n=1)
571
- self.up_d3 = DS_C3k2(c_d3, c_d2, n=1)
572
-
573
- self.final_d2 = DS_C3k2(c_d2, c_d2, n=1)
574
-
575
- def forward(self, enc_feats: List[torch.Tensor], h_ace: torch.Tensor):
576
- p2, p3, p4, p5 = enc_feats
577
-
578
- d5 = self.skip_p5(p5)
579
- h_d5 = self.h_to_d5(F.interpolate(h_ace, size=d5.shape[2:], mode="bilinear"))
580
- d5 = self.fusion_d5(d5, h_d5)
581
-
582
- d5_up = F.interpolate(d5, size=p4.shape[2:], mode="bilinear")
583
- d4_skip = self.skip_p4(p4)
584
- d4 = self.up_d5(d5_up) + d4_skip
585
-
586
- h_d4 = self.h_to_d4(F.interpolate(h_ace, size=d4.shape[2:], mode="bilinear"))
587
- d4 = self.fusion_d4(d4, h_d4)
588
-
589
- d4_up = F.interpolate(d4, size=p3.shape[2:], mode="bilinear")
590
- d3_skip = self.skip_p3(p3)
591
- d3 = self.up_d4(d4_up) + d3_skip
592
-
593
- h_d3 = self.h_to_d3(F.interpolate(h_ace, size=d3.shape[2:], mode="bilinear"))
594
- d3 = self.fusion_d3(d3, h_d3)
595
-
596
- d3_up = F.interpolate(d3, size=p2.shape[2:], mode="bilinear")
597
- d2_skip = self.skip_p2(p2)
598
- d2 = self.up_d3(d3_up) + d2_skip
599
-
600
- h_d2 = self.h_to_d2(F.interpolate(h_ace, size=d2.shape[2:], mode="bilinear"))
601
- d2 = self.fusion_d2(d2, h_d2)
602
-
603
- d2_final = self.final_d2(d2)
604
-
605
- return d2_final
606
-
607
-
608
- class FreqPixelShuffle(nn.Module):
609
- def __init__(self, in_channels, out_channels, scale=2):
610
- super().__init__()
611
- self.scale = scale
612
- self.conv = DSConv(in_channels, out_channels * scale, k=3, s=1, p=1)
613
- self.act = nn.SiLU()
614
-
615
- def forward(self, x):
616
- x = self.conv(x)
617
- B, C_r, H, W = x.shape
618
- out_c = C_r // self.scale
619
-
620
- x = x.view(B, out_c, self.scale, H, W)
621
-
622
- x = x.permute(0, 1, 3, 4, 2).contiguous()
623
- x = x.view(B, out_c, H, W * self.scale)
624
-
625
- return x
626
-
627
-
628
- class ProgressiveUpsampleHead(nn.Module):
629
- def __init__(self, in_channels, out_channels, target_bins=1025):
630
- super().__init__()
631
- self.target_bins = target_bins
632
-
633
- c = in_channels
634
-
635
- self.block1 = FreqPixelShuffle(c, c, scale=2)
636
- self.block2 = FreqPixelShuffle(c, c // 2, scale=2)
637
- self.block3 = FreqPixelShuffle(c // 2, c // 2, scale=2)
638
- self.block4 = FreqPixelShuffle(c // 2, c // 4, scale=2)
639
-
640
- self.final_conv = nn.Conv2d(c // 4, out_channels, kernel_size=1, bias=False)
641
-
642
- def forward(self, x):
643
-
644
- x = self.block1(x)
645
- x = self.block2(x)
646
- x = self.block3(x)
647
- x = self.block4(x)
648
-
649
- if x.shape[-1] != self.target_bins:
650
- x = F.interpolate(
651
- x,
652
- size=(x.shape[2], self.target_bins),
653
- mode="bilinear",
654
- align_corners=False,
655
- )
656
-
657
- x = self.final_conv(x)
658
- return x
659
-
660
-
661
- class SegmModel(nn.Module):
662
- def __init__(
663
- self,
664
- in_bands=62,
665
- in_dim=256,
666
- out_bins=1025,
667
- out_channels=4,
668
- base_channels=64,
669
- base_depth=2,
670
- num_hyperedges=16,
671
- num_heads=8,
672
- ):
673
- super().__init__()
674
-
675
- self.backbone = Backbone(
676
- in_channels=in_dim, base_channels=base_channels, base_depth=base_depth
677
- )
678
- enc_channels = self.backbone.out_channels
679
- c2, c3, c4, c5 = enc_channels
680
-
681
- hyperace_in_channels = enc_channels
682
- hyperace_out_channels = c4
683
- self.hyperace = HyperACE(
684
- hyperace_in_channels,
685
- hyperace_out_channels,
686
- num_hyperedges,
687
- num_heads,
688
- k=3,
689
- l=2,
690
- )
691
-
692
- decoder_channels = [c2, c3, c4, c5]
693
- self.decoder = Decoder(enc_channels, hyperace_out_channels, decoder_channels)
694
-
695
- self.upsample_head = ProgressiveUpsampleHead(
696
- in_channels=decoder_channels[0],
697
- out_channels=out_channels,
698
- target_bins=out_bins,
699
- )
700
-
701
- def forward(self, x):
702
- H, W = x.shape[2:]
703
-
704
- enc_feats = self.backbone(x)
705
-
706
- h_ace_feats = self.hyperace(enc_feats)
707
-
708
- dec_feat = self.decoder(enc_feats, h_ace_feats)
709
-
710
- feat_time_restored = F.interpolate(
711
- dec_feat, size=(H, dec_feat.shape[-1]), mode="bilinear", align_corners=False
712
- )
713
-
714
- out = self.upsample_head(feat_time_restored)
715
-
716
- return out
717
-
718
-
719
- def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
720
- dim_hidden = default(dim_hidden, dim_in)
721
-
722
- net = []
723
- dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
724
-
725
- for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
726
- is_last = ind == (len(dims) - 2)
727
-
728
- net.append(nn.Linear(layer_dim_in, layer_dim_out))
729
-
730
- if is_last:
731
- continue
732
-
733
- net.append(activation())
734
-
735
- return nn.Sequential(*net)
736
-
737
-
738
- class MaskEstimator(Module):
739
- @beartype
740
- def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
741
- super().__init__()
742
- self.dim_inputs = dim_inputs
743
- self.to_freqs = ModuleList([])
744
- dim_hidden = dim * mlp_expansion_factor
745
-
746
- for dim_in in dim_inputs:
747
- net = []
748
-
749
- mlp = nn.Sequential(
750
- MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1)
751
- )
752
-
753
- self.to_freqs.append(mlp)
754
-
755
- self.segm = SegmModel(
756
- in_bands=len(dim_inputs), in_dim=dim, out_bins=sum(dim_inputs) // 4
757
- )
758
-
759
- def forward(self, x):
760
- y = rearrange(x, "b t f c -> b c t f")
761
- y = self.segm(y)
762
- y = rearrange(y, "b c t f -> b t (f c)")
763
-
764
- x = x.unbind(dim=-2)
765
-
766
- outs = []
767
-
768
- for band_features, mlp in zip(x, self.to_freqs):
769
- freq_out = mlp(band_features)
770
- outs.append(freq_out)
771
-
772
- return torch.cat(outs, dim=-1) + y
773
-
774
-
775
- DEFAULT_FREQS_PER_BANDS = (
776
- 2,
777
- 2,
778
- 2,
779
- 2,
780
- 2,
781
- 2,
782
- 2,
783
- 2,
784
- 2,
785
- 2,
786
- 2,
787
- 2,
788
- 2,
789
- 2,
790
- 2,
791
- 2,
792
- 2,
793
- 2,
794
- 2,
795
- 2,
796
- 2,
797
- 2,
798
- 2,
799
- 2,
800
- 4,
801
- 4,
802
- 4,
803
- 4,
804
- 4,
805
- 4,
806
- 4,
807
- 4,
808
- 4,
809
- 4,
810
- 4,
811
- 4,
812
- 12,
813
- 12,
814
- 12,
815
- 12,
816
- 12,
817
- 12,
818
- 12,
819
- 12,
820
- 24,
821
- 24,
822
- 24,
823
- 24,
824
- 24,
825
- 24,
826
- 24,
827
- 24,
828
- 48,
829
- 48,
830
- 48,
831
- 48,
832
- 48,
833
- 48,
834
- 48,
835
- 48,
836
- 128,
837
- 129,
838
- )
839
-
840
-
841
- class BSRoformerHyperACE(Module):
842
-
843
- @beartype
844
- def __init__(
845
- self,
846
- dim,
847
- *,
848
- depth,
849
- stereo=False,
850
- num_stems=1,
851
- time_transformer_depth=2,
852
- freq_transformer_depth=2,
853
- linear_transformer_depth=0,
854
- freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
855
- dim_head=64,
856
- heads=8,
857
- attn_dropout=0.0,
858
- ff_dropout=0.0,
859
- flash_attn=True,
860
- dim_freqs_in=1025,
861
- stft_n_fft=2048,
862
- stft_hop_length=512,
863
- stft_win_length=2048,
864
- stft_normalized=False,
865
- stft_window_fn: Optional[Callable] = None,
866
- mask_estimator_depth=2,
867
- multi_stft_resolution_loss_weight=1.0,
868
- multi_stft_resolutions_window_sizes: Tuple[int, ...] = (
869
- 4096,
870
- 2048,
871
- 1024,
872
- 512,
873
- 256,
874
- ),
875
- multi_stft_hop_size=147,
876
- multi_stft_normalized=False,
877
- multi_stft_window_fn: Callable = torch.hann_window,
878
- mlp_expansion_factor=4,
879
- use_torch_checkpoint=False,
880
- skip_connection=False,
881
- sage_attention=False,
882
- ):
883
- super().__init__()
884
-
885
- self.stereo = stereo
886
- self.audio_channels = 2 if stereo else 1
887
- self.num_stems = num_stems
888
- self.use_torch_checkpoint = use_torch_checkpoint
889
- self.skip_connection = skip_connection
890
-
891
- self.layers = ModuleList([])
892
-
893
- if sage_attention:
894
- print("Use Sage Attention")
895
-
896
- transformer_kwargs = dict(
897
- dim=dim,
898
- heads=heads,
899
- dim_head=dim_head,
900
- attn_dropout=attn_dropout,
901
- ff_dropout=ff_dropout,
902
- flash_attn=flash_attn,
903
- norm_output=False,
904
- sage_attention=sage_attention,
905
- )
906
-
907
- time_rotary_embed = RotaryEmbedding(dim=dim_head)
908
- freq_rotary_embed = RotaryEmbedding(dim=dim_head)
909
-
910
- for _ in range(depth):
911
- tran_modules = []
912
- tran_modules.append(
913
- Transformer(
914
- depth=time_transformer_depth,
915
- rotary_embed=time_rotary_embed,
916
- **transformer_kwargs,
917
- )
918
- )
919
- tran_modules.append(
920
- Transformer(
921
- depth=freq_transformer_depth,
922
- rotary_embed=freq_rotary_embed,
923
- **transformer_kwargs,
924
- )
925
- )
926
- self.layers.append(nn.ModuleList(tran_modules))
927
-
928
- self.final_norm = RMSNorm(dim)
929
-
930
- self.stft_kwargs = dict(
931
- n_fft=stft_n_fft,
932
- hop_length=stft_hop_length,
933
- win_length=stft_win_length,
934
- normalized=stft_normalized,
935
- )
936
-
937
- self.stft_window_fn = partial(
938
- default(stft_window_fn, torch.hann_window), stft_win_length
939
- )
940
-
941
- freqs = torch.stft(
942
- torch.randn(1, 4096),
943
- **self.stft_kwargs,
944
- window=torch.ones(stft_win_length),
945
- return_complex=True,
946
- ).shape[1]
947
-
948
- assert len(freqs_per_bands) > 1
949
- assert (
950
- sum(freqs_per_bands) == freqs
951
- ), f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
952
-
953
- freqs_per_bands_with_complex = tuple(
954
- 2 * f * self.audio_channels for f in freqs_per_bands
955
- )
956
-
957
- self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
958
-
959
- self.mask_estimators = nn.ModuleList([])
960
-
961
- for _ in range(num_stems):
962
- mask_estimator = MaskEstimator(
963
- dim=dim,
964
- dim_inputs=freqs_per_bands_with_complex,
965
- depth=mask_estimator_depth,
966
- mlp_expansion_factor=mlp_expansion_factor,
967
- )
968
-
969
- self.mask_estimators.append(mask_estimator)
970
-
971
- self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
972
- self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
973
- self.multi_stft_n_fft = stft_n_fft
974
- self.multi_stft_window_fn = multi_stft_window_fn
975
-
976
- self.multi_stft_kwargs = dict(
977
- hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
978
- )
979
-
980
- def forward(self, raw_audio, target=None, return_loss_breakdown=False):
981
-
982
- device = raw_audio.device
983
-
984
- x_is_mps = True if device.type == "mps" else False
985
-
986
- if raw_audio.ndim == 2:
987
- raw_audio = rearrange(raw_audio, "b t -> b 1 t")
988
-
989
- channels = raw_audio.shape[1]
990
- assert (not self.stereo and channels == 1) or (
991
- self.stereo and channels == 2
992
- ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
993
-
994
- raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
995
-
996
- stft_window = self.stft_window_fn(device=device)
997
-
998
- try:
999
- stft_repr = torch.stft(
1000
- raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
1001
- )
1002
- except:
1003
- stft_repr = torch.stft(
1004
- raw_audio.cpu() if x_is_mps else raw_audio,
1005
- **self.stft_kwargs,
1006
- window=stft_window.cpu() if x_is_mps else stft_window,
1007
- return_complex=True,
1008
- ).to(device)
1009
- stft_repr = torch.view_as_real(stft_repr)
1010
-
1011
- stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
1012
-
1013
- stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
1014
-
1015
- x = rearrange(stft_repr, "b f t c -> b t (f c)")
1016
-
1017
- x = self.band_split(x)
1018
-
1019
- for i, transformer_block in enumerate(self.layers):
1020
-
1021
- time_transformer, freq_transformer = transformer_block
1022
-
1023
- x = rearrange(x, "b t f d -> b f t d")
1024
- x, ps = pack([x], "* t d")
1025
-
1026
- x = time_transformer(x)
1027
-
1028
- (x,) = unpack(x, ps, "* t d")
1029
- x = rearrange(x, "b f t d -> b t f d")
1030
- x, ps = pack([x], "* f d")
1031
-
1032
- x = freq_transformer(x)
1033
-
1034
- (x,) = unpack(x, ps, "* f d")
1035
-
1036
- x = self.final_norm(x)
1037
-
1038
- num_stems = len(self.mask_estimators)
1039
-
1040
- mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
1041
- mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
1042
-
1043
- stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
1044
-
1045
- stft_repr = torch.view_as_complex(stft_repr)
1046
- mask = torch.view_as_complex(mask)
1047
-
1048
- stft_repr = stft_repr * mask
1049
-
1050
- stft_repr = rearrange(
1051
- stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
1052
- )
1053
-
1054
- try:
1055
- recon_audio = torch.istft(
1056
- stft_repr,
1057
- **self.stft_kwargs,
1058
- window=stft_window,
1059
- return_complex=False,
1060
- length=raw_audio.shape[-1],
1061
- )
1062
- except:
1063
- recon_audio = torch.istft(
1064
- stft_repr.cpu() if x_is_mps else stft_repr,
1065
- **self.stft_kwargs,
1066
- window=stft_window.cpu() if x_is_mps else stft_window,
1067
- return_complex=False,
1068
- length=raw_audio.shape[-1],
1069
- ).to(device)
1070
-
1071
- recon_audio = rearrange(
1072
- recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
1073
- )
1074
-
1075
- if num_stems == 1:
1076
- recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
1077
-
1078
- if not exists(target):
1079
- return recon_audio
1080
-
1081
- if self.num_stems > 1:
1082
- assert target.ndim == 4 and target.shape[1] == self.num_stems
1083
-
1084
- if target.ndim == 2:
1085
- target = rearrange(target, "... t -> ... 1 t")
1086
-
1087
- target = target[..., : recon_audio.shape[-1]]
1088
-
1089
- loss = F.l1_loss(recon_audio, target)
1090
-
1091
- multi_stft_resolution_loss = 0.0
1092
-
1093
- for window_size in self.multi_stft_resolutions_window_sizes:
1094
- res_stft_kwargs = dict(
1095
- n_fft=max(window_size, self.multi_stft_n_fft),
1096
- win_length=window_size,
1097
- return_complex=True,
1098
- window=self.multi_stft_window_fn(window_size, device=device),
1099
- **self.multi_stft_kwargs,
1100
- )
1101
-
1102
- recon_Y = torch.stft(
1103
- rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
1104
- )
1105
- target_Y = torch.stft(
1106
- rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
1107
- )
1108
-
1109
- multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
1110
- recon_Y, target_Y
1111
- )
1112
-
1113
- weighted_multi_resolution_loss = (
1114
- multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
1115
- )
1116
-
1117
- total_loss = loss + weighted_multi_resolution_loss
1118
-
1119
- if not return_loss_breakdown:
1120
- return total_loss
1121
-
1122
- return total_loss, (loss, multi_stft_resolution_loss)
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn, einsum, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ import torch.nn.functional as F
7
+
8
+ from .attend import Attend
9
+
10
+ from torch.utils.checkpoint import checkpoint
11
+
12
+ from beartype.typing import Tuple, Optional, List, Callable
13
+ from beartype import beartype
14
+
15
+ from rotary_embedding_torch import RotaryEmbedding
16
+
17
+ from einops import rearrange, pack, unpack
18
+ from einops.layers.torch import Rearrange
19
+ import torchaudio
20
+
21
+
22
+ def exists(val):
23
+ return val is not None
24
+
25
+
26
+ def default(v, d):
27
+ return v if exists(v) else d
28
+
29
+
30
+ def pack_one(t, pattern):
31
+ return pack([t], pattern)
32
+
33
+
34
+ def unpack_one(t, ps, pattern):
35
+ return unpack(t, ps, pattern)[0]
36
+
37
+
38
+ def l2norm(t):
39
+ return F.normalize(t, dim=-1, p=2)
40
+
41
+
42
+ class RMSNorm(Module):
43
+ def __init__(self, dim):
44
+ super().__init__()
45
+ self.scale = dim**0.5
46
+ self.gamma = nn.Parameter(torch.ones(dim))
47
+
48
+ def forward(self, x):
49
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
50
+
51
+
52
+ class FeedForward(Module):
53
+ def __init__(self, dim, mult=4, dropout=0.0):
54
+ super().__init__()
55
+ dim_inner = int(dim * mult)
56
+ self.net = nn.Sequential(
57
+ RMSNorm(dim),
58
+ nn.Linear(dim, dim_inner),
59
+ nn.GELU(),
60
+ nn.Dropout(dropout),
61
+ nn.Linear(dim_inner, dim),
62
+ nn.Dropout(dropout),
63
+ )
64
+
65
+ def forward(self, x):
66
+ return self.net(x)
67
+
68
+
69
+ class Attention(Module):
70
+ def __init__(
71
+ self,
72
+ dim,
73
+ heads=8,
74
+ dim_head=64,
75
+ dropout=0.0,
76
+ rotary_embed=None,
77
+ flash=True,
78
+ ):
79
+ super().__init__()
80
+ self.heads = heads
81
+ self.scale = dim_head**-0.5
82
+ dim_inner = heads * dim_head
83
+
84
+ self.rotary_embed = rotary_embed
85
+
86
+ self.attend = Attend(flash=flash, dropout=dropout)
87
+
88
+ self.norm = RMSNorm(dim)
89
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
90
+
91
+ self.to_gates = nn.Linear(dim, heads)
92
+
93
+ self.to_out = nn.Sequential(
94
+ nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
95
+ )
96
+
97
+ def forward(self, x):
98
+ x = self.norm(x)
99
+
100
+ q, k, v = rearrange(
101
+ self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
102
+ )
103
+
104
+ if exists(self.rotary_embed):
105
+ q = self.rotary_embed.rotate_queries_or_keys(q)
106
+ k = self.rotary_embed.rotate_queries_or_keys(k)
107
+
108
+ out = self.attend(q, k, v)
109
+
110
+ gates = self.to_gates(x)
111
+ out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
112
+
113
+ out = rearrange(out, "b h n d -> b n (h d)")
114
+ return self.to_out(out)
115
+
116
+
117
+ class LinearAttention(Module):
118
+
119
+ @beartype
120
+ def __init__(
121
+ self,
122
+ *,
123
+ dim,
124
+ dim_head=32,
125
+ heads=8,
126
+ scale=8,
127
+ flash=True,
128
+ dropout=0.0,
129
+ ):
130
+ super().__init__()
131
+ dim_inner = dim_head * heads
132
+ self.norm = RMSNorm(dim)
133
+
134
+ self.to_qkv = nn.Sequential(
135
+ nn.Linear(dim, dim_inner * 3, bias=False),
136
+ Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
137
+ )
138
+
139
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
140
+
141
+ self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
142
+
143
+ self.to_out = nn.Sequential(
144
+ Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
145
+ )
146
+
147
+ def forward(self, x):
148
+ x = self.norm(x)
149
+
150
+ q, k, v = self.to_qkv(x)
151
+
152
+ q, k = map(l2norm, (q, k))
153
+ q = q * self.temperature.exp()
154
+
155
+ out = self.attend(q, k, v)
156
+
157
+ return self.to_out(out)
158
+
159
+
160
+ class Transformer(Module):
161
+ def __init__(
162
+ self,
163
+ *,
164
+ dim,
165
+ depth,
166
+ dim_head=64,
167
+ heads=8,
168
+ attn_dropout=0.0,
169
+ ff_dropout=0.0,
170
+ ff_mult=4,
171
+ norm_output=True,
172
+ rotary_embed=None,
173
+ flash_attn=True,
174
+ linear_attn=False,
175
+ ):
176
+ super().__init__()
177
+ self.layers = ModuleList([])
178
+
179
+ for _ in range(depth):
180
+ if linear_attn:
181
+ attn = LinearAttention(
182
+ dim=dim,
183
+ dim_head=dim_head,
184
+ heads=heads,
185
+ dropout=attn_dropout,
186
+ flash=flash_attn,
187
+ )
188
+ else:
189
+ attn = Attention(
190
+ dim=dim,
191
+ dim_head=dim_head,
192
+ heads=heads,
193
+ dropout=attn_dropout,
194
+ rotary_embed=rotary_embed,
195
+ flash=flash_attn,
196
+ )
197
+
198
+ self.layers.append(
199
+ ModuleList(
200
+ [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
201
+ )
202
+ )
203
+
204
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
205
+
206
+ def forward(self, x):
207
+
208
+ for attn, ff in self.layers:
209
+ x = attn(x) + x
210
+ x = ff(x) + x
211
+
212
+ return self.norm(x)
213
+
214
+
215
+ class BandSplit(Module):
216
+ @beartype
217
+ def __init__(self, dim, dim_inputs: Tuple[int, ...]):
218
+ super().__init__()
219
+ self.dim_inputs = dim_inputs
220
+ self.to_features = ModuleList([])
221
+
222
+ for dim_in in dim_inputs:
223
+ net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
224
+
225
+ self.to_features.append(net)
226
+
227
+ def forward(self, x):
228
+
229
+ x = x.split(self.dim_inputs, dim=-1)
230
+
231
+ outs = []
232
+ for split_input, to_feature in zip(x, self.to_features):
233
+ split_output = to_feature(split_input)
234
+ outs.append(split_output)
235
+
236
+ x = torch.stack(outs, dim=-2)
237
+
238
+ return x
239
+
240
+
241
+ class Conv(nn.Module):
242
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
243
+ super().__init__()
244
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
245
+ self.bn = nn.InstanceNorm2d(c2, affine=True, eps=1e-8)
246
+ self.act = nn.SiLU() if act else nn.Identity()
247
+
248
+ def forward(self, x):
249
+ return self.act(self.bn(self.conv(x)))
250
+
251
+
252
+ def autopad(k, p=None):
253
+ if p is None:
254
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
255
+ return p
256
+
257
+
258
+ class DSConv(nn.Module):
259
+ def __init__(self, c1, c2, k=3, s=1, p=None, act=True):
260
+ super().__init__()
261
+ self.dwconv = nn.Conv2d(c1, c1, k, s, autopad(k, p), groups=c1, bias=False)
262
+ self.pwconv = nn.Conv2d(c1, c2, 1, 1, 0, bias=False)
263
+ self.bn = nn.InstanceNorm2d(c2, affine=True, eps=1e-8)
264
+ self.act = nn.SiLU() if act else nn.Identity()
265
+
266
+ def forward(self, x):
267
+ return self.act(self.bn(self.pwconv(self.dwconv(x))))
268
+
269
+
270
+ class DS_Bottleneck(nn.Module):
271
+ def __init__(self, c1, c2, k=3, shortcut=True):
272
+ super().__init__()
273
+ c_ = c1
274
+ self.dsconv1 = DSConv(c1, c_, k=3, s=1)
275
+ self.dsconv2 = DSConv(c_, c2, k=k, s=1)
276
+ self.shortcut = shortcut and c1 == c2
277
+
278
+ def forward(self, x):
279
+ return (
280
+ x + self.dsconv2(self.dsconv1(x))
281
+ if self.shortcut
282
+ else self.dsconv2(self.dsconv1(x))
283
+ )
284
+
285
+
286
+ class DS_C3k(nn.Module):
287
+ def __init__(self, c1, c2, n=1, k=3, e=0.5):
288
+ super().__init__()
289
+ c_ = int(c2 * e)
290
+ self.cv1 = Conv(c1, c_, 1, 1)
291
+ self.cv2 = Conv(c1, c_, 1, 1)
292
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
293
+ self.m = nn.Sequential(
294
+ *[DS_Bottleneck(c_, c_, k=k, shortcut=True) for _ in range(n)]
295
+ )
296
+
297
+ def forward(self, x):
298
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
299
+
300
+
301
+ class DS_C3k2(nn.Module):
302
+ def __init__(self, c1, c2, n=1, k=3, e=0.5):
303
+ super().__init__()
304
+ c_ = int(c2 * e)
305
+ self.cv1 = Conv(c1, c_, 1, 1)
306
+ self.m = DS_C3k(c_, c_, n=n, k=k, e=1.0)
307
+ self.cv2 = Conv(c_, c2, 1, 1)
308
+
309
+ def forward(self, x):
310
+ x_ = self.cv1(x)
311
+ x_ = self.m(x_)
312
+ return self.cv2(x_)
313
+
314
+
315
+ class AdaptiveHyperedgeGeneration(nn.Module):
316
+ def __init__(self, in_channels, num_hyperedges, num_heads=8):
317
+ super().__init__()
318
+ self.num_hyperedges = num_hyperedges
319
+ self.num_heads = num_heads
320
+ self.head_dim = in_channels // num_heads
321
+
322
+ self.global_proto = nn.Parameter(torch.randn(num_hyperedges, in_channels))
323
+
324
+ self.context_mapper = nn.Linear(
325
+ 2 * in_channels, num_hyperedges * in_channels, bias=False
326
+ )
327
+
328
+ self.query_proj = nn.Linear(in_channels, in_channels, bias=False)
329
+
330
+ self.scale = self.head_dim**-0.5
331
+
332
+ def forward(self, x):
333
+ B, N, C = x.shape
334
+
335
+ f_avg = F.adaptive_avg_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
336
+ f_max = F.adaptive_max_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
337
+ f_ctx = torch.cat((f_avg, f_max), dim=1)
338
+
339
+ delta_P = self.context_mapper(f_ctx).view(B, self.num_hyperedges, C)
340
+ P = self.global_proto.unsqueeze(0) + delta_P
341
+
342
+ z = self.query_proj(x)
343
+
344
+ z = z.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
345
+
346
+ P = P.view(B, self.num_hyperedges, self.num_heads, self.head_dim).permute(
347
+ 0, 2, 3, 1
348
+ )
349
+
350
+ sim = (z @ P) * self.scale
351
+
352
+ s_bar = sim.mean(dim=1)
353
+
354
+ A = F.softmax(s_bar.permute(0, 2, 1), dim=-1)
355
+
356
+ return A
357
+
358
+
359
+ class HypergraphConvolution(nn.Module):
360
+ def __init__(self, in_channels, out_channels):
361
+ super().__init__()
362
+ self.W_e = nn.Linear(in_channels, in_channels, bias=False)
363
+ self.W_v = nn.Linear(in_channels, out_channels, bias=False)
364
+ self.act = nn.SiLU()
365
+
366
+ def forward(self, x, A):
367
+ f_m = torch.bmm(A, x)
368
+ f_m = self.act(self.W_e(f_m))
369
+
370
+ x_out = torch.bmm(A.transpose(1, 2), f_m)
371
+ x_out = self.act(self.W_v(x_out))
372
+
373
+ return x + x_out
374
+
375
+
376
+ class AdaptiveHypergraphComputation(nn.Module):
377
+ def __init__(self, in_channels, out_channels, num_hyperedges=8, num_heads=8):
378
+ super().__init__()
379
+ self.adaptive_hyperedge_gen = AdaptiveHyperedgeGeneration(
380
+ in_channels, num_hyperedges, num_heads
381
+ )
382
+ self.hypergraph_conv = HypergraphConvolution(in_channels, out_channels)
383
+
384
+ def forward(self, x):
385
+ B, C, H, W = x.shape
386
+ x_flat = x.flatten(2).permute(0, 2, 1)
387
+
388
+ A = self.adaptive_hyperedge_gen(x_flat)
389
+
390
+ x_out_flat = self.hypergraph_conv(x_flat, A)
391
+
392
+ x_out = x_out_flat.permute(0, 2, 1).view(B, -1, H, W)
393
+ return x_out
394
+
395
+
396
+ class C3AH(nn.Module):
397
+ def __init__(self, c1, c2, num_hyperedges=8, num_heads=8, e=0.5):
398
+ super().__init__()
399
+ c_ = int(c1 * e)
400
+ self.cv1 = Conv(c1, c_, 1, 1)
401
+ self.cv2 = Conv(c1, c_, 1, 1)
402
+ self.ahc = AdaptiveHypergraphComputation(c_, c_, num_hyperedges, num_heads)
403
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
404
+
405
+ def forward(self, x):
406
+ x_lateral = self.cv1(x)
407
+ x_ahc = self.ahc(self.cv2(x))
408
+ return self.cv3(torch.cat((x_ahc, x_lateral), dim=1))
409
+
410
+
411
+ class HyperACE(nn.Module):
412
+ def __init__(
413
+ self,
414
+ in_channels: List[int],
415
+ out_channels: int,
416
+ num_hyperedges=8,
417
+ num_heads=8,
418
+ k=2,
419
+ l=1,
420
+ c_h=0.5,
421
+ c_l=0.25,
422
+ ):
423
+ super().__init__()
424
+
425
+ c2, c3, c4, c5 = in_channels
426
+ c_mid = c4
427
+
428
+ self.fuse_conv = Conv(c2 + c3 + c4 + c5, c_mid, 1, 1)
429
+
430
+ self.c_h = int(c_mid * c_h)
431
+ self.c_l = int(c_mid * c_l)
432
+ self.c_s = c_mid - self.c_h - self.c_l
433
+ assert self.c_s > 0, "Channel split error"
434
+
435
+ self.high_order_branch = nn.ModuleList(
436
+ [
437
+ C3AH(self.c_h, self.c_h, num_hyperedges, num_heads, e=1.0)
438
+ for _ in range(k)
439
+ ]
440
+ )
441
+ self.high_order_fuse = Conv(self.c_h * k, self.c_h, 1, 1)
442
+
443
+ self.low_order_branch = nn.Sequential(
444
+ *[DS_C3k(self.c_l, self.c_l, n=1, k=3, e=1.0) for _ in range(l)]
445
+ )
446
+
447
+ self.final_fuse = Conv(self.c_h + self.c_l + self.c_s, out_channels, 1, 1)
448
+
449
+ def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
450
+ B2, B3, B4, B5 = x
451
+
452
+ B, _, H4, W4 = B4.shape
453
+
454
+ B2_resized = F.interpolate(
455
+ B2, size=(H4, W4), mode="bilinear", align_corners=False
456
+ )
457
+ B3_resized = F.interpolate(
458
+ B3, size=(H4, W4), mode="bilinear", align_corners=False
459
+ )
460
+ B5_resized = F.interpolate(
461
+ B5, size=(H4, W4), mode="bilinear", align_corners=False
462
+ )
463
+
464
+ x_b = self.fuse_conv(torch.cat((B2_resized, B3_resized, B4, B5_resized), dim=1))
465
+
466
+ x_h, x_l, x_s = torch.split(x_b, [self.c_h, self.c_l, self.c_s], dim=1)
467
+
468
+ x_h_outs = [m(x_h) for m in self.high_order_branch]
469
+ x_h_fused = self.high_order_fuse(torch.cat(x_h_outs, dim=1))
470
+
471
+ x_l_out = self.low_order_branch(x_l)
472
+
473
+ y = self.final_fuse(torch.cat((x_h_fused, x_l_out, x_s), dim=1))
474
+
475
+ return y
476
+
477
+
478
+ class GatedFusion(nn.Module):
479
+ def __init__(self, in_channels):
480
+ super().__init__()
481
+ self.gamma = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
482
+
483
+ def forward(self, f_in, h):
484
+ if f_in.shape[1] != h.shape[1]:
485
+ raise ValueError(f"Channel mismatch: f_in={f_in.shape}, h={h.shape}")
486
+ return f_in + self.gamma * h
487
+
488
+
489
+ class Backbone(nn.Module):
490
+ def __init__(self, in_channels=256, base_channels=64, base_depth=3):
491
+ super().__init__()
492
+ c = base_channels
493
+ c2 = base_channels
494
+ c3 = 256
495
+ c4 = 384
496
+ c5 = 512
497
+ c6 = 768
498
+
499
+ self.stem = DSConv(in_channels, c2, k=3, s=(2, 1), p=1)
500
+
501
+ self.p2 = nn.Sequential(
502
+ DSConv(c2, c3, k=3, s=(2, 1), p=1), DS_C3k2(c3, c3, n=base_depth)
503
+ )
504
+
505
+ self.p3 = nn.Sequential(
506
+ DSConv(c3, c4, k=3, s=(2, 1), p=1), DS_C3k2(c4, c4, n=base_depth * 2)
507
+ )
508
+
509
+ self.p4 = nn.Sequential(
510
+ DSConv(c4, c5, k=3, s=(2, 1), p=1), DS_C3k2(c5, c5, n=base_depth * 2)
511
+ )
512
+
513
+ self.p5 = nn.Sequential(
514
+ DSConv(c5, c6, k=3, s=(2, 1), p=1), DS_C3k2(c6, c6, n=base_depth)
515
+ )
516
+
517
+ self.out_channels = [c3, c4, c5, c6]
518
+
519
+ def forward(self, x):
520
+ x = self.stem(x)
521
+ x2 = self.p2(x)
522
+ x3 = self.p3(x2)
523
+ x4 = self.p4(x3)
524
+ x5 = self.p5(x4)
525
+ return [x2, x3, x4, x5]
526
+
527
+
528
+ class Decoder(nn.Module):
529
+ def __init__(
530
+ self,
531
+ encoder_channels: List[int],
532
+ hyperace_out_c: int,
533
+ decoder_channels: List[int],
534
+ ):
535
+ super().__init__()
536
+ c_p2, c_p3, c_p4, c_p5 = encoder_channels
537
+ c_d2, c_d3, c_d4, c_d5 = decoder_channels
538
+
539
+ self.h_to_d5 = Conv(hyperace_out_c, c_d5, 1, 1)
540
+ self.h_to_d4 = Conv(hyperace_out_c, c_d4, 1, 1)
541
+ self.h_to_d3 = Conv(hyperace_out_c, c_d3, 1, 1)
542
+ self.h_to_d2 = Conv(hyperace_out_c, c_d2, 1, 1)
543
+
544
+ self.fusion_d5 = GatedFusion(c_d5)
545
+ self.fusion_d4 = GatedFusion(c_d4)
546
+ self.fusion_d3 = GatedFusion(c_d3)
547
+ self.fusion_d2 = GatedFusion(c_d2)
548
+
549
+ self.skip_p5 = Conv(c_p5, c_d5, 1, 1)
550
+ self.skip_p4 = Conv(c_p4, c_d4, 1, 1)
551
+ self.skip_p3 = Conv(c_p3, c_d3, 1, 1)
552
+ self.skip_p2 = Conv(c_p2, c_d2, 1, 1)
553
+
554
+ self.up_d5 = DS_C3k2(c_d5, c_d4, n=1)
555
+ self.up_d4 = DS_C3k2(c_d4, c_d3, n=1)
556
+ self.up_d3 = DS_C3k2(c_d3, c_d2, n=1)
557
+
558
+ self.final_d2 = DS_C3k2(c_d2, c_d2, n=1)
559
+
560
+ def forward(self, enc_feats: List[torch.Tensor], h_ace: torch.Tensor):
561
+ p2, p3, p4, p5 = enc_feats
562
+
563
+ d5 = self.skip_p5(p5)
564
+ h_d5 = self.h_to_d5(F.interpolate(h_ace, size=d5.shape[2:], mode="bilinear"))
565
+ d5 = self.fusion_d5(d5, h_d5)
566
+
567
+ d5_up = F.interpolate(d5, size=p4.shape[2:], mode="bilinear")
568
+ d4_skip = self.skip_p4(p4)
569
+ d4 = self.up_d5(d5_up) + d4_skip
570
+
571
+ h_d4 = self.h_to_d4(F.interpolate(h_ace, size=d4.shape[2:], mode="bilinear"))
572
+ d4 = self.fusion_d4(d4, h_d4)
573
+
574
+ d4_up = F.interpolate(d4, size=p3.shape[2:], mode="bilinear")
575
+ d3_skip = self.skip_p3(p3)
576
+ d3 = self.up_d4(d4_up) + d3_skip
577
+
578
+ h_d3 = self.h_to_d3(F.interpolate(h_ace, size=d3.shape[2:], mode="bilinear"))
579
+ d3 = self.fusion_d3(d3, h_d3)
580
+
581
+ d3_up = F.interpolate(d3, size=p2.shape[2:], mode="bilinear")
582
+ d2_skip = self.skip_p2(p2)
583
+ d2 = self.up_d3(d3_up) + d2_skip
584
+
585
+ h_d2 = self.h_to_d2(F.interpolate(h_ace, size=d2.shape[2:], mode="bilinear"))
586
+ d2 = self.fusion_d2(d2, h_d2)
587
+
588
+ d2_final = self.final_d2(d2)
589
+
590
+ return d2_final
591
+
592
+
593
+ class FreqPixelShuffle(nn.Module):
594
+ def __init__(self, in_channels, out_channels, scale=2):
595
+ super().__init__()
596
+ self.scale = scale
597
+ self.conv = DSConv(in_channels, out_channels * scale, k=3, s=1, p=1)
598
+ self.act = nn.SiLU()
599
+
600
+ def forward(self, x):
601
+ x = self.conv(x)
602
+ B, C_r, H, W = x.shape
603
+ out_c = C_r // self.scale
604
+
605
+ x = x.view(B, out_c, self.scale, H, W)
606
+
607
+ x = x.permute(0, 1, 3, 4, 2).contiguous()
608
+ x = x.view(B, out_c, H, W * self.scale)
609
+
610
+ return x
611
+
612
+
613
+ class ProgressiveUpsampleHead(nn.Module):
614
+ def __init__(self, in_channels, out_channels, target_bins=1025):
615
+ super().__init__()
616
+ self.target_bins = target_bins
617
+
618
+ c = in_channels
619
+
620
+ self.block1 = FreqPixelShuffle(c, c, scale=2)
621
+ self.block2 = FreqPixelShuffle(c, c // 2, scale=2)
622
+ self.block3 = FreqPixelShuffle(c // 2, c // 2, scale=2)
623
+ self.block4 = FreqPixelShuffle(c // 2, c // 4, scale=2)
624
+
625
+ self.final_conv = nn.Conv2d(c // 4, out_channels, kernel_size=1, bias=False)
626
+
627
+ def forward(self, x):
628
+
629
+ x = self.block1(x)
630
+ x = self.block2(x)
631
+ x = self.block3(x)
632
+ x = self.block4(x)
633
+
634
+ if x.shape[-1] != self.target_bins:
635
+ x = F.interpolate(
636
+ x,
637
+ size=(x.shape[2], self.target_bins),
638
+ mode="bilinear",
639
+ align_corners=False,
640
+ )
641
+
642
+ x = self.final_conv(x)
643
+ return x
644
+
645
+
646
+ class SegmModel(nn.Module):
647
+ def __init__(
648
+ self,
649
+ in_bands=62,
650
+ in_dim=256,
651
+ out_bins=1025,
652
+ out_channels=4,
653
+ base_channels=64,
654
+ base_depth=2,
655
+ num_hyperedges=16,
656
+ num_heads=8,
657
+ ):
658
+ super().__init__()
659
+
660
+ self.backbone = Backbone(
661
+ in_channels=in_dim, base_channels=base_channels, base_depth=base_depth
662
+ )
663
+ enc_channels = self.backbone.out_channels
664
+ c2, c3, c4, c5 = enc_channels
665
+
666
+ hyperace_in_channels = enc_channels
667
+ hyperace_out_channels = c4
668
+ self.hyperace = HyperACE(
669
+ hyperace_in_channels,
670
+ hyperace_out_channels,
671
+ num_hyperedges,
672
+ num_heads,
673
+ k=3,
674
+ l=2,
675
+ )
676
+
677
+ decoder_channels = [c2, c3, c4, c5]
678
+ self.decoder = Decoder(enc_channels, hyperace_out_channels, decoder_channels)
679
+
680
+ self.upsample_head = ProgressiveUpsampleHead(
681
+ in_channels=decoder_channels[0],
682
+ out_channels=out_channels,
683
+ target_bins=out_bins,
684
+ )
685
+
686
+ def forward(self, x):
687
+ H, W = x.shape[2:]
688
+
689
+ enc_feats = self.backbone(x)
690
+
691
+ h_ace_feats = self.hyperace(enc_feats)
692
+
693
+ dec_feat = self.decoder(enc_feats, h_ace_feats)
694
+
695
+ feat_time_restored = F.interpolate(
696
+ dec_feat, size=(H, dec_feat.shape[-1]), mode="bilinear", align_corners=False
697
+ )
698
+
699
+ out = self.upsample_head(feat_time_restored)
700
+
701
+ return out
702
+
703
+
704
+ def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
705
+ dim_hidden = default(dim_hidden, dim_in)
706
+
707
+ net = []
708
+ dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
709
+
710
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
711
+ is_last = ind == (len(dims) - 2)
712
+
713
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
714
+
715
+ if is_last:
716
+ continue
717
+
718
+ net.append(activation())
719
+
720
+ return nn.Sequential(*net)
721
+
722
+
723
+ class MaskEstimator(Module):
724
+ @beartype
725
+ def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
726
+ super().__init__()
727
+ self.dim_inputs = dim_inputs
728
+ self.to_freqs = ModuleList([])
729
+ dim_hidden = dim * mlp_expansion_factor
730
+
731
+ for dim_in in dim_inputs:
732
+ net = []
733
+
734
+ mlp = nn.Sequential(
735
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1)
736
+ )
737
+
738
+ self.to_freqs.append(mlp)
739
+
740
+ self.segm = SegmModel(
741
+ in_bands=len(dim_inputs), in_dim=dim, out_bins=sum(dim_inputs) // 4
742
+ )
743
+
744
+ def forward(self, x):
745
+ y = rearrange(x, "b t f c -> b c t f")
746
+ y = self.segm(y)
747
+ y = rearrange(y, "b c t f -> b t (f c)")
748
+
749
+ x = x.unbind(dim=-2)
750
+
751
+ outs = []
752
+
753
+ for band_features, mlp in zip(x, self.to_freqs):
754
+ freq_out = mlp(band_features)
755
+ outs.append(freq_out)
756
+
757
+ return torch.cat(outs, dim=-1) + y
758
+
759
+
760
+ DEFAULT_FREQS_PER_BANDS = (
761
+ 2,
762
+ 2,
763
+ 2,
764
+ 2,
765
+ 2,
766
+ 2,
767
+ 2,
768
+ 2,
769
+ 2,
770
+ 2,
771
+ 2,
772
+ 2,
773
+ 2,
774
+ 2,
775
+ 2,
776
+ 2,
777
+ 2,
778
+ 2,
779
+ 2,
780
+ 2,
781
+ 2,
782
+ 2,
783
+ 2,
784
+ 2,
785
+ 4,
786
+ 4,
787
+ 4,
788
+ 4,
789
+ 4,
790
+ 4,
791
+ 4,
792
+ 4,
793
+ 4,
794
+ 4,
795
+ 4,
796
+ 4,
797
+ 12,
798
+ 12,
799
+ 12,
800
+ 12,
801
+ 12,
802
+ 12,
803
+ 12,
804
+ 12,
805
+ 24,
806
+ 24,
807
+ 24,
808
+ 24,
809
+ 24,
810
+ 24,
811
+ 24,
812
+ 24,
813
+ 48,
814
+ 48,
815
+ 48,
816
+ 48,
817
+ 48,
818
+ 48,
819
+ 48,
820
+ 48,
821
+ 128,
822
+ 129,
823
+ )
824
+
825
+
826
+ class BSRoformerHyperACE(Module):
827
+
828
+ @beartype
829
+ def __init__(
830
+ self,
831
+ dim,
832
+ *,
833
+ depth,
834
+ stereo=False,
835
+ num_stems=1,
836
+ time_transformer_depth=2,
837
+ freq_transformer_depth=2,
838
+ linear_transformer_depth=0,
839
+ freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
840
+ dim_head=64,
841
+ heads=8,
842
+ attn_dropout=0.0,
843
+ ff_dropout=0.0,
844
+ flash_attn=True,
845
+ dim_freqs_in=1025,
846
+ stft_n_fft=2048,
847
+ stft_hop_length=512,
848
+ stft_win_length=2048,
849
+ stft_normalized=False,
850
+ stft_window_fn: Optional[Callable] = None,
851
+ mask_estimator_depth=2,
852
+ multi_stft_resolution_loss_weight=1.0,
853
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (
854
+ 4096,
855
+ 2048,
856
+ 1024,
857
+ 512,
858
+ 256,
859
+ ),
860
+ multi_stft_hop_size=147,
861
+ multi_stft_normalized=False,
862
+ multi_stft_window_fn: Callable = torch.hann_window,
863
+ mlp_expansion_factor=4,
864
+ use_torch_checkpoint=False,
865
+ skip_connection=False,
866
+ **kwargs
867
+ ):
868
+ super().__init__()
869
+
870
+ self.stereo = stereo
871
+ self.audio_channels = 2 if stereo else 1
872
+ self.num_stems = num_stems
873
+ self.use_torch_checkpoint = use_torch_checkpoint
874
+ self.skip_connection = skip_connection
875
+
876
+ self.layers = ModuleList([])
877
+
878
+ transformer_kwargs = dict(
879
+ dim=dim,
880
+ heads=heads,
881
+ dim_head=dim_head,
882
+ attn_dropout=attn_dropout,
883
+ ff_dropout=ff_dropout,
884
+ flash_attn=flash_attn,
885
+ norm_output=False,
886
+ )
887
+
888
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
889
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
890
+
891
+ for _ in range(depth):
892
+ tran_modules = []
893
+ tran_modules.append(
894
+ Transformer(
895
+ depth=time_transformer_depth,
896
+ rotary_embed=time_rotary_embed,
897
+ **transformer_kwargs,
898
+ )
899
+ )
900
+ tran_modules.append(
901
+ Transformer(
902
+ depth=freq_transformer_depth,
903
+ rotary_embed=freq_rotary_embed,
904
+ **transformer_kwargs,
905
+ )
906
+ )
907
+ self.layers.append(nn.ModuleList(tran_modules))
908
+
909
+ self.final_norm = RMSNorm(dim)
910
+
911
+ self.stft_kwargs = dict(
912
+ n_fft=stft_n_fft,
913
+ hop_length=stft_hop_length,
914
+ win_length=stft_win_length,
915
+ normalized=stft_normalized,
916
+ )
917
+
918
+ self.stft_window_fn = partial(
919
+ default(stft_window_fn, torch.hann_window), stft_win_length
920
+ )
921
+
922
+ freqs = torch.stft(
923
+ torch.randn(1, 4096),
924
+ **self.stft_kwargs,
925
+ window=torch.ones(stft_win_length),
926
+ return_complex=True,
927
+ ).shape[1]
928
+
929
+ assert len(freqs_per_bands) > 1
930
+ assert (
931
+ sum(freqs_per_bands) == freqs
932
+ ), f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
933
+
934
+ freqs_per_bands_with_complex = tuple(
935
+ 2 * f * self.audio_channels for f in freqs_per_bands
936
+ )
937
+
938
+ self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
939
+
940
+ self.mask_estimators = nn.ModuleList([])
941
+
942
+ for _ in range(num_stems):
943
+ mask_estimator = MaskEstimator(
944
+ dim=dim,
945
+ dim_inputs=freqs_per_bands_with_complex,
946
+ depth=mask_estimator_depth,
947
+ mlp_expansion_factor=mlp_expansion_factor,
948
+ )
949
+
950
+ self.mask_estimators.append(mask_estimator)
951
+
952
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
953
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
954
+ self.multi_stft_n_fft = stft_n_fft
955
+ self.multi_stft_window_fn = multi_stft_window_fn
956
+
957
+ self.multi_stft_kwargs = dict(
958
+ hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
959
+ )
960
+
961
+ def forward(self, raw_audio, target=None, return_loss_breakdown=False):
962
+
963
+ device = raw_audio.device
964
+
965
+ x_is_mps = True if device.type == "mps" else False
966
+
967
+ if raw_audio.ndim == 2:
968
+ raw_audio = rearrange(raw_audio, "b t -> b 1 t")
969
+
970
+ channels = raw_audio.shape[1]
971
+ assert (not self.stereo and channels == 1) or (
972
+ self.stereo and channels == 2
973
+ ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
974
+
975
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
976
+
977
+ stft_window = self.stft_window_fn(device=device)
978
+
979
+ try:
980
+ stft_repr = torch.stft(
981
+ raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
982
+ )
983
+ except:
984
+ stft_repr = torch.stft(
985
+ raw_audio.cpu() if x_is_mps else raw_audio,
986
+ **self.stft_kwargs,
987
+ window=stft_window.cpu() if x_is_mps else stft_window,
988
+ return_complex=True,
989
+ ).to(device)
990
+ stft_repr = torch.view_as_real(stft_repr)
991
+
992
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
993
+
994
+ stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
995
+
996
+ x = rearrange(stft_repr, "b f t c -> b t (f c)")
997
+
998
+ x = self.band_split(x)
999
+
1000
+ for i, transformer_block in enumerate(self.layers):
1001
+
1002
+ time_transformer, freq_transformer = transformer_block
1003
+
1004
+ x = rearrange(x, "b t f d -> b f t d")
1005
+ x, ps = pack([x], "* t d")
1006
+
1007
+ x = time_transformer(x)
1008
+
1009
+ (x,) = unpack(x, ps, "* t d")
1010
+ x = rearrange(x, "b f t d -> b t f d")
1011
+ x, ps = pack([x], "* f d")
1012
+
1013
+ x = freq_transformer(x)
1014
+
1015
+ (x,) = unpack(x, ps, "* f d")
1016
+
1017
+ x = self.final_norm(x)
1018
+
1019
+ num_stems = len(self.mask_estimators)
1020
+
1021
+ mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
1022
+ mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
1023
+
1024
+ stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
1025
+
1026
+ stft_repr = torch.view_as_complex(stft_repr)
1027
+ mask = torch.view_as_complex(mask)
1028
+
1029
+ stft_repr = stft_repr * mask
1030
+
1031
+ stft_repr = rearrange(
1032
+ stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
1033
+ )
1034
+
1035
+ try:
1036
+ recon_audio = torch.istft(
1037
+ stft_repr,
1038
+ **self.stft_kwargs,
1039
+ window=stft_window,
1040
+ return_complex=False,
1041
+ length=raw_audio.shape[-1],
1042
+ )
1043
+ except:
1044
+ recon_audio = torch.istft(
1045
+ stft_repr.cpu() if x_is_mps else stft_repr,
1046
+ **self.stft_kwargs,
1047
+ window=stft_window.cpu() if x_is_mps else stft_window,
1048
+ return_complex=False,
1049
+ length=raw_audio.shape[-1],
1050
+ ).to(device)
1051
+
1052
+ recon_audio = rearrange(
1053
+ recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
1054
+ )
1055
+
1056
+ if num_stems == 1:
1057
+ recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
1058
+
1059
+ if not exists(target):
1060
+ return recon_audio
1061
+
1062
+ if self.num_stems > 1:
1063
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
1064
+
1065
+ if target.ndim == 2:
1066
+ target = rearrange(target, "... t -> ... 1 t")
1067
+
1068
+ target = target[..., : recon_audio.shape[-1]]
1069
+
1070
+ loss = F.l1_loss(recon_audio, target)
1071
+
1072
+ multi_stft_resolution_loss = 0.0
1073
+
1074
+ for window_size in self.multi_stft_resolutions_window_sizes:
1075
+ res_stft_kwargs = dict(
1076
+ n_fft=max(window_size, self.multi_stft_n_fft),
1077
+ win_length=window_size,
1078
+ return_complex=True,
1079
+ window=self.multi_stft_window_fn(window_size, device=device),
1080
+ **self.multi_stft_kwargs,
1081
+ )
1082
+
1083
+ recon_Y = torch.stft(
1084
+ rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
1085
+ )
1086
+ target_Y = torch.stft(
1087
+ rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
1088
+ )
1089
+
1090
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
1091
+ recon_Y, target_Y
1092
+ )
1093
+
1094
+ weighted_multi_resolution_loss = (
1095
+ multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
1096
+ )
1097
+
1098
+ total_loss = loss + weighted_multi_resolution_loss
1099
+
1100
+ if not return_loss_breakdown:
1101
+ return total_loss
1102
+
1103
+ return total_loss, (loss, multi_stft_resolution_loss)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/bs_roformer/bs_roformer_hyperace2.py CHANGED
@@ -1,1166 +1,1147 @@
1
- from functools import partial
2
-
3
- import torch
4
- from torch import nn, einsum, Tensor
5
- from torch.nn import Module, ModuleList
6
- import torch.nn.functional as F
7
-
8
- from .attend import Attend
9
-
10
- try:
11
- from .attend_sage import Attend as AttendSage
12
- except:
13
- pass
14
- from torch.utils.checkpoint import checkpoint
15
-
16
- from beartype.typing import Tuple, Optional, List, Callable
17
- from beartype import beartype
18
-
19
- from rotary_embedding_torch import RotaryEmbedding
20
-
21
- from einops import rearrange, pack, unpack
22
- from einops.layers.torch import Rearrange
23
- import torchaudio
24
-
25
-
26
- def exists(val):
27
- return val is not None
28
-
29
-
30
- def default(v, d):
31
- return v if exists(v) else d
32
-
33
-
34
- def pack_one(t, pattern):
35
- return pack([t], pattern)
36
-
37
-
38
- def unpack_one(t, ps, pattern):
39
- return unpack(t, ps, pattern)[0]
40
-
41
-
42
- def l2norm(t):
43
- return F.normalize(t, dim=-1, p=2)
44
-
45
-
46
- class RMSNorm(Module):
47
- def __init__(self, dim):
48
- super().__init__()
49
- self.scale = dim**0.5
50
- self.gamma = nn.Parameter(torch.ones(dim))
51
-
52
- def forward(self, x):
53
- return F.normalize(x, dim=-1) * self.scale * self.gamma
54
-
55
-
56
- class FeedForward(Module):
57
- def __init__(self, dim, mult=4, dropout=0.0):
58
- super().__init__()
59
- dim_inner = int(dim * mult)
60
- self.net = nn.Sequential(
61
- RMSNorm(dim),
62
- nn.Linear(dim, dim_inner),
63
- nn.GELU(),
64
- nn.Dropout(dropout),
65
- nn.Linear(dim_inner, dim),
66
- nn.Dropout(dropout),
67
- )
68
-
69
- def forward(self, x):
70
- return self.net(x)
71
-
72
-
73
- class Attention(Module):
74
- def __init__(
75
- self,
76
- dim,
77
- heads=8,
78
- dim_head=64,
79
- dropout=0.0,
80
- rotary_embed=None,
81
- flash=True,
82
- sage_attention=False,
83
- ):
84
- super().__init__()
85
- self.heads = heads
86
- self.scale = dim_head**-0.5
87
- dim_inner = heads * dim_head
88
-
89
- self.rotary_embed = rotary_embed
90
-
91
- if sage_attention:
92
- self.attend = AttendSage(flash=flash, dropout=dropout)
93
- else:
94
- self.attend = Attend(flash=flash, dropout=dropout)
95
-
96
- self.norm = RMSNorm(dim)
97
- self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
98
-
99
- self.to_gates = nn.Linear(dim, heads)
100
-
101
- self.to_out = nn.Sequential(
102
- nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
103
- )
104
-
105
- def forward(self, x):
106
- x = self.norm(x)
107
-
108
- q, k, v = rearrange(
109
- self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
110
- )
111
-
112
- if exists(self.rotary_embed):
113
- q = self.rotary_embed.rotate_queries_or_keys(q)
114
- k = self.rotary_embed.rotate_queries_or_keys(k)
115
-
116
- out = self.attend(q, k, v)
117
-
118
- gates = self.to_gates(x)
119
- out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
120
-
121
- out = rearrange(out, "b h n d -> b n (h d)")
122
- return self.to_out(out)
123
-
124
-
125
- class LinearAttention(Module):
126
-
127
- @beartype
128
- def __init__(
129
- self,
130
- *,
131
- dim,
132
- dim_head=32,
133
- heads=8,
134
- scale=8,
135
- flash=True,
136
- dropout=0.0,
137
- sage_attention=False,
138
- ):
139
- super().__init__()
140
- dim_inner = dim_head * heads
141
- self.norm = RMSNorm(dim)
142
-
143
- self.to_qkv = nn.Sequential(
144
- nn.Linear(dim, dim_inner * 3, bias=False),
145
- Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
146
- )
147
-
148
- self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
149
-
150
- if sage_attention:
151
- self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash)
152
- else:
153
- self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
154
-
155
- self.to_out = nn.Sequential(
156
- Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
157
- )
158
-
159
- def forward(self, x):
160
- x = self.norm(x)
161
-
162
- q, k, v = self.to_qkv(x)
163
-
164
- q, k = map(l2norm, (q, k))
165
- q = q * self.temperature.exp()
166
-
167
- out = self.attend(q, k, v)
168
-
169
- return self.to_out(out)
170
-
171
-
172
- class Transformer(Module):
173
- def __init__(
174
- self,
175
- *,
176
- dim,
177
- depth,
178
- dim_head=64,
179
- heads=8,
180
- attn_dropout=0.0,
181
- ff_dropout=0.0,
182
- ff_mult=4,
183
- norm_output=True,
184
- rotary_embed=None,
185
- flash_attn=True,
186
- linear_attn=False,
187
- sage_attention=False,
188
- ):
189
- super().__init__()
190
- self.layers = ModuleList([])
191
-
192
- for _ in range(depth):
193
- if linear_attn:
194
- attn = LinearAttention(
195
- dim=dim,
196
- dim_head=dim_head,
197
- heads=heads,
198
- dropout=attn_dropout,
199
- flash=flash_attn,
200
- sage_attention=sage_attention,
201
- )
202
- else:
203
- attn = Attention(
204
- dim=dim,
205
- dim_head=dim_head,
206
- heads=heads,
207
- dropout=attn_dropout,
208
- rotary_embed=rotary_embed,
209
- flash=flash_attn,
210
- sage_attention=sage_attention,
211
- )
212
-
213
- self.layers.append(
214
- ModuleList(
215
- [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
216
- )
217
- )
218
-
219
- self.norm = RMSNorm(dim) if norm_output else nn.Identity()
220
-
221
- def forward(self, x):
222
-
223
- for attn, ff in self.layers:
224
- x = attn(x) + x
225
- x = ff(x) + x
226
-
227
- return self.norm(x)
228
-
229
-
230
- class BandSplit(Module):
231
- @beartype
232
- def __init__(self, dim, dim_inputs: Tuple[int, ...]):
233
- super().__init__()
234
- self.dim_inputs = dim_inputs
235
- self.to_features = ModuleList([])
236
-
237
- for dim_in in dim_inputs:
238
- net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
239
-
240
- self.to_features.append(net)
241
-
242
- def forward(self, x):
243
-
244
- x = x.split(self.dim_inputs, dim=-1)
245
-
246
- outs = []
247
- for split_input, to_feature in zip(x, self.to_features):
248
- split_output = to_feature(split_input)
249
- outs.append(split_output)
250
-
251
- x = torch.stack(outs, dim=-2)
252
-
253
- return x
254
-
255
-
256
- class Conv(nn.Module):
257
- def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
258
- super().__init__()
259
- self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
260
- self.bn = nn.InstanceNorm2d(c2, affine=True, eps=1e-8)
261
- self.act = nn.SiLU() if act else nn.Identity()
262
-
263
- def forward(self, x):
264
- return self.act(self.bn(self.conv(x)))
265
-
266
-
267
- def autopad(k, p=None):
268
- if p is None:
269
- p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
270
- return p
271
-
272
-
273
- class DSConv(nn.Module):
274
- def __init__(self, c1, c2, k=3, s=1, p=None, act=True):
275
- super().__init__()
276
- self.dwconv = nn.Conv2d(c1, c1, k, s, autopad(k, p), groups=c1, bias=False)
277
- self.pwconv = nn.Conv2d(c1, c2, 1, 1, 0, bias=False)
278
- self.bn = nn.InstanceNorm2d(c2, affine=True, eps=1e-8)
279
- self.act = nn.SiLU() if act else nn.Identity()
280
-
281
- def forward(self, x):
282
- return self.act(self.bn(self.pwconv(self.dwconv(x))))
283
-
284
-
285
- class DS_Bottleneck(nn.Module):
286
- def __init__(self, c1, c2, k=3, shortcut=True):
287
- super().__init__()
288
- c_ = c1
289
- self.dsconv1 = DSConv(c1, c_, k=3, s=1)
290
- self.dsconv2 = DSConv(c_, c2, k=k, s=1)
291
- self.shortcut = shortcut and c1 == c2
292
-
293
- def forward(self, x):
294
- return (
295
- x + self.dsconv2(self.dsconv1(x))
296
- if self.shortcut
297
- else self.dsconv2(self.dsconv1(x))
298
- )
299
-
300
-
301
- class DS_C3k(nn.Module):
302
- def __init__(self, c1, c2, n=1, k=3, e=0.5):
303
- super().__init__()
304
- c_ = int(c2 * e)
305
- self.cv1 = Conv(c1, c_, 1, 1)
306
- self.cv2 = Conv(c1, c_, 1, 1)
307
- self.cv3 = Conv(2 * c_, c2, 1, 1)
308
- self.m = nn.Sequential(
309
- *[DS_Bottleneck(c_, c_, k=k, shortcut=True) for _ in range(n)]
310
- )
311
-
312
- def forward(self, x):
313
- return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
314
-
315
-
316
- class DS_C3k2(nn.Module):
317
- def __init__(self, c1, c2, n=1, k=3, e=0.5):
318
- super().__init__()
319
- c_ = int(c2 * e)
320
- self.cv1 = Conv(c1, c_, 1, 1)
321
- self.m = DS_C3k(c_, c_, n=n, k=k, e=1.0)
322
- self.cv2 = Conv(c_, c2, 1, 1)
323
-
324
- def forward(self, x):
325
- x_ = self.cv1(x)
326
- x_ = self.m(x_)
327
- return self.cv2(x_)
328
-
329
-
330
- class AdaptiveHyperedgeGeneration(nn.Module):
331
- def __init__(self, in_channels, num_hyperedges, num_heads=8):
332
- super().__init__()
333
- self.num_hyperedges = num_hyperedges
334
- self.num_heads = num_heads
335
- self.head_dim = in_channels // num_heads
336
-
337
- self.global_proto = nn.Parameter(torch.randn(num_hyperedges, in_channels))
338
-
339
- self.context_mapper = nn.Linear(
340
- 2 * in_channels, num_hyperedges * in_channels, bias=False
341
- )
342
-
343
- self.query_proj = nn.Linear(in_channels, in_channels, bias=False)
344
-
345
- self.scale = self.head_dim**-0.5
346
-
347
- def forward(self, x):
348
- B, N, C = x.shape
349
-
350
- f_avg = F.adaptive_avg_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
351
- f_max = F.adaptive_max_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
352
- f_ctx = torch.cat((f_avg, f_max), dim=1)
353
-
354
- delta_P = self.context_mapper(f_ctx).view(B, self.num_hyperedges, C)
355
- P = self.global_proto.unsqueeze(0) + delta_P
356
-
357
- z = self.query_proj(x)
358
-
359
- z = z.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
360
-
361
- P = P.view(B, self.num_hyperedges, self.num_heads, self.head_dim).permute(
362
- 0, 2, 3, 1
363
- )
364
-
365
- sim = (z @ P) * self.scale
366
-
367
- s_bar = sim.mean(dim=1)
368
-
369
- A = F.softmax(s_bar.permute(0, 2, 1), dim=-1)
370
-
371
- return A
372
-
373
-
374
- class HypergraphConvolution(nn.Module):
375
- def __init__(self, in_channels, out_channels):
376
- super().__init__()
377
- self.W_e = nn.Linear(in_channels, in_channels, bias=False)
378
- self.W_v = nn.Linear(in_channels, out_channels, bias=False)
379
- self.act = nn.SiLU()
380
-
381
- def forward(self, x, A):
382
- f_m = torch.bmm(A, x)
383
- f_m = self.act(self.W_e(f_m))
384
-
385
- x_out = torch.bmm(A.transpose(1, 2), f_m)
386
- x_out = self.act(self.W_v(x_out))
387
-
388
- return x + x_out
389
-
390
-
391
- class AdaptiveHypergraphComputation(nn.Module):
392
- def __init__(self, in_channels, out_channels, num_hyperedges=8, num_heads=8):
393
- super().__init__()
394
- self.adaptive_hyperedge_gen = AdaptiveHyperedgeGeneration(
395
- in_channels, num_hyperedges, num_heads
396
- )
397
- self.hypergraph_conv = HypergraphConvolution(in_channels, out_channels)
398
-
399
- def forward(self, x):
400
- B, C, H, W = x.shape
401
- x_flat = x.flatten(2).permute(0, 2, 1)
402
-
403
- A = self.adaptive_hyperedge_gen(x_flat)
404
-
405
- x_out_flat = self.hypergraph_conv(x_flat, A)
406
-
407
- x_out = x_out_flat.permute(0, 2, 1).view(B, -1, H, W)
408
- return x_out
409
-
410
-
411
- class C3AH(nn.Module):
412
- def __init__(self, c1, c2, num_hyperedges=8, num_heads=8, e=0.5):
413
- super().__init__()
414
- c_ = int(c1 * e)
415
- self.cv1 = Conv(c1, c_, 1, 1)
416
- self.cv2 = Conv(c1, c_, 1, 1)
417
- self.ahc = AdaptiveHypergraphComputation(c_, c_, num_hyperedges, num_heads)
418
- self.cv3 = Conv(2 * c_, c2, 1, 1)
419
-
420
- def forward(self, x):
421
- x_lateral = self.cv1(x)
422
- x_ahc = self.ahc(self.cv2(x))
423
- return self.cv3(torch.cat((x_ahc, x_lateral), dim=1))
424
-
425
-
426
- class HyperACE(nn.Module):
427
- def __init__(
428
- self,
429
- in_channels: List[int],
430
- out_channels: int,
431
- num_hyperedges=8,
432
- num_heads=8,
433
- k=2,
434
- l=1,
435
- c_h=0.5,
436
- c_l=0.25,
437
- ):
438
- super().__init__()
439
-
440
- c2, c3, c4, c5 = in_channels
441
- c_mid = c4
442
-
443
- self.fuse_conv = Conv(c2 + c3 + c4 + c5, c_mid, 1, 1)
444
-
445
- self.c_h = int(c_mid * c_h)
446
- self.c_l = int(c_mid * c_l)
447
- self.c_s = c_mid - self.c_h - self.c_l
448
- assert self.c_s > 0, "Channel split error"
449
-
450
- self.high_order_branch = nn.ModuleList(
451
- [
452
- C3AH(self.c_h, self.c_h, num_hyperedges, num_heads, e=1.0)
453
- for _ in range(k)
454
- ]
455
- )
456
- self.high_order_fuse = Conv(self.c_h * k, self.c_h, 1, 1)
457
-
458
- self.low_order_branch = nn.Sequential(
459
- *[DS_C3k(self.c_l, self.c_l, n=1, k=3, e=1.0) for _ in range(l)]
460
- )
461
-
462
- self.final_fuse = Conv(self.c_h + self.c_l + self.c_s, out_channels, 1, 1)
463
-
464
- def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
465
- B2, B3, B4, B5 = x
466
-
467
- B, _, H4, W4 = B4.shape
468
-
469
- B2_resized = F.interpolate(
470
- B2, size=(H4, W4), mode="bilinear", align_corners=False
471
- )
472
- B3_resized = F.interpolate(
473
- B3, size=(H4, W4), mode="bilinear", align_corners=False
474
- )
475
- B5_resized = F.interpolate(
476
- B5, size=(H4, W4), mode="bilinear", align_corners=False
477
- )
478
-
479
- x_b = self.fuse_conv(torch.cat((B2_resized, B3_resized, B4, B5_resized), dim=1))
480
-
481
- x_h, x_l, x_s = torch.split(x_b, [self.c_h, self.c_l, self.c_s], dim=1)
482
-
483
- x_h_outs = [m(x_h) for m in self.high_order_branch]
484
- x_h_fused = self.high_order_fuse(torch.cat(x_h_outs, dim=1))
485
-
486
- x_l_out = self.low_order_branch(x_l)
487
-
488
- y = self.final_fuse(torch.cat((x_h_fused, x_l_out, x_s), dim=1))
489
-
490
- return y
491
-
492
-
493
- class GatedFusion(nn.Module):
494
- def __init__(self, in_channels):
495
- super().__init__()
496
- self.gamma = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
497
-
498
- def forward(self, f_in, h):
499
- if f_in.shape[1] != h.shape[1]:
500
- raise ValueError(f"Channel mismatch: f_in={f_in.shape}, h={h.shape}")
501
- return f_in + self.gamma * h
502
-
503
-
504
- class Backbone(nn.Module):
505
- def __init__(self, in_channels=256, base_channels=64, base_depth=3):
506
- super().__init__()
507
- c = base_channels
508
- c2 = base_channels
509
- c3 = 256
510
- c4 = 384
511
- c5 = 512
512
- c6 = 768
513
-
514
- self.stem = DSConv(in_channels, c2, k=3, s=(2, 1), p=1)
515
-
516
- self.p2 = nn.Sequential(
517
- DSConv(c2, c3, k=3, s=(2, 1), p=1), DS_C3k2(c3, c3, n=base_depth)
518
- )
519
-
520
- self.p3 = nn.Sequential(
521
- DSConv(c3, c4, k=3, s=(2, 1), p=1), DS_C3k2(c4, c4, n=base_depth * 2)
522
- )
523
-
524
- self.p4 = nn.Sequential(
525
- DSConv(c4, c5, k=3, s=2, p=1), DS_C3k2(c5, c5, n=base_depth * 2)
526
- )
527
-
528
- self.p5 = nn.Sequential(
529
- DSConv(c5, c6, k=3, s=2, p=1), DS_C3k2(c6, c6, n=base_depth)
530
- )
531
-
532
- self.out_channels = [c3, c4, c5, c6]
533
-
534
- def forward(self, x):
535
- x = self.stem(x)
536
- x2 = self.p2(x)
537
- x3 = self.p3(x2)
538
- x4 = self.p4(x3)
539
- x5 = self.p5(x4)
540
- return [x2, x3, x4, x5]
541
-
542
-
543
- class Decoder(nn.Module):
544
- def __init__(
545
- self,
546
- encoder_channels: List[int],
547
- hyperace_out_c: int,
548
- decoder_channels: List[int],
549
- ):
550
- super().__init__()
551
- c_p2, c_p3, c_p4, c_p5 = encoder_channels
552
- c_d2, c_d3, c_d4, c_d5 = decoder_channels
553
-
554
- self.h_to_d5 = Conv(hyperace_out_c, c_d5, 1, 1)
555
- self.h_to_d4 = Conv(hyperace_out_c, c_d4, 1, 1)
556
- self.h_to_d3 = Conv(hyperace_out_c, c_d3, 1, 1)
557
- self.h_to_d2 = Conv(hyperace_out_c, c_d2, 1, 1)
558
-
559
- self.fusion_d5 = GatedFusion(c_d5)
560
- self.fusion_d4 = GatedFusion(c_d4)
561
- self.fusion_d3 = GatedFusion(c_d3)
562
- self.fusion_d2 = GatedFusion(c_d2)
563
-
564
- self.skip_p5 = Conv(c_p5, c_d5, 1, 1)
565
- self.skip_p4 = Conv(c_p4, c_d4, 1, 1)
566
- self.skip_p3 = Conv(c_p3, c_d3, 1, 1)
567
- self.skip_p2 = Conv(c_p2, c_d2, 1, 1)
568
-
569
- self.up_d5 = DS_C3k2(c_d5, c_d4, n=1)
570
- self.up_d4 = DS_C3k2(c_d4, c_d3, n=1)
571
- self.up_d3 = DS_C3k2(c_d3, c_d2, n=1)
572
-
573
- self.final_d2 = DS_C3k2(c_d2, c_d2, n=1)
574
-
575
- def forward(self, enc_feats: List[torch.Tensor], h_ace: torch.Tensor):
576
- p2, p3, p4, p5 = enc_feats
577
-
578
- d5 = self.skip_p5(p5)
579
- h_d5 = self.h_to_d5(F.interpolate(h_ace, size=d5.shape[2:], mode="bilinear"))
580
- d5 = self.fusion_d5(d5, h_d5)
581
-
582
- d5_up = F.interpolate(d5, size=p4.shape[2:], mode="bilinear")
583
- d4_skip = self.skip_p4(p4)
584
- d4 = self.up_d5(d5_up) + d4_skip
585
-
586
- h_d4 = self.h_to_d4(F.interpolate(h_ace, size=d4.shape[2:], mode="bilinear"))
587
- d4 = self.fusion_d4(d4, h_d4)
588
-
589
- d4_up = F.interpolate(d4, size=p3.shape[2:], mode="bilinear")
590
- d3_skip = self.skip_p3(p3)
591
- d3 = self.up_d4(d4_up) + d3_skip
592
-
593
- h_d3 = self.h_to_d3(F.interpolate(h_ace, size=d3.shape[2:], mode="bilinear"))
594
- d3 = self.fusion_d3(d3, h_d3)
595
-
596
- d3_up = F.interpolate(d3, size=p2.shape[2:], mode="bilinear")
597
- d2_skip = self.skip_p2(p2)
598
- d2 = self.up_d3(d3_up) + d2_skip
599
-
600
- h_d2 = self.h_to_d2(F.interpolate(h_ace, size=d2.shape[2:], mode="bilinear"))
601
- d2 = self.fusion_d2(d2, h_d2)
602
-
603
- d2_final = self.final_d2(d2)
604
-
605
- return d2_final
606
-
607
-
608
- class TFC_TDF(nn.Module):
609
- def __init__(self, in_c, c, l, f, bn=4):
610
- super().__init__()
611
-
612
- self.blocks = nn.ModuleList()
613
- for i in range(l):
614
- block = nn.Module()
615
-
616
- block.tfc1 = nn.Sequential(
617
- nn.InstanceNorm2d(in_c, affine=True, eps=1e-8),
618
- nn.SiLU(),
619
- nn.Conv2d(in_c, c, 3, 1, 1, bias=False),
620
- )
621
- block.tdf = nn.Sequential(
622
- nn.InstanceNorm2d(c, affine=True, eps=1e-8),
623
- nn.SiLU(),
624
- nn.Linear(f, f // bn, bias=False),
625
- nn.InstanceNorm2d(c, affine=True, eps=1e-8),
626
- nn.SiLU(),
627
- nn.Linear(f // bn, f, bias=False),
628
- )
629
- block.tfc2 = nn.Sequential(
630
- nn.InstanceNorm2d(c, affine=True, eps=1e-8),
631
- nn.SiLU(),
632
- nn.Conv2d(c, c, 3, 1, 1, bias=False),
633
- )
634
- block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False)
635
-
636
- self.blocks.append(block)
637
- in_c = c
638
-
639
- def forward(self, x):
640
- for block in self.blocks:
641
- s = block.shortcut(x)
642
- x = block.tfc1(x)
643
- x = x + block.tdf(x)
644
- x = block.tfc2(x)
645
- x = x + s
646
- return x
647
-
648
-
649
- class FreqPixelShuffle(nn.Module):
650
- def __init__(self, in_channels, out_channels, scale, f):
651
- super().__init__()
652
- self.scale = scale
653
- self.conv = DSConv(in_channels, out_channels * scale)
654
- self.out_conv = TFC_TDF(out_channels, out_channels, 2, f)
655
-
656
- def forward(self, x):
657
- x = self.conv(x)
658
- B, C_r, H, W = x.shape
659
- out_c = C_r // self.scale
660
-
661
- x = x.view(B, out_c, self.scale, H, W)
662
-
663
- x = x.permute(0, 1, 3, 4, 2).contiguous()
664
- x = x.view(B, out_c, H, W * self.scale)
665
-
666
- return self.out_conv(x)
667
-
668
-
669
- class ProgressiveUpsampleHead(nn.Module):
670
- def __init__(self, in_channels, out_channels, target_bins=1025, in_bands=62):
671
- super().__init__()
672
- self.target_bins = target_bins
673
-
674
- c = in_channels
675
-
676
- self.block1 = FreqPixelShuffle(c, c // 2, scale=2, f=in_bands * 2)
677
- self.block2 = FreqPixelShuffle(c // 2, c // 4, scale=2, f=in_bands * 4)
678
- self.block3 = FreqPixelShuffle(c // 4, c // 8, scale=2, f=in_bands * 8)
679
- self.block4 = FreqPixelShuffle(c // 8, c // 16, scale=2, f=in_bands * 16)
680
-
681
- self.final_conv = nn.Conv2d(
682
- c // 16, out_channels, kernel_size=3, stride=1, padding="same", bias=False
683
- )
684
-
685
- def forward(self, x):
686
-
687
- x = self.block1(x)
688
- x = self.block2(x)
689
- x = self.block3(x)
690
- x = self.block4(x)
691
-
692
- if x.shape[-1] != self.target_bins:
693
- x = F.interpolate(
694
- x,
695
- size=(x.shape[2], self.target_bins),
696
- mode="bilinear",
697
- align_corners=False,
698
- )
699
-
700
- x = self.final_conv(x)
701
- return x
702
-
703
-
704
- class SegmModel(nn.Module):
705
- def __init__(
706
- self,
707
- in_bands=62,
708
- in_dim=256,
709
- out_bins=1025,
710
- out_channels=4,
711
- base_channels=64,
712
- base_depth=2,
713
- num_hyperedges=32,
714
- num_heads=8,
715
- ):
716
- super().__init__()
717
-
718
- self.backbone = Backbone(
719
- in_channels=in_dim, base_channels=base_channels, base_depth=base_depth
720
- )
721
- enc_channels = self.backbone.out_channels
722
- c2, c3, c4, c5 = enc_channels
723
-
724
- hyperace_in_channels = enc_channels
725
- hyperace_out_channels = c4
726
- self.hyperace = HyperACE(
727
- hyperace_in_channels,
728
- hyperace_out_channels,
729
- num_hyperedges,
730
- num_heads,
731
- k=2,
732
- l=1,
733
- )
734
-
735
- decoder_channels = [c2, c3, c4, c5]
736
- self.decoder = Decoder(enc_channels, hyperace_out_channels, decoder_channels)
737
-
738
- self.upsample_head = ProgressiveUpsampleHead(
739
- in_channels=decoder_channels[0],
740
- out_channels=out_channels,
741
- target_bins=out_bins,
742
- in_bands=in_bands,
743
- )
744
-
745
- def forward(self, x):
746
- H, W = x.shape[2:]
747
-
748
- enc_feats = self.backbone(x)
749
-
750
- h_ace_feats = self.hyperace(enc_feats)
751
-
752
- dec_feat = self.decoder(enc_feats, h_ace_feats)
753
-
754
- feat_time_restored = F.interpolate(
755
- dec_feat, size=(H, dec_feat.shape[-1]), mode="bilinear", align_corners=False
756
- )
757
-
758
- out = self.upsample_head(feat_time_restored)
759
-
760
- return out
761
-
762
-
763
- def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
764
- dim_hidden = default(dim_hidden, dim_in)
765
-
766
- net = []
767
- dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
768
-
769
- for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
770
- is_last = ind == (len(dims) - 2)
771
-
772
- net.append(nn.Linear(layer_dim_in, layer_dim_out))
773
-
774
- if is_last:
775
- continue
776
-
777
- net.append(activation())
778
-
779
- return nn.Sequential(*net)
780
-
781
-
782
- class MaskEstimator(Module):
783
- @beartype
784
- def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
785
- super().__init__()
786
- self.dim_inputs = dim_inputs
787
- self.to_freqs = ModuleList([])
788
- dim_hidden = dim * mlp_expansion_factor
789
-
790
- for dim_in in dim_inputs:
791
- net = []
792
-
793
- mlp = nn.Sequential(
794
- MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1)
795
- )
796
-
797
- self.to_freqs.append(mlp)
798
-
799
- self.segm = SegmModel(
800
- in_bands=len(dim_inputs), in_dim=dim, out_bins=sum(dim_inputs) // 4
801
- )
802
-
803
- def forward(self, x):
804
- y = rearrange(x, "b t f c -> b c t f")
805
- y = self.segm(y)
806
- y = rearrange(y, "b c t f -> b t (f c)")
807
-
808
- x = x.unbind(dim=-2)
809
-
810
- outs = []
811
-
812
- for band_features, mlp in zip(x, self.to_freqs):
813
- freq_out = mlp(band_features)
814
- outs.append(freq_out)
815
-
816
- return torch.cat(outs, dim=-1) + y
817
-
818
-
819
- DEFAULT_FREQS_PER_BANDS = (
820
- 2,
821
- 2,
822
- 2,
823
- 2,
824
- 2,
825
- 2,
826
- 2,
827
- 2,
828
- 2,
829
- 2,
830
- 2,
831
- 2,
832
- 2,
833
- 2,
834
- 2,
835
- 2,
836
- 2,
837
- 2,
838
- 2,
839
- 2,
840
- 2,
841
- 2,
842
- 2,
843
- 2,
844
- 4,
845
- 4,
846
- 4,
847
- 4,
848
- 4,
849
- 4,
850
- 4,
851
- 4,
852
- 4,
853
- 4,
854
- 4,
855
- 4,
856
- 12,
857
- 12,
858
- 12,
859
- 12,
860
- 12,
861
- 12,
862
- 12,
863
- 12,
864
- 24,
865
- 24,
866
- 24,
867
- 24,
868
- 24,
869
- 24,
870
- 24,
871
- 24,
872
- 48,
873
- 48,
874
- 48,
875
- 48,
876
- 48,
877
- 48,
878
- 48,
879
- 48,
880
- 128,
881
- 129,
882
- )
883
-
884
-
885
- class BSRoformerHyperACE_2(Module):
886
-
887
- @beartype
888
- def __init__(
889
- self,
890
- dim,
891
- *,
892
- depth,
893
- stereo=False,
894
- num_stems=1,
895
- time_transformer_depth=2,
896
- freq_transformer_depth=2,
897
- linear_transformer_depth=0,
898
- freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
899
- dim_head=64,
900
- heads=8,
901
- attn_dropout=0.0,
902
- ff_dropout=0.0,
903
- flash_attn=True,
904
- dim_freqs_in=1025,
905
- stft_n_fft=2048,
906
- stft_hop_length=512,
907
- stft_win_length=2048,
908
- stft_normalized=False,
909
- stft_window_fn: Optional[Callable] = None,
910
- mask_estimator_depth=2,
911
- multi_stft_resolution_loss_weight=1.0,
912
- multi_stft_resolutions_window_sizes: Tuple[int, ...] = (
913
- 4096,
914
- 2048,
915
- 1024,
916
- 512,
917
- 256,
918
- ),
919
- multi_stft_hop_size=147,
920
- multi_stft_normalized=False,
921
- multi_stft_window_fn: Callable = torch.hann_window,
922
- mlp_expansion_factor=4,
923
- use_torch_checkpoint=False,
924
- skip_connection=False,
925
- sage_attention=False,
926
- ):
927
- super().__init__()
928
-
929
- self.stereo = stereo
930
- self.audio_channels = 2 if stereo else 1
931
- self.num_stems = num_stems
932
- self.use_torch_checkpoint = use_torch_checkpoint
933
- self.skip_connection = skip_connection
934
-
935
- self.layers = ModuleList([])
936
-
937
- if sage_attention:
938
- print("Use Sage Attention")
939
-
940
- transformer_kwargs = dict(
941
- dim=dim,
942
- heads=heads,
943
- dim_head=dim_head,
944
- attn_dropout=attn_dropout,
945
- ff_dropout=ff_dropout,
946
- flash_attn=flash_attn,
947
- norm_output=False,
948
- sage_attention=sage_attention,
949
- )
950
-
951
- time_rotary_embed = RotaryEmbedding(dim=dim_head)
952
- freq_rotary_embed = RotaryEmbedding(dim=dim_head)
953
-
954
- for _ in range(depth):
955
- tran_modules = []
956
- tran_modules.append(
957
- Transformer(
958
- depth=time_transformer_depth,
959
- rotary_embed=time_rotary_embed,
960
- **transformer_kwargs,
961
- )
962
- )
963
- tran_modules.append(
964
- Transformer(
965
- depth=freq_transformer_depth,
966
- rotary_embed=freq_rotary_embed,
967
- **transformer_kwargs,
968
- )
969
- )
970
- self.layers.append(nn.ModuleList(tran_modules))
971
-
972
- self.final_norm = RMSNorm(dim)
973
-
974
- self.stft_kwargs = dict(
975
- n_fft=stft_n_fft,
976
- hop_length=stft_hop_length,
977
- win_length=stft_win_length,
978
- normalized=stft_normalized,
979
- )
980
-
981
- self.stft_window_fn = partial(
982
- default(stft_window_fn, torch.hann_window), stft_win_length
983
- )
984
-
985
- freqs = torch.stft(
986
- torch.randn(1, 4096),
987
- **self.stft_kwargs,
988
- window=torch.ones(stft_win_length),
989
- return_complex=True,
990
- ).shape[1]
991
-
992
- assert len(freqs_per_bands) > 1
993
- assert (
994
- sum(freqs_per_bands) == freqs
995
- ), f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
996
-
997
- freqs_per_bands_with_complex = tuple(
998
- 2 * f * self.audio_channels for f in freqs_per_bands
999
- )
1000
-
1001
- self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
1002
-
1003
- self.mask_estimators = nn.ModuleList([])
1004
-
1005
- for _ in range(num_stems):
1006
- mask_estimator = MaskEstimator(
1007
- dim=dim,
1008
- dim_inputs=freqs_per_bands_with_complex,
1009
- depth=mask_estimator_depth,
1010
- mlp_expansion_factor=mlp_expansion_factor,
1011
- )
1012
-
1013
- self.mask_estimators.append(mask_estimator)
1014
-
1015
- self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
1016
- self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
1017
- self.multi_stft_n_fft = stft_n_fft
1018
- self.multi_stft_window_fn = multi_stft_window_fn
1019
-
1020
- self.multi_stft_kwargs = dict(
1021
- hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
1022
- )
1023
-
1024
- def forward(self, raw_audio, target=None, return_loss_breakdown=False):
1025
-
1026
- device = raw_audio.device
1027
-
1028
- x_is_mps = True if device.type == "mps" else False
1029
-
1030
- if raw_audio.ndim == 2:
1031
- raw_audio = rearrange(raw_audio, "b t -> b 1 t")
1032
-
1033
- channels = raw_audio.shape[1]
1034
- assert (not self.stereo and channels == 1) or (
1035
- self.stereo and channels == 2
1036
- ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
1037
-
1038
- raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
1039
-
1040
- stft_window = self.stft_window_fn(device=device)
1041
-
1042
- try:
1043
- stft_repr = torch.stft(
1044
- raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
1045
- )
1046
- except:
1047
- stft_repr = torch.stft(
1048
- raw_audio.cpu() if x_is_mps else raw_audio,
1049
- **self.stft_kwargs,
1050
- window=stft_window.cpu() if x_is_mps else stft_window,
1051
- return_complex=True,
1052
- ).to(device)
1053
- stft_repr = torch.view_as_real(stft_repr)
1054
-
1055
- stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
1056
-
1057
- stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
1058
-
1059
- x = rearrange(stft_repr, "b f t c -> b t (f c)")
1060
-
1061
- x = self.band_split(x)
1062
-
1063
- for i, transformer_block in enumerate(self.layers):
1064
-
1065
- time_transformer, freq_transformer = transformer_block
1066
-
1067
- x = rearrange(x, "b t f d -> b f t d")
1068
- x, ps = pack([x], "* t d")
1069
-
1070
- x = time_transformer(x)
1071
-
1072
- (x,) = unpack(x, ps, "* t d")
1073
- x = rearrange(x, "b f t d -> b t f d")
1074
- x, ps = pack([x], "* f d")
1075
-
1076
- x = freq_transformer(x)
1077
-
1078
- (x,) = unpack(x, ps, "* f d")
1079
-
1080
- x = self.final_norm(x)
1081
-
1082
- num_stems = len(self.mask_estimators)
1083
-
1084
- mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
1085
- mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
1086
-
1087
- stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
1088
-
1089
- stft_repr = torch.view_as_complex(stft_repr)
1090
- mask = torch.view_as_complex(mask)
1091
-
1092
- stft_repr = stft_repr * mask
1093
-
1094
- stft_repr = rearrange(
1095
- stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
1096
- )
1097
-
1098
- try:
1099
- recon_audio = torch.istft(
1100
- stft_repr,
1101
- **self.stft_kwargs,
1102
- window=stft_window,
1103
- return_complex=False,
1104
- length=raw_audio.shape[-1],
1105
- )
1106
- except:
1107
- recon_audio = torch.istft(
1108
- stft_repr.cpu() if x_is_mps else stft_repr,
1109
- **self.stft_kwargs,
1110
- window=stft_window.cpu() if x_is_mps else stft_window,
1111
- return_complex=False,
1112
- length=raw_audio.shape[-1],
1113
- ).to(device)
1114
-
1115
- recon_audio = rearrange(
1116
- recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
1117
- )
1118
-
1119
- if num_stems == 1:
1120
- recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
1121
-
1122
- if not exists(target):
1123
- return recon_audio
1124
-
1125
- if self.num_stems > 1:
1126
- assert target.ndim == 4 and target.shape[1] == self.num_stems
1127
-
1128
- if target.ndim == 2:
1129
- target = rearrange(target, "... t -> ... 1 t")
1130
-
1131
- target = target[..., : recon_audio.shape[-1]]
1132
-
1133
- loss = F.l1_loss(recon_audio, target)
1134
-
1135
- multi_stft_resolution_loss = 0.0
1136
-
1137
- for window_size in self.multi_stft_resolutions_window_sizes:
1138
- res_stft_kwargs = dict(
1139
- n_fft=max(window_size, self.multi_stft_n_fft),
1140
- win_length=window_size,
1141
- return_complex=True,
1142
- window=self.multi_stft_window_fn(window_size, device=device),
1143
- **self.multi_stft_kwargs,
1144
- )
1145
-
1146
- recon_Y = torch.stft(
1147
- rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
1148
- )
1149
- target_Y = torch.stft(
1150
- rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
1151
- )
1152
-
1153
- multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
1154
- recon_Y, target_Y
1155
- )
1156
-
1157
- weighted_multi_resolution_loss = (
1158
- multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
1159
- )
1160
-
1161
- total_loss = loss + weighted_multi_resolution_loss
1162
-
1163
- if not return_loss_breakdown:
1164
- return total_loss
1165
-
1166
- return total_loss, (loss, multi_stft_resolution_loss)
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn, einsum, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ import torch.nn.functional as F
7
+
8
+ from .attend import Attend
9
+
10
+ from torch.utils.checkpoint import checkpoint
11
+
12
+ from beartype.typing import Tuple, Optional, List, Callable
13
+ from beartype import beartype
14
+
15
+ from rotary_embedding_torch import RotaryEmbedding
16
+
17
+ from einops import rearrange, pack, unpack
18
+ from einops.layers.torch import Rearrange
19
+ import torchaudio
20
+
21
+
22
+ def exists(val):
23
+ return val is not None
24
+
25
+
26
+ def default(v, d):
27
+ return v if exists(v) else d
28
+
29
+
30
+ def pack_one(t, pattern):
31
+ return pack([t], pattern)
32
+
33
+
34
+ def unpack_one(t, ps, pattern):
35
+ return unpack(t, ps, pattern)[0]
36
+
37
+
38
+ def l2norm(t):
39
+ return F.normalize(t, dim=-1, p=2)
40
+
41
+
42
+ class RMSNorm(Module):
43
+ def __init__(self, dim):
44
+ super().__init__()
45
+ self.scale = dim**0.5
46
+ self.gamma = nn.Parameter(torch.ones(dim))
47
+
48
+ def forward(self, x):
49
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
50
+
51
+
52
+ class FeedForward(Module):
53
+ def __init__(self, dim, mult=4, dropout=0.0):
54
+ super().__init__()
55
+ dim_inner = int(dim * mult)
56
+ self.net = nn.Sequential(
57
+ RMSNorm(dim),
58
+ nn.Linear(dim, dim_inner),
59
+ nn.GELU(),
60
+ nn.Dropout(dropout),
61
+ nn.Linear(dim_inner, dim),
62
+ nn.Dropout(dropout),
63
+ )
64
+
65
+ def forward(self, x):
66
+ return self.net(x)
67
+
68
+
69
+ class Attention(Module):
70
+ def __init__(
71
+ self,
72
+ dim,
73
+ heads=8,
74
+ dim_head=64,
75
+ dropout=0.0,
76
+ rotary_embed=None,
77
+ flash=True,
78
+ ):
79
+ super().__init__()
80
+ self.heads = heads
81
+ self.scale = dim_head**-0.5
82
+ dim_inner = heads * dim_head
83
+
84
+ self.rotary_embed = rotary_embed
85
+
86
+ self.attend = Attend(flash=flash, dropout=dropout)
87
+
88
+ self.norm = RMSNorm(dim)
89
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
90
+
91
+ self.to_gates = nn.Linear(dim, heads)
92
+
93
+ self.to_out = nn.Sequential(
94
+ nn.Linear(dim_inner, dim, bias=False), nn.Dropout(dropout)
95
+ )
96
+
97
+ def forward(self, x):
98
+ x = self.norm(x)
99
+
100
+ q, k, v = rearrange(
101
+ self.to_qkv(x), "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads
102
+ )
103
+
104
+ if exists(self.rotary_embed):
105
+ q = self.rotary_embed.rotate_queries_or_keys(q)
106
+ k = self.rotary_embed.rotate_queries_or_keys(k)
107
+
108
+ out = self.attend(q, k, v)
109
+
110
+ gates = self.to_gates(x)
111
+ out = out * rearrange(gates, "b n h -> b h n 1").sigmoid()
112
+
113
+ out = rearrange(out, "b h n d -> b n (h d)")
114
+ return self.to_out(out)
115
+
116
+
117
+ class LinearAttention(Module):
118
+
119
+ @beartype
120
+ def __init__(
121
+ self,
122
+ *,
123
+ dim,
124
+ dim_head=32,
125
+ heads=8,
126
+ scale=8,
127
+ flash=True,
128
+ dropout=0.0,
129
+ ):
130
+ super().__init__()
131
+ dim_inner = dim_head * heads
132
+ self.norm = RMSNorm(dim)
133
+
134
+ self.to_qkv = nn.Sequential(
135
+ nn.Linear(dim, dim_inner * 3, bias=False),
136
+ Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
137
+ )
138
+
139
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
140
+
141
+ self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
142
+
143
+ self.to_out = nn.Sequential(
144
+ Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
145
+ )
146
+
147
+ def forward(self, x):
148
+ x = self.norm(x)
149
+
150
+ q, k, v = self.to_qkv(x)
151
+
152
+ q, k = map(l2norm, (q, k))
153
+ q = q * self.temperature.exp()
154
+
155
+ out = self.attend(q, k, v)
156
+
157
+ return self.to_out(out)
158
+
159
+
160
+ class Transformer(Module):
161
+ def __init__(
162
+ self,
163
+ *,
164
+ dim,
165
+ depth,
166
+ dim_head=64,
167
+ heads=8,
168
+ attn_dropout=0.0,
169
+ ff_dropout=0.0,
170
+ ff_mult=4,
171
+ norm_output=True,
172
+ rotary_embed=None,
173
+ flash_attn=True,
174
+ linear_attn=False,
175
+ ):
176
+ super().__init__()
177
+ self.layers = ModuleList([])
178
+
179
+ for _ in range(depth):
180
+ if linear_attn:
181
+ attn = LinearAttention(
182
+ dim=dim,
183
+ dim_head=dim_head,
184
+ heads=heads,
185
+ dropout=attn_dropout,
186
+ flash=flash_attn,
187
+ )
188
+ else:
189
+ attn = Attention(
190
+ dim=dim,
191
+ dim_head=dim_head,
192
+ heads=heads,
193
+ dropout=attn_dropout,
194
+ rotary_embed=rotary_embed,
195
+ flash=flash_attn,
196
+ )
197
+
198
+ self.layers.append(
199
+ ModuleList(
200
+ [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
201
+ )
202
+ )
203
+
204
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
205
+
206
+ def forward(self, x):
207
+
208
+ for attn, ff in self.layers:
209
+ x = attn(x) + x
210
+ x = ff(x) + x
211
+
212
+ return self.norm(x)
213
+
214
+
215
+ class BandSplit(Module):
216
+ @beartype
217
+ def __init__(self, dim, dim_inputs: Tuple[int, ...]):
218
+ super().__init__()
219
+ self.dim_inputs = dim_inputs
220
+ self.to_features = ModuleList([])
221
+
222
+ for dim_in in dim_inputs:
223
+ net = nn.Sequential(RMSNorm(dim_in), nn.Linear(dim_in, dim))
224
+
225
+ self.to_features.append(net)
226
+
227
+ def forward(self, x):
228
+
229
+ x = x.split(self.dim_inputs, dim=-1)
230
+
231
+ outs = []
232
+ for split_input, to_feature in zip(x, self.to_features):
233
+ split_output = to_feature(split_input)
234
+ outs.append(split_output)
235
+
236
+ x = torch.stack(outs, dim=-2)
237
+
238
+ return x
239
+
240
+
241
+ class Conv(nn.Module):
242
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
243
+ super().__init__()
244
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
245
+ self.bn = nn.InstanceNorm2d(c2, affine=True, eps=1e-8)
246
+ self.act = nn.SiLU() if act else nn.Identity()
247
+
248
+ def forward(self, x):
249
+ return self.act(self.bn(self.conv(x)))
250
+
251
+
252
+ def autopad(k, p=None):
253
+ if p is None:
254
+ p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
255
+ return p
256
+
257
+
258
+ class DSConv(nn.Module):
259
+ def __init__(self, c1, c2, k=3, s=1, p=None, act=True):
260
+ super().__init__()
261
+ self.dwconv = nn.Conv2d(c1, c1, k, s, autopad(k, p), groups=c1, bias=False)
262
+ self.pwconv = nn.Conv2d(c1, c2, 1, 1, 0, bias=False)
263
+ self.bn = nn.InstanceNorm2d(c2, affine=True, eps=1e-8)
264
+ self.act = nn.SiLU() if act else nn.Identity()
265
+
266
+ def forward(self, x):
267
+ return self.act(self.bn(self.pwconv(self.dwconv(x))))
268
+
269
+
270
+ class DS_Bottleneck(nn.Module):
271
+ def __init__(self, c1, c2, k=3, shortcut=True):
272
+ super().__init__()
273
+ c_ = c1
274
+ self.dsconv1 = DSConv(c1, c_, k=3, s=1)
275
+ self.dsconv2 = DSConv(c_, c2, k=k, s=1)
276
+ self.shortcut = shortcut and c1 == c2
277
+
278
+ def forward(self, x):
279
+ return (
280
+ x + self.dsconv2(self.dsconv1(x))
281
+ if self.shortcut
282
+ else self.dsconv2(self.dsconv1(x))
283
+ )
284
+
285
+
286
+ class DS_C3k(nn.Module):
287
+ def __init__(self, c1, c2, n=1, k=3, e=0.5):
288
+ super().__init__()
289
+ c_ = int(c2 * e)
290
+ self.cv1 = Conv(c1, c_, 1, 1)
291
+ self.cv2 = Conv(c1, c_, 1, 1)
292
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
293
+ self.m = nn.Sequential(
294
+ *[DS_Bottleneck(c_, c_, k=k, shortcut=True) for _ in range(n)]
295
+ )
296
+
297
+ def forward(self, x):
298
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
299
+
300
+
301
+ class DS_C3k2(nn.Module):
302
+ def __init__(self, c1, c2, n=1, k=3, e=0.5):
303
+ super().__init__()
304
+ c_ = int(c2 * e)
305
+ self.cv1 = Conv(c1, c_, 1, 1)
306
+ self.m = DS_C3k(c_, c_, n=n, k=k, e=1.0)
307
+ self.cv2 = Conv(c_, c2, 1, 1)
308
+
309
+ def forward(self, x):
310
+ x_ = self.cv1(x)
311
+ x_ = self.m(x_)
312
+ return self.cv2(x_)
313
+
314
+
315
+ class AdaptiveHyperedgeGeneration(nn.Module):
316
+ def __init__(self, in_channels, num_hyperedges, num_heads=8):
317
+ super().__init__()
318
+ self.num_hyperedges = num_hyperedges
319
+ self.num_heads = num_heads
320
+ self.head_dim = in_channels // num_heads
321
+
322
+ self.global_proto = nn.Parameter(torch.randn(num_hyperedges, in_channels))
323
+
324
+ self.context_mapper = nn.Linear(
325
+ 2 * in_channels, num_hyperedges * in_channels, bias=False
326
+ )
327
+
328
+ self.query_proj = nn.Linear(in_channels, in_channels, bias=False)
329
+
330
+ self.scale = self.head_dim**-0.5
331
+
332
+ def forward(self, x):
333
+ B, N, C = x.shape
334
+
335
+ f_avg = F.adaptive_avg_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
336
+ f_max = F.adaptive_max_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)
337
+ f_ctx = torch.cat((f_avg, f_max), dim=1)
338
+
339
+ delta_P = self.context_mapper(f_ctx).view(B, self.num_hyperedges, C)
340
+ P = self.global_proto.unsqueeze(0) + delta_P
341
+
342
+ z = self.query_proj(x)
343
+
344
+ z = z.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
345
+
346
+ P = P.view(B, self.num_hyperedges, self.num_heads, self.head_dim).permute(
347
+ 0, 2, 3, 1
348
+ )
349
+
350
+ sim = (z @ P) * self.scale
351
+
352
+ s_bar = sim.mean(dim=1)
353
+
354
+ A = F.softmax(s_bar.permute(0, 2, 1), dim=-1)
355
+
356
+ return A
357
+
358
+
359
+ class HypergraphConvolution(nn.Module):
360
+ def __init__(self, in_channels, out_channels):
361
+ super().__init__()
362
+ self.W_e = nn.Linear(in_channels, in_channels, bias=False)
363
+ self.W_v = nn.Linear(in_channels, out_channels, bias=False)
364
+ self.act = nn.SiLU()
365
+
366
+ def forward(self, x, A):
367
+ f_m = torch.bmm(A, x)
368
+ f_m = self.act(self.W_e(f_m))
369
+
370
+ x_out = torch.bmm(A.transpose(1, 2), f_m)
371
+ x_out = self.act(self.W_v(x_out))
372
+
373
+ return x + x_out
374
+
375
+
376
+ class AdaptiveHypergraphComputation(nn.Module):
377
+ def __init__(self, in_channels, out_channels, num_hyperedges=8, num_heads=8):
378
+ super().__init__()
379
+ self.adaptive_hyperedge_gen = AdaptiveHyperedgeGeneration(
380
+ in_channels, num_hyperedges, num_heads
381
+ )
382
+ self.hypergraph_conv = HypergraphConvolution(in_channels, out_channels)
383
+
384
+ def forward(self, x):
385
+ B, C, H, W = x.shape
386
+ x_flat = x.flatten(2).permute(0, 2, 1)
387
+
388
+ A = self.adaptive_hyperedge_gen(x_flat)
389
+
390
+ x_out_flat = self.hypergraph_conv(x_flat, A)
391
+
392
+ x_out = x_out_flat.permute(0, 2, 1).view(B, -1, H, W)
393
+ return x_out
394
+
395
+
396
+ class C3AH(nn.Module):
397
+ def __init__(self, c1, c2, num_hyperedges=8, num_heads=8, e=0.5):
398
+ super().__init__()
399
+ c_ = int(c1 * e)
400
+ self.cv1 = Conv(c1, c_, 1, 1)
401
+ self.cv2 = Conv(c1, c_, 1, 1)
402
+ self.ahc = AdaptiveHypergraphComputation(c_, c_, num_hyperedges, num_heads)
403
+ self.cv3 = Conv(2 * c_, c2, 1, 1)
404
+
405
+ def forward(self, x):
406
+ x_lateral = self.cv1(x)
407
+ x_ahc = self.ahc(self.cv2(x))
408
+ return self.cv3(torch.cat((x_ahc, x_lateral), dim=1))
409
+
410
+
411
+ class HyperACE(nn.Module):
412
+ def __init__(
413
+ self,
414
+ in_channels: List[int],
415
+ out_channels: int,
416
+ num_hyperedges=8,
417
+ num_heads=8,
418
+ k=2,
419
+ l=1,
420
+ c_h=0.5,
421
+ c_l=0.25,
422
+ ):
423
+ super().__init__()
424
+
425
+ c2, c3, c4, c5 = in_channels
426
+ c_mid = c4
427
+
428
+ self.fuse_conv = Conv(c2 + c3 + c4 + c5, c_mid, 1, 1)
429
+
430
+ self.c_h = int(c_mid * c_h)
431
+ self.c_l = int(c_mid * c_l)
432
+ self.c_s = c_mid - self.c_h - self.c_l
433
+ assert self.c_s > 0, "Channel split error"
434
+
435
+ self.high_order_branch = nn.ModuleList(
436
+ [
437
+ C3AH(self.c_h, self.c_h, num_hyperedges, num_heads, e=1.0)
438
+ for _ in range(k)
439
+ ]
440
+ )
441
+ self.high_order_fuse = Conv(self.c_h * k, self.c_h, 1, 1)
442
+
443
+ self.low_order_branch = nn.Sequential(
444
+ *[DS_C3k(self.c_l, self.c_l, n=1, k=3, e=1.0) for _ in range(l)]
445
+ )
446
+
447
+ self.final_fuse = Conv(self.c_h + self.c_l + self.c_s, out_channels, 1, 1)
448
+
449
+ def forward(self, x: List[torch.Tensor]) -> torch.Tensor:
450
+ B2, B3, B4, B5 = x
451
+
452
+ B, _, H4, W4 = B4.shape
453
+
454
+ B2_resized = F.interpolate(
455
+ B2, size=(H4, W4), mode="bilinear", align_corners=False
456
+ )
457
+ B3_resized = F.interpolate(
458
+ B3, size=(H4, W4), mode="bilinear", align_corners=False
459
+ )
460
+ B5_resized = F.interpolate(
461
+ B5, size=(H4, W4), mode="bilinear", align_corners=False
462
+ )
463
+
464
+ x_b = self.fuse_conv(torch.cat((B2_resized, B3_resized, B4, B5_resized), dim=1))
465
+
466
+ x_h, x_l, x_s = torch.split(x_b, [self.c_h, self.c_l, self.c_s], dim=1)
467
+
468
+ x_h_outs = [m(x_h) for m in self.high_order_branch]
469
+ x_h_fused = self.high_order_fuse(torch.cat(x_h_outs, dim=1))
470
+
471
+ x_l_out = self.low_order_branch(x_l)
472
+
473
+ y = self.final_fuse(torch.cat((x_h_fused, x_l_out, x_s), dim=1))
474
+
475
+ return y
476
+
477
+
478
+ class GatedFusion(nn.Module):
479
+ def __init__(self, in_channels):
480
+ super().__init__()
481
+ self.gamma = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
482
+
483
+ def forward(self, f_in, h):
484
+ if f_in.shape[1] != h.shape[1]:
485
+ raise ValueError(f"Channel mismatch: f_in={f_in.shape}, h={h.shape}")
486
+ return f_in + self.gamma * h
487
+
488
+
489
+ class Backbone(nn.Module):
490
+ def __init__(self, in_channels=256, base_channels=64, base_depth=3):
491
+ super().__init__()
492
+ c = base_channels
493
+ c2 = base_channels
494
+ c3 = 256
495
+ c4 = 384
496
+ c5 = 512
497
+ c6 = 768
498
+
499
+ self.stem = DSConv(in_channels, c2, k=3, s=(2, 1), p=1)
500
+
501
+ self.p2 = nn.Sequential(
502
+ DSConv(c2, c3, k=3, s=(2, 1), p=1), DS_C3k2(c3, c3, n=base_depth)
503
+ )
504
+
505
+ self.p3 = nn.Sequential(
506
+ DSConv(c3, c4, k=3, s=(2, 1), p=1), DS_C3k2(c4, c4, n=base_depth * 2)
507
+ )
508
+
509
+ self.p4 = nn.Sequential(
510
+ DSConv(c4, c5, k=3, s=2, p=1), DS_C3k2(c5, c5, n=base_depth * 2)
511
+ )
512
+
513
+ self.p5 = nn.Sequential(
514
+ DSConv(c5, c6, k=3, s=2, p=1), DS_C3k2(c6, c6, n=base_depth)
515
+ )
516
+
517
+ self.out_channels = [c3, c4, c5, c6]
518
+
519
+ def forward(self, x):
520
+ x = self.stem(x)
521
+ x2 = self.p2(x)
522
+ x3 = self.p3(x2)
523
+ x4 = self.p4(x3)
524
+ x5 = self.p5(x4)
525
+ return [x2, x3, x4, x5]
526
+
527
+
528
+ class Decoder(nn.Module):
529
+ def __init__(
530
+ self,
531
+ encoder_channels: List[int],
532
+ hyperace_out_c: int,
533
+ decoder_channels: List[int],
534
+ ):
535
+ super().__init__()
536
+ c_p2, c_p3, c_p4, c_p5 = encoder_channels
537
+ c_d2, c_d3, c_d4, c_d5 = decoder_channels
538
+
539
+ self.h_to_d5 = Conv(hyperace_out_c, c_d5, 1, 1)
540
+ self.h_to_d4 = Conv(hyperace_out_c, c_d4, 1, 1)
541
+ self.h_to_d3 = Conv(hyperace_out_c, c_d3, 1, 1)
542
+ self.h_to_d2 = Conv(hyperace_out_c, c_d2, 1, 1)
543
+
544
+ self.fusion_d5 = GatedFusion(c_d5)
545
+ self.fusion_d4 = GatedFusion(c_d4)
546
+ self.fusion_d3 = GatedFusion(c_d3)
547
+ self.fusion_d2 = GatedFusion(c_d2)
548
+
549
+ self.skip_p5 = Conv(c_p5, c_d5, 1, 1)
550
+ self.skip_p4 = Conv(c_p4, c_d4, 1, 1)
551
+ self.skip_p3 = Conv(c_p3, c_d3, 1, 1)
552
+ self.skip_p2 = Conv(c_p2, c_d2, 1, 1)
553
+
554
+ self.up_d5 = DS_C3k2(c_d5, c_d4, n=1)
555
+ self.up_d4 = DS_C3k2(c_d4, c_d3, n=1)
556
+ self.up_d3 = DS_C3k2(c_d3, c_d2, n=1)
557
+
558
+ self.final_d2 = DS_C3k2(c_d2, c_d2, n=1)
559
+
560
+ def forward(self, enc_feats: List[torch.Tensor], h_ace: torch.Tensor):
561
+ p2, p3, p4, p5 = enc_feats
562
+
563
+ d5 = self.skip_p5(p5)
564
+ h_d5 = self.h_to_d5(F.interpolate(h_ace, size=d5.shape[2:], mode="bilinear"))
565
+ d5 = self.fusion_d5(d5, h_d5)
566
+
567
+ d5_up = F.interpolate(d5, size=p4.shape[2:], mode="bilinear")
568
+ d4_skip = self.skip_p4(p4)
569
+ d4 = self.up_d5(d5_up) + d4_skip
570
+
571
+ h_d4 = self.h_to_d4(F.interpolate(h_ace, size=d4.shape[2:], mode="bilinear"))
572
+ d4 = self.fusion_d4(d4, h_d4)
573
+
574
+ d4_up = F.interpolate(d4, size=p3.shape[2:], mode="bilinear")
575
+ d3_skip = self.skip_p3(p3)
576
+ d3 = self.up_d4(d4_up) + d3_skip
577
+
578
+ h_d3 = self.h_to_d3(F.interpolate(h_ace, size=d3.shape[2:], mode="bilinear"))
579
+ d3 = self.fusion_d3(d3, h_d3)
580
+
581
+ d3_up = F.interpolate(d3, size=p2.shape[2:], mode="bilinear")
582
+ d2_skip = self.skip_p2(p2)
583
+ d2 = self.up_d3(d3_up) + d2_skip
584
+
585
+ h_d2 = self.h_to_d2(F.interpolate(h_ace, size=d2.shape[2:], mode="bilinear"))
586
+ d2 = self.fusion_d2(d2, h_d2)
587
+
588
+ d2_final = self.final_d2(d2)
589
+
590
+ return d2_final
591
+
592
+
593
+ class TFC_TDF(nn.Module):
594
+ def __init__(self, in_c, c, l, f, bn=4):
595
+ super().__init__()
596
+
597
+ self.blocks = nn.ModuleList()
598
+ for i in range(l):
599
+ block = nn.Module()
600
+
601
+ block.tfc1 = nn.Sequential(
602
+ nn.InstanceNorm2d(in_c, affine=True, eps=1e-8),
603
+ nn.SiLU(),
604
+ nn.Conv2d(in_c, c, 3, 1, 1, bias=False),
605
+ )
606
+ block.tdf = nn.Sequential(
607
+ nn.InstanceNorm2d(c, affine=True, eps=1e-8),
608
+ nn.SiLU(),
609
+ nn.Linear(f, f // bn, bias=False),
610
+ nn.InstanceNorm2d(c, affine=True, eps=1e-8),
611
+ nn.SiLU(),
612
+ nn.Linear(f // bn, f, bias=False),
613
+ )
614
+ block.tfc2 = nn.Sequential(
615
+ nn.InstanceNorm2d(c, affine=True, eps=1e-8),
616
+ nn.SiLU(),
617
+ nn.Conv2d(c, c, 3, 1, 1, bias=False),
618
+ )
619
+ block.shortcut = nn.Conv2d(in_c, c, 1, 1, 0, bias=False)
620
+
621
+ self.blocks.append(block)
622
+ in_c = c
623
+
624
+ def forward(self, x):
625
+ for block in self.blocks:
626
+ s = block.shortcut(x)
627
+ x = block.tfc1(x)
628
+ x = x + block.tdf(x)
629
+ x = block.tfc2(x)
630
+ x = x + s
631
+ return x
632
+
633
+
634
+ class FreqPixelShuffle(nn.Module):
635
+ def __init__(self, in_channels, out_channels, scale, f):
636
+ super().__init__()
637
+ self.scale = scale
638
+ self.conv = DSConv(in_channels, out_channels * scale)
639
+ self.out_conv = TFC_TDF(out_channels, out_channels, 2, f)
640
+
641
+ def forward(self, x):
642
+ x = self.conv(x)
643
+ B, C_r, H, W = x.shape
644
+ out_c = C_r // self.scale
645
+
646
+ x = x.view(B, out_c, self.scale, H, W)
647
+
648
+ x = x.permute(0, 1, 3, 4, 2).contiguous()
649
+ x = x.view(B, out_c, H, W * self.scale)
650
+
651
+ return self.out_conv(x)
652
+
653
+
654
+ class ProgressiveUpsampleHead(nn.Module):
655
+ def __init__(self, in_channels, out_channels, target_bins=1025, in_bands=62):
656
+ super().__init__()
657
+ self.target_bins = target_bins
658
+
659
+ c = in_channels
660
+
661
+ self.block1 = FreqPixelShuffle(c, c // 2, scale=2, f=in_bands * 2)
662
+ self.block2 = FreqPixelShuffle(c // 2, c // 4, scale=2, f=in_bands * 4)
663
+ self.block3 = FreqPixelShuffle(c // 4, c // 8, scale=2, f=in_bands * 8)
664
+ self.block4 = FreqPixelShuffle(c // 8, c // 16, scale=2, f=in_bands * 16)
665
+
666
+ self.final_conv = nn.Conv2d(
667
+ c // 16, out_channels, kernel_size=3, stride=1, padding="same", bias=False
668
+ )
669
+
670
+ def forward(self, x):
671
+
672
+ x = self.block1(x)
673
+ x = self.block2(x)
674
+ x = self.block3(x)
675
+ x = self.block4(x)
676
+
677
+ if x.shape[-1] != self.target_bins:
678
+ x = F.interpolate(
679
+ x,
680
+ size=(x.shape[2], self.target_bins),
681
+ mode="bilinear",
682
+ align_corners=False,
683
+ )
684
+
685
+ x = self.final_conv(x)
686
+ return x
687
+
688
+
689
+ class SegmModel(nn.Module):
690
+ def __init__(
691
+ self,
692
+ in_bands=62,
693
+ in_dim=256,
694
+ out_bins=1025,
695
+ out_channels=4,
696
+ base_channels=64,
697
+ base_depth=2,
698
+ num_hyperedges=32,
699
+ num_heads=8,
700
+ ):
701
+ super().__init__()
702
+
703
+ self.backbone = Backbone(
704
+ in_channels=in_dim, base_channels=base_channels, base_depth=base_depth
705
+ )
706
+ enc_channels = self.backbone.out_channels
707
+ c2, c3, c4, c5 = enc_channels
708
+
709
+ hyperace_in_channels = enc_channels
710
+ hyperace_out_channels = c4
711
+ self.hyperace = HyperACE(
712
+ hyperace_in_channels,
713
+ hyperace_out_channels,
714
+ num_hyperedges,
715
+ num_heads,
716
+ k=2,
717
+ l=1,
718
+ )
719
+
720
+ decoder_channels = [c2, c3, c4, c5]
721
+ self.decoder = Decoder(enc_channels, hyperace_out_channels, decoder_channels)
722
+
723
+ self.upsample_head = ProgressiveUpsampleHead(
724
+ in_channels=decoder_channels[0],
725
+ out_channels=out_channels,
726
+ target_bins=out_bins,
727
+ in_bands=in_bands,
728
+ )
729
+
730
+ def forward(self, x):
731
+ H, W = x.shape[2:]
732
+
733
+ enc_feats = self.backbone(x)
734
+
735
+ h_ace_feats = self.hyperace(enc_feats)
736
+
737
+ dec_feat = self.decoder(enc_feats, h_ace_feats)
738
+
739
+ feat_time_restored = F.interpolate(
740
+ dec_feat, size=(H, dec_feat.shape[-1]), mode="bilinear", align_corners=False
741
+ )
742
+
743
+ out = self.upsample_head(feat_time_restored)
744
+
745
+ return out
746
+
747
+
748
+ def MLP(dim_in, dim_out, dim_hidden=None, depth=1, activation=nn.Tanh):
749
+ dim_hidden = default(dim_hidden, dim_in)
750
+
751
+ net = []
752
+ dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
753
+
754
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
755
+ is_last = ind == (len(dims) - 2)
756
+
757
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
758
+
759
+ if is_last:
760
+ continue
761
+
762
+ net.append(activation())
763
+
764
+ return nn.Sequential(*net)
765
+
766
+
767
+ class MaskEstimator(Module):
768
+ @beartype
769
+ def __init__(self, dim, dim_inputs: Tuple[int, ...], depth, mlp_expansion_factor=4):
770
+ super().__init__()
771
+ self.dim_inputs = dim_inputs
772
+ self.to_freqs = ModuleList([])
773
+ dim_hidden = dim * mlp_expansion_factor
774
+
775
+ for dim_in in dim_inputs:
776
+ net = []
777
+
778
+ mlp = nn.Sequential(
779
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1)
780
+ )
781
+
782
+ self.to_freqs.append(mlp)
783
+
784
+ self.segm = SegmModel(
785
+ in_bands=len(dim_inputs), in_dim=dim, out_bins=sum(dim_inputs) // 4
786
+ )
787
+
788
+ def forward(self, x):
789
+ y = rearrange(x, "b t f c -> b c t f")
790
+ y = self.segm(y)
791
+ y = rearrange(y, "b c t f -> b t (f c)")
792
+
793
+ x = x.unbind(dim=-2)
794
+
795
+ outs = []
796
+
797
+ for band_features, mlp in zip(x, self.to_freqs):
798
+ freq_out = mlp(band_features)
799
+ outs.append(freq_out)
800
+
801
+ return torch.cat(outs, dim=-1) + y
802
+
803
+
804
+ DEFAULT_FREQS_PER_BANDS = (
805
+ 2,
806
+ 2,
807
+ 2,
808
+ 2,
809
+ 2,
810
+ 2,
811
+ 2,
812
+ 2,
813
+ 2,
814
+ 2,
815
+ 2,
816
+ 2,
817
+ 2,
818
+ 2,
819
+ 2,
820
+ 2,
821
+ 2,
822
+ 2,
823
+ 2,
824
+ 2,
825
+ 2,
826
+ 2,
827
+ 2,
828
+ 2,
829
+ 4,
830
+ 4,
831
+ 4,
832
+ 4,
833
+ 4,
834
+ 4,
835
+ 4,
836
+ 4,
837
+ 4,
838
+ 4,
839
+ 4,
840
+ 4,
841
+ 12,
842
+ 12,
843
+ 12,
844
+ 12,
845
+ 12,
846
+ 12,
847
+ 12,
848
+ 12,
849
+ 24,
850
+ 24,
851
+ 24,
852
+ 24,
853
+ 24,
854
+ 24,
855
+ 24,
856
+ 24,
857
+ 48,
858
+ 48,
859
+ 48,
860
+ 48,
861
+ 48,
862
+ 48,
863
+ 48,
864
+ 48,
865
+ 128,
866
+ 129,
867
+ )
868
+
869
+
870
+ class BSRoformerHyperACE_2(Module):
871
+
872
+ @beartype
873
+ def __init__(
874
+ self,
875
+ dim,
876
+ *,
877
+ depth,
878
+ stereo=False,
879
+ num_stems=1,
880
+ time_transformer_depth=2,
881
+ freq_transformer_depth=2,
882
+ linear_transformer_depth=0,
883
+ freqs_per_bands: Tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
884
+ dim_head=64,
885
+ heads=8,
886
+ attn_dropout=0.0,
887
+ ff_dropout=0.0,
888
+ flash_attn=True,
889
+ dim_freqs_in=1025,
890
+ stft_n_fft=2048,
891
+ stft_hop_length=512,
892
+ stft_win_length=2048,
893
+ stft_normalized=False,
894
+ stft_window_fn: Optional[Callable] = None,
895
+ mask_estimator_depth=2,
896
+ multi_stft_resolution_loss_weight=1.0,
897
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (
898
+ 4096,
899
+ 2048,
900
+ 1024,
901
+ 512,
902
+ 256,
903
+ ),
904
+ multi_stft_hop_size=147,
905
+ multi_stft_normalized=False,
906
+ multi_stft_window_fn: Callable = torch.hann_window,
907
+ mlp_expansion_factor=4,
908
+ use_torch_checkpoint=False,
909
+ skip_connection=False,
910
+ **kwargs
911
+ ):
912
+ super().__init__()
913
+
914
+ self.stereo = stereo
915
+ self.audio_channels = 2 if stereo else 1
916
+ self.num_stems = num_stems
917
+ self.use_torch_checkpoint = use_torch_checkpoint
918
+ self.skip_connection = skip_connection
919
+
920
+ self.layers = ModuleList([])
921
+
922
+ transformer_kwargs = dict(
923
+ dim=dim,
924
+ heads=heads,
925
+ dim_head=dim_head,
926
+ attn_dropout=attn_dropout,
927
+ ff_dropout=ff_dropout,
928
+ flash_attn=flash_attn,
929
+ norm_output=False,
930
+ )
931
+
932
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
933
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
934
+
935
+ for _ in range(depth):
936
+ tran_modules = []
937
+ tran_modules.append(
938
+ Transformer(
939
+ depth=time_transformer_depth,
940
+ rotary_embed=time_rotary_embed,
941
+ **transformer_kwargs,
942
+ )
943
+ )
944
+ tran_modules.append(
945
+ Transformer(
946
+ depth=freq_transformer_depth,
947
+ rotary_embed=freq_rotary_embed,
948
+ **transformer_kwargs,
949
+ )
950
+ )
951
+ self.layers.append(nn.ModuleList(tran_modules))
952
+
953
+ self.final_norm = RMSNorm(dim)
954
+
955
+ self.stft_kwargs = dict(
956
+ n_fft=stft_n_fft,
957
+ hop_length=stft_hop_length,
958
+ win_length=stft_win_length,
959
+ normalized=stft_normalized,
960
+ )
961
+
962
+ self.stft_window_fn = partial(
963
+ default(stft_window_fn, torch.hann_window), stft_win_length
964
+ )
965
+
966
+ freqs = torch.stft(
967
+ torch.randn(1, 4096),
968
+ **self.stft_kwargs,
969
+ window=torch.ones(stft_win_length),
970
+ return_complex=True,
971
+ ).shape[1]
972
+
973
+ assert len(freqs_per_bands) > 1
974
+ assert (
975
+ sum(freqs_per_bands) == freqs
976
+ ), f"the number of freqs in the bands must equal {freqs} based on the STFT settings, but got {sum(freqs_per_bands)}"
977
+
978
+ freqs_per_bands_with_complex = tuple(
979
+ 2 * f * self.audio_channels for f in freqs_per_bands
980
+ )
981
+
982
+ self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
983
+
984
+ self.mask_estimators = nn.ModuleList([])
985
+
986
+ for _ in range(num_stems):
987
+ mask_estimator = MaskEstimator(
988
+ dim=dim,
989
+ dim_inputs=freqs_per_bands_with_complex,
990
+ depth=mask_estimator_depth,
991
+ mlp_expansion_factor=mlp_expansion_factor,
992
+ )
993
+
994
+ self.mask_estimators.append(mask_estimator)
995
+
996
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
997
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
998
+ self.multi_stft_n_fft = stft_n_fft
999
+ self.multi_stft_window_fn = multi_stft_window_fn
1000
+
1001
+ self.multi_stft_kwargs = dict(
1002
+ hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
1003
+ )
1004
+
1005
+ def forward(self, raw_audio, target=None, return_loss_breakdown=False):
1006
+
1007
+ device = raw_audio.device
1008
+
1009
+ x_is_mps = True if device.type == "mps" else False
1010
+
1011
+ if raw_audio.ndim == 2:
1012
+ raw_audio = rearrange(raw_audio, "b t -> b 1 t")
1013
+
1014
+ channels = raw_audio.shape[1]
1015
+ assert (not self.stereo and channels == 1) or (
1016
+ self.stereo and channels == 2
1017
+ ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
1018
+
1019
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, "* t")
1020
+
1021
+ stft_window = self.stft_window_fn(device=device)
1022
+
1023
+ try:
1024
+ stft_repr = torch.stft(
1025
+ raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
1026
+ )
1027
+ except:
1028
+ stft_repr = torch.stft(
1029
+ raw_audio.cpu() if x_is_mps else raw_audio,
1030
+ **self.stft_kwargs,
1031
+ window=stft_window.cpu() if x_is_mps else stft_window,
1032
+ return_complex=True,
1033
+ ).to(device)
1034
+ stft_repr = torch.view_as_real(stft_repr)
1035
+
1036
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, "* f t c")
1037
+
1038
+ stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
1039
+
1040
+ x = rearrange(stft_repr, "b f t c -> b t (f c)")
1041
+
1042
+ x = self.band_split(x)
1043
+
1044
+ for i, transformer_block in enumerate(self.layers):
1045
+
1046
+ time_transformer, freq_transformer = transformer_block
1047
+
1048
+ x = rearrange(x, "b t f d -> b f t d")
1049
+ x, ps = pack([x], "* t d")
1050
+
1051
+ x = time_transformer(x)
1052
+
1053
+ (x,) = unpack(x, ps, "* t d")
1054
+ x = rearrange(x, "b f t d -> b t f d")
1055
+ x, ps = pack([x], "* f d")
1056
+
1057
+ x = freq_transformer(x)
1058
+
1059
+ (x,) = unpack(x, ps, "* f d")
1060
+
1061
+ x = self.final_norm(x)
1062
+
1063
+ num_stems = len(self.mask_estimators)
1064
+
1065
+ mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
1066
+ mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
1067
+
1068
+ stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
1069
+
1070
+ stft_repr = torch.view_as_complex(stft_repr)
1071
+ mask = torch.view_as_complex(mask)
1072
+
1073
+ stft_repr = stft_repr * mask
1074
+
1075
+ stft_repr = rearrange(
1076
+ stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
1077
+ )
1078
+
1079
+ try:
1080
+ recon_audio = torch.istft(
1081
+ stft_repr,
1082
+ **self.stft_kwargs,
1083
+ window=stft_window,
1084
+ return_complex=False,
1085
+ length=raw_audio.shape[-1],
1086
+ )
1087
+ except:
1088
+ recon_audio = torch.istft(
1089
+ stft_repr.cpu() if x_is_mps else stft_repr,
1090
+ **self.stft_kwargs,
1091
+ window=stft_window.cpu() if x_is_mps else stft_window,
1092
+ return_complex=False,
1093
+ length=raw_audio.shape[-1],
1094
+ ).to(device)
1095
+
1096
+ recon_audio = rearrange(
1097
+ recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
1098
+ )
1099
+
1100
+ if num_stems == 1:
1101
+ recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
1102
+
1103
+ if not exists(target):
1104
+ return recon_audio
1105
+
1106
+ if self.num_stems > 1:
1107
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
1108
+
1109
+ if target.ndim == 2:
1110
+ target = rearrange(target, "... t -> ... 1 t")
1111
+
1112
+ target = target[..., : recon_audio.shape[-1]]
1113
+
1114
+ loss = F.l1_loss(recon_audio, target)
1115
+
1116
+ multi_stft_resolution_loss = 0.0
1117
+
1118
+ for window_size in self.multi_stft_resolutions_window_sizes:
1119
+ res_stft_kwargs = dict(
1120
+ n_fft=max(window_size, self.multi_stft_n_fft),
1121
+ win_length=window_size,
1122
+ return_complex=True,
1123
+ window=self.multi_stft_window_fn(window_size, device=device),
1124
+ **self.multi_stft_kwargs,
1125
+ )
1126
+
1127
+ recon_Y = torch.stft(
1128
+ rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
1129
+ )
1130
+ target_Y = torch.stft(
1131
+ rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
1132
+ )
1133
+
1134
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
1135
+ recon_Y, target_Y
1136
+ )
1137
+
1138
+ weighted_multi_resolution_loss = (
1139
+ multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
1140
+ )
1141
+
1142
+ total_loss = loss + weighted_multi_resolution_loss
1143
+
1144
+ if not return_loss_breakdown:
1145
+ return total_loss
1146
+
1147
+ return total_loss, (loss, multi_stft_resolution_loss)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/bs_roformer/bs_roformer_sw.py CHANGED
@@ -1,676 +1,657 @@
1
- from __future__ import annotations
2
-
3
- from functools import partial
4
-
5
- import torch
6
- import torch.nn.functional as F
7
- from beartype import beartype
8
- from beartype.typing import Callable
9
- from einops import pack, rearrange, unpack
10
- from einops.layers.torch import Rearrange
11
- from torch import nn
12
- from torch.nn import Module, ModuleList
13
- from torch.utils.checkpoint import checkpoint
14
-
15
- from .attend import Attend
16
-
17
- try:
18
- from .attend_sage import AttendSage
19
- except ImportError:
20
- pass
21
-
22
-
23
- def l2norm(t):
24
- return F.normalize(t, dim=-1, p=2)
25
-
26
-
27
- class CustomNorm(Module):
28
- def __init__(self, dim, eps: float = 5.960464477539063e-08):
29
- super().__init__()
30
- self.scale = dim**0.5
31
- self.gamma = nn.Parameter(torch.ones(dim))
32
- self.eps = eps
33
-
34
- def forward(self, x):
35
- l2_norm = torch.linalg.norm(x, dim=-1, keepdim=True)
36
- denom = torch.maximum(l2_norm, torch.full_like(l2_norm, self.eps))
37
- normalized_x = x / denom
38
- return normalized_x * self.scale * self.gamma
39
-
40
-
41
- class RotaryEmbedding(nn.Module):
42
- def __init__(self, cos_emb, sin_emb):
43
- super().__init__()
44
- self.cos_emb = cos_emb
45
- self.sin_emb = sin_emb
46
-
47
- def rotate_half(self, x):
48
- x = rearrange(x, "... (d r) -> ... d r", r=2)
49
- x1, x2 = x.unbind(dim=-1)
50
- x = torch.stack((-x2, x1), dim=-1)
51
- return rearrange(x, "... d r -> ... (d r)")
52
-
53
- def forward(self, x):
54
- cos_b = self.cos_emb.unsqueeze(0).unsqueeze(0).to(x.device, x.dtype)
55
- sin_b = self.sin_emb.unsqueeze(0).unsqueeze(0).to(x.device, x.dtype)
56
-
57
- term1 = x * cos_b
58
- term2 = self.rotate_half(x) * sin_b
59
-
60
- sum = term1.to(torch.float32) + term2.to(torch.float32)
61
- return sum.to(x.dtype)
62
-
63
-
64
- class FeedForward(Module):
65
- def __init__(self, dim, mult=4, dropout=0.0):
66
- super().__init__()
67
- dim_inner = int(dim * mult)
68
- self.net = nn.Sequential(
69
- CustomNorm(dim),
70
- nn.Linear(dim, dim_inner),
71
- nn.GELU(),
72
- nn.Dropout(dropout),
73
- nn.Linear(dim_inner, dim),
74
- nn.Dropout(dropout),
75
- )
76
-
77
- def forward(self, x):
78
- return self.net(x)
79
-
80
-
81
- class Attention(Module):
82
- def __init__(
83
- self,
84
- dim,
85
- heads=8,
86
- dim_head=64,
87
- dropout=0.0,
88
- shared_qkv_bias=None,
89
- shared_out_bias=None,
90
- rotary_embed: RotaryEmbedding | None = None,
91
- flash=True,
92
- sage_attention=False,
93
- ):
94
- super().__init__()
95
- self.heads = heads
96
- self.scale = dim_head**-0.5
97
- dim_inner = heads * dim_head
98
-
99
- self.rotary_embed = rotary_embed
100
-
101
- if sage_attention:
102
- self.attend = AttendSage(flash=flash, dropout=dropout) # type: ignore
103
- else:
104
- self.attend = Attend(flash=flash, dropout=dropout) # type: ignore
105
-
106
- self.norm = CustomNorm(dim)
107
- self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=(shared_qkv_bias is not None))
108
- if shared_qkv_bias is not None:
109
- self.to_qkv.bias = shared_qkv_bias
110
-
111
- self.to_gates = nn.Linear(dim, heads)
112
-
113
- self.to_out = nn.Sequential(
114
- nn.Linear(dim_inner, dim, bias=(shared_out_bias is not None)),
115
- nn.Dropout(dropout),
116
- )
117
- if shared_out_bias is not None:
118
- self.to_out[0].bias = shared_out_bias
119
-
120
- def forward(self, x):
121
- x = self.norm(x)
122
-
123
- qkv = self.to_qkv(x)
124
- q, k, v = rearrange(qkv, "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)
125
-
126
- if self.rotary_embed is not None:
127
- q = self.rotary_embed(q)
128
- k = self.rotary_embed(k)
129
-
130
- out = self.attend(q, k, v)
131
-
132
- gates = self.to_gates(x)
133
- gate_act = gates.sigmoid()
134
-
135
- out = out * rearrange(gate_act, "b n h -> b h n 1")
136
-
137
- out = rearrange(out, "b h n d -> b n (h d)")
138
- out = self.to_out(out)
139
- return out
140
-
141
-
142
- class LinearAttention(Module):
143
-
144
- @beartype
145
- def __init__(
146
- self,
147
- *,
148
- dim,
149
- dim_head=32,
150
- heads=8,
151
- scale=8,
152
- flash=True,
153
- dropout=0.0,
154
- sage_attention=False,
155
- ):
156
- super().__init__()
157
- dim_inner = dim_head * heads
158
- self.norm = CustomNorm(dim)
159
-
160
- self.to_qkv = nn.Sequential(
161
- nn.Linear(dim, dim_inner * 3, bias=False),
162
- Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
163
- )
164
-
165
- self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
166
-
167
- if sage_attention:
168
- self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash) # type: ignore
169
- else:
170
- self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
171
-
172
- self.to_out = nn.Sequential(
173
- Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
174
- )
175
-
176
- def forward(self, x):
177
- x = self.norm(x)
178
-
179
- q, k, v = self.to_qkv(x)
180
-
181
- q, k = map(l2norm, (q, k))
182
- q = q * self.temperature.exp()
183
-
184
- out = self.attend(q, k, v)
185
-
186
- return self.to_out(out)
187
-
188
-
189
- class Transformer(Module):
190
- def __init__(
191
- self,
192
- *,
193
- dim,
194
- depth,
195
- dim_head=64,
196
- heads=8,
197
- attn_dropout=0.0,
198
- ff_dropout=0.0,
199
- ff_mult=4,
200
- norm_output=True,
201
- rotary_embed: RotaryEmbedding | None = None,
202
- flash_attn=True,
203
- linear_attn=False,
204
- sage_attention=False,
205
- shared_qkv_bias=None,
206
- shared_out_bias=None,
207
- ):
208
- super().__init__()
209
- self.layers = ModuleList([])
210
-
211
- for _ in range(depth):
212
- attn: LinearAttention | Attention
213
- if linear_attn:
214
- attn = LinearAttention(
215
- dim=dim,
216
- dim_head=dim_head,
217
- heads=heads,
218
- dropout=attn_dropout,
219
- flash=flash_attn,
220
- sage_attention=sage_attention,
221
- )
222
- else:
223
- attn = Attention(
224
- dim=dim,
225
- dim_head=dim_head,
226
- heads=heads,
227
- dropout=attn_dropout,
228
- shared_qkv_bias=shared_qkv_bias,
229
- shared_out_bias=shared_out_bias,
230
- rotary_embed=rotary_embed,
231
- flash=flash_attn,
232
- sage_attention=sage_attention,
233
- )
234
-
235
- self.layers.append(
236
- ModuleList(
237
- [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
238
- )
239
- )
240
-
241
- self.norm = CustomNorm(dim) if norm_output else nn.Identity()
242
-
243
- def forward(self, x):
244
- for attn, ff in self.layers: # type: ignore
245
- x = attn(x) + x
246
- x = ff(x) + x
247
- return self.norm(x)
248
-
249
-
250
- class BandSplit(Module):
251
- @beartype
252
- def __init__(self, dim, dim_inputs: tuple[int, ...]):
253
- super().__init__()
254
- self.dim_inputs = dim_inputs
255
- self.to_features = ModuleList([])
256
-
257
- for dim_in in dim_inputs:
258
- net = nn.Sequential(CustomNorm(dim_in), nn.Linear(dim_in, dim))
259
-
260
- self.to_features.append(net)
261
-
262
- def forward(self, x):
263
- x = x.split(self.dim_inputs, dim=-1)
264
-
265
- outs = []
266
- for split_input, to_feature in zip(x, self.to_features):
267
- split_output = to_feature(split_input)
268
- outs.append(split_output)
269
-
270
- return torch.stack(outs, dim=-2)
271
-
272
-
273
- def MLP(
274
- dim_in: int,
275
- dim_out: int,
276
- dim_hidden: int | None = None,
277
- depth: int = 1,
278
- activation=nn.Tanh,
279
- ):
280
- dim_hidden = dim_hidden or dim_in
281
-
282
- net = []
283
- dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
284
-
285
- for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
286
- is_last = ind == (len(dims) - 2)
287
-
288
- net.append(nn.Linear(layer_dim_in, layer_dim_out))
289
-
290
- if is_last:
291
- continue
292
-
293
- net.append(activation())
294
-
295
- return nn.Sequential(*net)
296
-
297
-
298
- class MaskEstimator(Module):
299
- @beartype
300
- def __init__(self, dim, dim_inputs: tuple[int, ...], depth, mlp_expansion_factor=4):
301
- super().__init__()
302
- self.dim_inputs = dim_inputs
303
- self.to_freqs = ModuleList([])
304
- dim_hidden = dim * mlp_expansion_factor
305
-
306
- for dim_in in dim_inputs:
307
- mlp = nn.Sequential(
308
- MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1)
309
- )
310
-
311
- self.to_freqs.append(mlp)
312
-
313
- def forward(self, x):
314
- x = x.unbind(dim=-2)
315
-
316
- outs = []
317
-
318
- for band_features, mlp in zip(x, self.to_freqs):
319
- freq_out = mlp(band_features)
320
- outs.append(freq_out)
321
-
322
- return torch.cat(outs, dim=-1)
323
-
324
-
325
- # fmt: off
326
- DEFAULT_FREQS_PER_BANDS = (
327
- 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
328
- 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
329
- 12, 12, 12, 12, 12, 12, 12, 12,
330
- 24, 24, 24, 24, 24, 24, 24, 24,
331
- 48, 48, 48, 48, 48, 48, 48, 48,
332
- 128, 129
333
- )
334
- # fmt: on
335
-
336
-
337
- class BSRoformer_SW(Module):
338
- @beartype
339
- def __init__(
340
- self,
341
- dim,
342
- *,
343
- depth,
344
- stereo=False,
345
- num_stems=1,
346
- time_transformer_depth=2,
347
- freq_transformer_depth=2,
348
- linear_transformer_depth=0,
349
- freqs_per_bands: tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
350
- dim_head=64,
351
- heads=8,
352
- attn_dropout=0.0,
353
- ff_dropout=0.0,
354
- flash_attn=True,
355
- stft_n_fft=2048,
356
- stft_hop_length=512,
357
- stft_win_length=2048,
358
- stft_normalized=False,
359
- stft_window_fn: Callable | None = None,
360
- mask_estimator_depth=2,
361
- multi_stft_resolution_loss_weight=1.0,
362
- multi_stft_resolutions_window_sizes: tuple[int, ...] = (
363
- 4096,
364
- 2048,
365
- 1024,
366
- 512,
367
- 256,
368
- ),
369
- multi_stft_hop_size=147,
370
- multi_stft_normalized=False,
371
- multi_stft_window_fn: Callable = torch.hann_window,
372
- mlp_expansion_factor=4,
373
- use_torch_checkpoint=False,
374
- skip_connection=False,
375
- sage_attention=False,
376
- use_shared_bias=False,
377
- chunk_size: int = 588800,
378
- ):
379
- super().__init__()
380
-
381
- self.stereo = stereo
382
- self.audio_channels = 2 if stereo else 1
383
- self.num_stems = num_stems
384
- self.use_torch_checkpoint = use_torch_checkpoint
385
- self.skip_connection = skip_connection
386
-
387
- self.layers = ModuleList([])
388
-
389
- if sage_attention:
390
- print("Use Sage Attention")
391
-
392
- if use_shared_bias:
393
- dim_inner = heads * dim_head
394
- self.shared_qkv_bias = nn.Parameter(torch.ones(dim_inner * 3))
395
- self.shared_out_bias = nn.Parameter(torch.ones(dim))
396
-
397
- transformer_kwargs = dict(
398
- dim=dim,
399
- heads=heads,
400
- dim_head=dim_head,
401
- attn_dropout=attn_dropout,
402
- ff_dropout=ff_dropout,
403
- flash_attn=flash_attn,
404
- norm_output=False,
405
- sage_attention=sage_attention,
406
- shared_qkv_bias=self.shared_qkv_bias,
407
- shared_out_bias=self.shared_out_bias,
408
- )
409
-
410
- t_frames = chunk_size // stft_hop_length + 1
411
- self.cos_emb_time = nn.Parameter(torch.zeros(t_frames, dim_head))
412
- self.sin_emb_time = nn.Parameter(torch.zeros(t_frames, dim_head))
413
- time_rotary_embed = RotaryEmbedding(
414
- cos_emb=self.cos_emb_time, sin_emb=self.sin_emb_time
415
- )
416
-
417
- num_bands = len(freqs_per_bands)
418
- self.cos_emb_freq = nn.Parameter(torch.zeros(num_bands, dim_head))
419
- self.sin_emb_freq = nn.Parameter(torch.zeros(num_bands, dim_head))
420
- freq_rotary_embed = RotaryEmbedding(
421
- cos_emb=self.cos_emb_freq, sin_emb=self.sin_emb_freq
422
- )
423
-
424
- for _ in range(depth):
425
- tran_modules = []
426
- if linear_transformer_depth > 0:
427
- tran_modules.append(
428
- Transformer(
429
- depth=linear_transformer_depth,
430
- linear_attn=True,
431
- **transformer_kwargs,
432
- )
433
- )
434
- tran_modules.append(
435
- Transformer(
436
- depth=time_transformer_depth,
437
- rotary_embed=time_rotary_embed,
438
- **transformer_kwargs,
439
- )
440
- )
441
- tran_modules.append(
442
- Transformer(
443
- depth=freq_transformer_depth,
444
- rotary_embed=freq_rotary_embed,
445
- **transformer_kwargs,
446
- )
447
- )
448
- self.layers.append(nn.ModuleList(tran_modules))
449
-
450
- self.final_norm = CustomNorm(dim)
451
-
452
- self.stft_kwargs = dict(
453
- n_fft=stft_n_fft,
454
- hop_length=stft_hop_length,
455
- win_length=stft_win_length,
456
- normalized=stft_normalized,
457
- )
458
-
459
- self.stft_window_fn = partial(
460
- stft_window_fn or torch.hann_window, stft_win_length
461
- )
462
-
463
- freqs_per_bands_with_complex = tuple(
464
- 2 * f * self.audio_channels for f in freqs_per_bands
465
- )
466
-
467
- self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
468
-
469
- self.mask_estimators = nn.ModuleList([])
470
-
471
- for _ in range(num_stems):
472
- mask_estimator = MaskEstimator(
473
- dim=dim,
474
- dim_inputs=freqs_per_bands_with_complex,
475
- depth=mask_estimator_depth,
476
- mlp_expansion_factor=mlp_expansion_factor,
477
- )
478
-
479
- self.mask_estimators.append(mask_estimator)
480
-
481
- self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
482
- self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
483
- self.multi_stft_n_fft = stft_n_fft
484
- self.multi_stft_window_fn = multi_stft_window_fn
485
-
486
- self.multi_stft_kwargs = dict(
487
- hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
488
- )
489
-
490
- def forward(self, raw_audio, target=None, return_loss_breakdown=False):
491
-
492
- device = raw_audio.device
493
-
494
- x_is_mps = True if device.type == "mps" else False
495
-
496
- if raw_audio.ndim == 2:
497
- raw_audio = rearrange(raw_audio, "b t -> b 1 t")
498
-
499
- channels = raw_audio.shape[1]
500
- assert (not self.stereo and channels == 1) or (
501
- self.stereo and channels == 2
502
- ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
503
-
504
- raw_audio, batch_audio_channel_packed_shape = pack([raw_audio], "* t")
505
-
506
- stft_window = self.stft_window_fn(device=device)
507
-
508
- try:
509
- stft_repr = torch.stft(
510
- raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
511
- )
512
- except Exception:
513
- stft_repr = torch.stft(
514
- raw_audio.cpu() if x_is_mps else raw_audio,
515
- **self.stft_kwargs,
516
- window=stft_window.cpu() if x_is_mps else stft_window,
517
- return_complex=True,
518
- ).to(device)
519
- stft_repr = torch.view_as_real(stft_repr)
520
-
521
- stft_repr = unpack(stft_repr, batch_audio_channel_packed_shape, "* f t c")[0]
522
-
523
- stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
524
-
525
- x = rearrange(stft_repr, "b f t c -> b t (f c)")
526
-
527
- if torch.isnan(x).any() or torch.isinf(x).any():
528
- raise RuntimeError(
529
- f"NaN/Inf in x after stft: {x.isnan().sum()} NaNs, {x.isinf().sum()} Infs"
530
- )
531
-
532
- if self.use_torch_checkpoint:
533
- x = checkpoint(self.band_split, x, use_reentrant=False)
534
- else:
535
- x = self.band_split(x)
536
-
537
- if torch.isnan(x).any() or torch.isinf(x).any():
538
- raise RuntimeError(
539
- f"NaN/Inf in x after band_split: {x.isnan().sum()} NaNs, {x.isinf().sum()} Infs"
540
- )
541
-
542
- store = [None] * len(self.layers)
543
- for i, transformer_block in enumerate(self.layers):
544
- if len(transformer_block) == 3:
545
- linear_transformer, time_transformer, freq_transformer = (
546
- transformer_block
547
- )
548
-
549
- x, ft_ps = pack([x], "b * d")
550
- if self.use_torch_checkpoint:
551
- x = checkpoint(linear_transformer, x, use_reentrant=False)
552
- else:
553
- x = linear_transformer(x)
554
- (x,) = unpack(x, ft_ps, "b * d")
555
- else:
556
- time_transformer, freq_transformer = transformer_block
557
-
558
- if self.skip_connection:
559
- for j in range(i):
560
- x = x + store[j]
561
-
562
- x = rearrange(x, "b t f d -> b f t d")
563
- x, ps = pack([x], "* t d")
564
-
565
- if self.use_torch_checkpoint:
566
- x = checkpoint(time_transformer, x, use_reentrant=False)
567
- else:
568
- x = time_transformer(x)
569
-
570
- (x,) = unpack(x, ps, "* t d")
571
- x = rearrange(x, "b f t d -> b t f d")
572
- x, ps = pack([x], "* f d")
573
-
574
- if self.use_torch_checkpoint:
575
- x = checkpoint(freq_transformer, x, use_reentrant=False)
576
- else:
577
- x = freq_transformer(x)
578
-
579
- (x,) = unpack(x, ps, "* f d")
580
-
581
- if self.skip_connection:
582
- store[i] = x
583
-
584
- x = self.final_norm(x)
585
-
586
- num_stems = len(self.mask_estimators)
587
-
588
- if self.use_torch_checkpoint:
589
- mask = torch.stack(
590
- [checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators],
591
- dim=1,
592
- )
593
- else:
594
- mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
595
- mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
596
-
597
- stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
598
-
599
- stft_repr = torch.view_as_complex(stft_repr)
600
- mask = torch.view_as_complex(mask)
601
-
602
- stft_repr = stft_repr * mask
603
-
604
- stft_repr = rearrange(
605
- stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
606
- )
607
-
608
- try:
609
- recon_audio = torch.istft(
610
- stft_repr,
611
- **self.stft_kwargs,
612
- window=stft_window,
613
- return_complex=False,
614
- length=raw_audio.shape[-1],
615
- )
616
- except Exception:
617
- recon_audio = torch.istft(
618
- stft_repr.cpu() if x_is_mps else stft_repr,
619
- **self.stft_kwargs,
620
- window=stft_window.cpu() if x_is_mps else stft_window,
621
- return_complex=False,
622
- length=raw_audio.shape[-1],
623
- ).to(device)
624
-
625
- recon_audio = rearrange(
626
- recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
627
- )
628
-
629
- if num_stems == 1:
630
- recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
631
-
632
- if target is None:
633
- return recon_audio
634
-
635
- if self.num_stems > 1:
636
- assert target.ndim == 4 and target.shape[1] == self.num_stems
637
-
638
- if target.ndim == 2:
639
- target = rearrange(target, "... t -> ... 1 t")
640
-
641
- target = target[..., : recon_audio.shape[-1]]
642
-
643
- loss = F.l1_loss(recon_audio, target)
644
-
645
- multi_stft_resolution_loss = 0.0
646
-
647
- for window_size in self.multi_stft_resolutions_window_sizes:
648
- res_stft_kwargs = dict(
649
- n_fft=max(window_size, self.multi_stft_n_fft),
650
- win_length=window_size,
651
- return_complex=True,
652
- window=self.multi_stft_window_fn(window_size, device=device),
653
- **self.multi_stft_kwargs,
654
- )
655
-
656
- recon_Y = torch.stft(
657
- rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
658
- )
659
- target_Y = torch.stft(
660
- rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
661
- )
662
-
663
- multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
664
- recon_Y, target_Y
665
- )
666
-
667
- weighted_multi_resolution_loss = (
668
- multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
669
- )
670
-
671
- total_loss = loss + weighted_multi_resolution_loss
672
-
673
- if not return_loss_breakdown:
674
- return total_loss
675
-
676
- return total_loss, (loss, multi_stft_resolution_loss)
 
1
+ from __future__ import annotations
2
+
3
+ from functools import partial
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from beartype import beartype
8
+ from beartype.typing import Callable
9
+ from einops import pack, rearrange, unpack
10
+ from einops.layers.torch import Rearrange
11
+ from torch import nn
12
+ from torch.nn import Module, ModuleList
13
+ from torch.utils.checkpoint import checkpoint
14
+
15
+ from .attend import Attend
16
+
17
+
18
+ def l2norm(t):
19
+ return F.normalize(t, dim=-1, p=2)
20
+
21
+
22
+ class CustomNorm(Module):
23
+ def __init__(self, dim, eps: float = 5.960464477539063e-08):
24
+ super().__init__()
25
+ self.scale = dim**0.5
26
+ self.gamma = nn.Parameter(torch.ones(dim))
27
+ self.eps = eps
28
+
29
+ def forward(self, x):
30
+ l2_norm = torch.linalg.norm(x, dim=-1, keepdim=True)
31
+ denom = torch.maximum(l2_norm, torch.full_like(l2_norm, self.eps))
32
+ normalized_x = x / denom
33
+ return normalized_x * self.scale * self.gamma
34
+
35
+
36
+ class RotaryEmbedding(nn.Module):
37
+ def __init__(self, cos_emb, sin_emb):
38
+ super().__init__()
39
+ self.cos_emb = cos_emb
40
+ self.sin_emb = sin_emb
41
+
42
+ def rotate_half(self, x):
43
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
44
+ x1, x2 = x.unbind(dim=-1)
45
+ x = torch.stack((-x2, x1), dim=-1)
46
+ return rearrange(x, "... d r -> ... (d r)")
47
+
48
+ def forward(self, x):
49
+ cos_b = self.cos_emb.unsqueeze(0).unsqueeze(0).to(x.device, x.dtype)
50
+ sin_b = self.sin_emb.unsqueeze(0).unsqueeze(0).to(x.device, x.dtype)
51
+
52
+ term1 = x * cos_b
53
+ term2 = self.rotate_half(x) * sin_b
54
+
55
+ sum = term1.to(torch.float32) + term2.to(torch.float32)
56
+ return sum.to(x.dtype)
57
+
58
+
59
+ class FeedForward(Module):
60
+ def __init__(self, dim, mult=4, dropout=0.0):
61
+ super().__init__()
62
+ dim_inner = int(dim * mult)
63
+ self.net = nn.Sequential(
64
+ CustomNorm(dim),
65
+ nn.Linear(dim, dim_inner),
66
+ nn.GELU(),
67
+ nn.Dropout(dropout),
68
+ nn.Linear(dim_inner, dim),
69
+ nn.Dropout(dropout),
70
+ )
71
+
72
+ def forward(self, x):
73
+ return self.net(x)
74
+
75
+
76
+ class Attention(Module):
77
+ def __init__(
78
+ self,
79
+ dim,
80
+ heads=8,
81
+ dim_head=64,
82
+ dropout=0.0,
83
+ shared_qkv_bias=None,
84
+ shared_out_bias=None,
85
+ rotary_embed: RotaryEmbedding | None = None,
86
+ flash=True,
87
+ ):
88
+ super().__init__()
89
+ self.heads = heads
90
+ self.scale = dim_head**-0.5
91
+ dim_inner = heads * dim_head
92
+
93
+ self.rotary_embed = rotary_embed
94
+
95
+ self.attend = Attend(flash=flash, dropout=dropout) # type: ignore
96
+
97
+ self.norm = CustomNorm(dim)
98
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=(shared_qkv_bias is not None))
99
+ if shared_qkv_bias is not None:
100
+ self.to_qkv.bias = shared_qkv_bias
101
+
102
+ self.to_gates = nn.Linear(dim, heads)
103
+
104
+ self.to_out = nn.Sequential(
105
+ nn.Linear(dim_inner, dim, bias=(shared_out_bias is not None)),
106
+ nn.Dropout(dropout),
107
+ )
108
+ if shared_out_bias is not None:
109
+ self.to_out[0].bias = shared_out_bias
110
+
111
+ def forward(self, x):
112
+ x = self.norm(x)
113
+
114
+ qkv = self.to_qkv(x)
115
+ q, k, v = rearrange(qkv, "b n (qkv h d) -> qkv b h n d", qkv=3, h=self.heads)
116
+
117
+ if self.rotary_embed is not None:
118
+ q = self.rotary_embed(q)
119
+ k = self.rotary_embed(k)
120
+
121
+ out = self.attend(q, k, v)
122
+
123
+ gates = self.to_gates(x)
124
+ gate_act = gates.sigmoid()
125
+
126
+ out = out * rearrange(gate_act, "b n h -> b h n 1")
127
+
128
+ out = rearrange(out, "b h n d -> b n (h d)")
129
+ out = self.to_out(out)
130
+ return out
131
+
132
+
133
+ class LinearAttention(Module):
134
+
135
+ @beartype
136
+ def __init__(
137
+ self,
138
+ *,
139
+ dim,
140
+ dim_head=32,
141
+ heads=8,
142
+ scale=8,
143
+ flash=True,
144
+ dropout=0.0,
145
+ ):
146
+ super().__init__()
147
+ dim_inner = dim_head * heads
148
+ self.norm = CustomNorm(dim)
149
+
150
+ self.to_qkv = nn.Sequential(
151
+ nn.Linear(dim, dim_inner * 3, bias=False),
152
+ Rearrange("b n (qkv h d) -> qkv b h d n", qkv=3, h=heads),
153
+ )
154
+
155
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
156
+
157
+ self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
158
+
159
+ self.to_out = nn.Sequential(
160
+ Rearrange("b h d n -> b n (h d)"), nn.Linear(dim_inner, dim, bias=False)
161
+ )
162
+
163
+ def forward(self, x):
164
+ x = self.norm(x)
165
+
166
+ q, k, v = self.to_qkv(x)
167
+
168
+ q, k = map(l2norm, (q, k))
169
+ q = q * self.temperature.exp()
170
+
171
+ out = self.attend(q, k, v)
172
+
173
+ return self.to_out(out)
174
+
175
+
176
+ class Transformer(Module):
177
+ def __init__(
178
+ self,
179
+ *,
180
+ dim,
181
+ depth,
182
+ dim_head=64,
183
+ heads=8,
184
+ attn_dropout=0.0,
185
+ ff_dropout=0.0,
186
+ ff_mult=4,
187
+ norm_output=True,
188
+ rotary_embed: RotaryEmbedding | None = None,
189
+ flash_attn=True,
190
+ linear_attn=False,
191
+ shared_qkv_bias=None,
192
+ shared_out_bias=None,
193
+ **kwargs
194
+ ):
195
+ super().__init__()
196
+ self.layers = ModuleList([])
197
+
198
+ for _ in range(depth):
199
+ attn: LinearAttention | Attention
200
+ if linear_attn:
201
+ attn = LinearAttention(
202
+ dim=dim,
203
+ dim_head=dim_head,
204
+ heads=heads,
205
+ dropout=attn_dropout,
206
+ flash=flash_attn,
207
+ )
208
+ else:
209
+ attn = Attention(
210
+ dim=dim,
211
+ dim_head=dim_head,
212
+ heads=heads,
213
+ dropout=attn_dropout,
214
+ shared_qkv_bias=shared_qkv_bias,
215
+ shared_out_bias=shared_out_bias,
216
+ rotary_embed=rotary_embed,
217
+ flash=flash_attn,
218
+ )
219
+
220
+ self.layers.append(
221
+ ModuleList(
222
+ [attn, FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)]
223
+ )
224
+ )
225
+
226
+ self.norm = CustomNorm(dim) if norm_output else nn.Identity()
227
+
228
+ def forward(self, x):
229
+ for attn, ff in self.layers: # type: ignore
230
+ x = attn(x) + x
231
+ x = ff(x) + x
232
+ return self.norm(x)
233
+
234
+
235
+ class BandSplit(Module):
236
+ @beartype
237
+ def __init__(self, dim, dim_inputs: tuple[int, ...]):
238
+ super().__init__()
239
+ self.dim_inputs = dim_inputs
240
+ self.to_features = ModuleList([])
241
+
242
+ for dim_in in dim_inputs:
243
+ net = nn.Sequential(CustomNorm(dim_in), nn.Linear(dim_in, dim))
244
+
245
+ self.to_features.append(net)
246
+
247
+ def forward(self, x):
248
+ x = x.split(self.dim_inputs, dim=-1)
249
+
250
+ outs = []
251
+ for split_input, to_feature in zip(x, self.to_features):
252
+ split_output = to_feature(split_input)
253
+ outs.append(split_output)
254
+
255
+ return torch.stack(outs, dim=-2)
256
+
257
+
258
+ def MLP(
259
+ dim_in: int,
260
+ dim_out: int,
261
+ dim_hidden: int | None = None,
262
+ depth: int = 1,
263
+ activation=nn.Tanh,
264
+ ):
265
+ dim_hidden = dim_hidden or dim_in
266
+
267
+ net = []
268
+ dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
269
+
270
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
271
+ is_last = ind == (len(dims) - 2)
272
+
273
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
274
+
275
+ if is_last:
276
+ continue
277
+
278
+ net.append(activation())
279
+
280
+ return nn.Sequential(*net)
281
+
282
+
283
+ class MaskEstimator(Module):
284
+ @beartype
285
+ def __init__(self, dim, dim_inputs: tuple[int, ...], depth, mlp_expansion_factor=4):
286
+ super().__init__()
287
+ self.dim_inputs = dim_inputs
288
+ self.to_freqs = ModuleList([])
289
+ dim_hidden = dim * mlp_expansion_factor
290
+
291
+ for dim_in in dim_inputs:
292
+ mlp = nn.Sequential(
293
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth), nn.GLU(dim=-1)
294
+ )
295
+
296
+ self.to_freqs.append(mlp)
297
+
298
+ def forward(self, x):
299
+ x = x.unbind(dim=-2)
300
+
301
+ outs = []
302
+
303
+ for band_features, mlp in zip(x, self.to_freqs):
304
+ freq_out = mlp(band_features)
305
+ outs.append(freq_out)
306
+
307
+ return torch.cat(outs, dim=-1)
308
+
309
+
310
+ # fmt: off
311
+ DEFAULT_FREQS_PER_BANDS = (
312
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
313
+ 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
314
+ 12, 12, 12, 12, 12, 12, 12, 12,
315
+ 24, 24, 24, 24, 24, 24, 24, 24,
316
+ 48, 48, 48, 48, 48, 48, 48, 48,
317
+ 128, 129
318
+ )
319
+ # fmt: on
320
+
321
+
322
+ class BSRoformer_SW(Module):
323
+ @beartype
324
+ def __init__(
325
+ self,
326
+ dim,
327
+ *,
328
+ depth,
329
+ stereo=False,
330
+ num_stems=1,
331
+ time_transformer_depth=2,
332
+ freq_transformer_depth=2,
333
+ linear_transformer_depth=0,
334
+ freqs_per_bands: tuple[int, ...] = DEFAULT_FREQS_PER_BANDS,
335
+ dim_head=64,
336
+ heads=8,
337
+ attn_dropout=0.0,
338
+ ff_dropout=0.0,
339
+ flash_attn=True,
340
+ stft_n_fft=2048,
341
+ stft_hop_length=512,
342
+ stft_win_length=2048,
343
+ stft_normalized=False,
344
+ stft_window_fn: Callable | None = None,
345
+ mask_estimator_depth=2,
346
+ multi_stft_resolution_loss_weight=1.0,
347
+ multi_stft_resolutions_window_sizes: tuple[int, ...] = (
348
+ 4096,
349
+ 2048,
350
+ 1024,
351
+ 512,
352
+ 256,
353
+ ),
354
+ multi_stft_hop_size=147,
355
+ multi_stft_normalized=False,
356
+ multi_stft_window_fn: Callable = torch.hann_window,
357
+ mlp_expansion_factor=4,
358
+ use_torch_checkpoint=False,
359
+ skip_connection=False,
360
+ use_shared_bias=False,
361
+ chunk_size: int = 588800,
362
+ **kwargs
363
+ ):
364
+ super().__init__()
365
+
366
+ self.stereo = stereo
367
+ self.audio_channels = 2 if stereo else 1
368
+ self.num_stems = num_stems
369
+ self.use_torch_checkpoint = use_torch_checkpoint
370
+ self.skip_connection = skip_connection
371
+
372
+ self.layers = ModuleList([])
373
+
374
+ if use_shared_bias:
375
+ dim_inner = heads * dim_head
376
+ self.shared_qkv_bias = nn.Parameter(torch.ones(dim_inner * 3))
377
+ self.shared_out_bias = nn.Parameter(torch.ones(dim))
378
+
379
+ transformer_kwargs = dict(
380
+ dim=dim,
381
+ heads=heads,
382
+ dim_head=dim_head,
383
+ attn_dropout=attn_dropout,
384
+ ff_dropout=ff_dropout,
385
+ flash_attn=flash_attn,
386
+ norm_output=False,
387
+ shared_qkv_bias=self.shared_qkv_bias,
388
+ shared_out_bias=self.shared_out_bias,
389
+ )
390
+
391
+ t_frames = chunk_size // stft_hop_length + 1
392
+ self.cos_emb_time = nn.Parameter(torch.zeros(t_frames, dim_head))
393
+ self.sin_emb_time = nn.Parameter(torch.zeros(t_frames, dim_head))
394
+ time_rotary_embed = RotaryEmbedding(
395
+ cos_emb=self.cos_emb_time, sin_emb=self.sin_emb_time
396
+ )
397
+
398
+ num_bands = len(freqs_per_bands)
399
+ self.cos_emb_freq = nn.Parameter(torch.zeros(num_bands, dim_head))
400
+ self.sin_emb_freq = nn.Parameter(torch.zeros(num_bands, dim_head))
401
+ freq_rotary_embed = RotaryEmbedding(
402
+ cos_emb=self.cos_emb_freq, sin_emb=self.sin_emb_freq
403
+ )
404
+
405
+ for _ in range(depth):
406
+ tran_modules = []
407
+ if linear_transformer_depth > 0:
408
+ tran_modules.append(
409
+ Transformer(
410
+ depth=linear_transformer_depth,
411
+ linear_attn=True,
412
+ **transformer_kwargs,
413
+ )
414
+ )
415
+ tran_modules.append(
416
+ Transformer(
417
+ depth=time_transformer_depth,
418
+ rotary_embed=time_rotary_embed,
419
+ **transformer_kwargs,
420
+ )
421
+ )
422
+ tran_modules.append(
423
+ Transformer(
424
+ depth=freq_transformer_depth,
425
+ rotary_embed=freq_rotary_embed,
426
+ **transformer_kwargs,
427
+ )
428
+ )
429
+ self.layers.append(nn.ModuleList(tran_modules))
430
+
431
+ self.final_norm = CustomNorm(dim)
432
+
433
+ self.stft_kwargs = dict(
434
+ n_fft=stft_n_fft,
435
+ hop_length=stft_hop_length,
436
+ win_length=stft_win_length,
437
+ normalized=stft_normalized,
438
+ )
439
+
440
+ self.stft_window_fn = partial(
441
+ stft_window_fn or torch.hann_window, stft_win_length
442
+ )
443
+
444
+ freqs_per_bands_with_complex = tuple(
445
+ 2 * f * self.audio_channels for f in freqs_per_bands
446
+ )
447
+
448
+ self.band_split = BandSplit(dim=dim, dim_inputs=freqs_per_bands_with_complex)
449
+
450
+ self.mask_estimators = nn.ModuleList([])
451
+
452
+ for _ in range(num_stems):
453
+ mask_estimator = MaskEstimator(
454
+ dim=dim,
455
+ dim_inputs=freqs_per_bands_with_complex,
456
+ depth=mask_estimator_depth,
457
+ mlp_expansion_factor=mlp_expansion_factor,
458
+ )
459
+
460
+ self.mask_estimators.append(mask_estimator)
461
+
462
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
463
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
464
+ self.multi_stft_n_fft = stft_n_fft
465
+ self.multi_stft_window_fn = multi_stft_window_fn
466
+
467
+ self.multi_stft_kwargs = dict(
468
+ hop_length=multi_stft_hop_size, normalized=multi_stft_normalized
469
+ )
470
+
471
+ def forward(self, raw_audio, target=None, return_loss_breakdown=False):
472
+
473
+ device = raw_audio.device
474
+
475
+ x_is_mps = True if device.type == "mps" else False
476
+
477
+ if raw_audio.ndim == 2:
478
+ raw_audio = rearrange(raw_audio, "b t -> b 1 t")
479
+
480
+ channels = raw_audio.shape[1]
481
+ assert (not self.stereo and channels == 1) or (
482
+ self.stereo and channels == 2
483
+ ), "stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)"
484
+
485
+ raw_audio, batch_audio_channel_packed_shape = pack([raw_audio], "* t")
486
+
487
+ stft_window = self.stft_window_fn(device=device)
488
+
489
+ try:
490
+ stft_repr = torch.stft(
491
+ raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True
492
+ )
493
+ except Exception:
494
+ stft_repr = torch.stft(
495
+ raw_audio.cpu() if x_is_mps else raw_audio,
496
+ **self.stft_kwargs,
497
+ window=stft_window.cpu() if x_is_mps else stft_window,
498
+ return_complex=True,
499
+ ).to(device)
500
+ stft_repr = torch.view_as_real(stft_repr)
501
+
502
+ stft_repr = unpack(stft_repr, batch_audio_channel_packed_shape, "* f t c")[0]
503
+
504
+ stft_repr = rearrange(stft_repr, "b s f t c -> b (f s) t c")
505
+
506
+ x = rearrange(stft_repr, "b f t c -> b t (f c)")
507
+
508
+ if torch.isnan(x).any() or torch.isinf(x).any():
509
+ raise RuntimeError(
510
+ f"NaN/Inf in x after stft: {x.isnan().sum()} NaNs, {x.isinf().sum()} Infs"
511
+ )
512
+
513
+ if self.use_torch_checkpoint:
514
+ x = checkpoint(self.band_split, x, use_reentrant=False)
515
+ else:
516
+ x = self.band_split(x)
517
+
518
+ if torch.isnan(x).any() or torch.isinf(x).any():
519
+ raise RuntimeError(
520
+ f"NaN/Inf in x after band_split: {x.isnan().sum()} NaNs, {x.isinf().sum()} Infs"
521
+ )
522
+
523
+ store = [None] * len(self.layers)
524
+ for i, transformer_block in enumerate(self.layers):
525
+ if len(transformer_block) == 3:
526
+ linear_transformer, time_transformer, freq_transformer = (
527
+ transformer_block
528
+ )
529
+
530
+ x, ft_ps = pack([x], "b * d")
531
+ if self.use_torch_checkpoint:
532
+ x = checkpoint(linear_transformer, x, use_reentrant=False)
533
+ else:
534
+ x = linear_transformer(x)
535
+ (x,) = unpack(x, ft_ps, "b * d")
536
+ else:
537
+ time_transformer, freq_transformer = transformer_block
538
+
539
+ if self.skip_connection:
540
+ for j in range(i):
541
+ x = x + store[j]
542
+
543
+ x = rearrange(x, "b t f d -> b f t d")
544
+ x, ps = pack([x], "* t d")
545
+
546
+ if self.use_torch_checkpoint:
547
+ x = checkpoint(time_transformer, x, use_reentrant=False)
548
+ else:
549
+ x = time_transformer(x)
550
+
551
+ (x,) = unpack(x, ps, "* t d")
552
+ x = rearrange(x, "b f t d -> b t f d")
553
+ x, ps = pack([x], "* f d")
554
+
555
+ if self.use_torch_checkpoint:
556
+ x = checkpoint(freq_transformer, x, use_reentrant=False)
557
+ else:
558
+ x = freq_transformer(x)
559
+
560
+ (x,) = unpack(x, ps, "* f d")
561
+
562
+ if self.skip_connection:
563
+ store[i] = x
564
+
565
+ x = self.final_norm(x)
566
+
567
+ num_stems = len(self.mask_estimators)
568
+
569
+ if self.use_torch_checkpoint:
570
+ mask = torch.stack(
571
+ [checkpoint(fn, x, use_reentrant=False) for fn in self.mask_estimators],
572
+ dim=1,
573
+ )
574
+ else:
575
+ mask = torch.stack([fn(x) for fn in self.mask_estimators], dim=1)
576
+ mask = rearrange(mask, "b n t (f c) -> b n f t c", c=2)
577
+
578
+ stft_repr = rearrange(stft_repr, "b f t c -> b 1 f t c")
579
+
580
+ stft_repr = torch.view_as_complex(stft_repr)
581
+ mask = torch.view_as_complex(mask)
582
+
583
+ stft_repr = stft_repr * mask
584
+
585
+ stft_repr = rearrange(
586
+ stft_repr, "b n (f s) t -> (b n s) f t", s=self.audio_channels
587
+ )
588
+
589
+ try:
590
+ recon_audio = torch.istft(
591
+ stft_repr,
592
+ **self.stft_kwargs,
593
+ window=stft_window,
594
+ return_complex=False,
595
+ length=raw_audio.shape[-1],
596
+ )
597
+ except Exception:
598
+ recon_audio = torch.istft(
599
+ stft_repr.cpu() if x_is_mps else stft_repr,
600
+ **self.stft_kwargs,
601
+ window=stft_window.cpu() if x_is_mps else stft_window,
602
+ return_complex=False,
603
+ length=raw_audio.shape[-1],
604
+ ).to(device)
605
+
606
+ recon_audio = rearrange(
607
+ recon_audio, "(b n s) t -> b n s t", s=self.audio_channels, n=num_stems
608
+ )
609
+
610
+ if num_stems == 1:
611
+ recon_audio = rearrange(recon_audio, "b 1 s t -> b s t")
612
+
613
+ if target is None:
614
+ return recon_audio
615
+
616
+ if self.num_stems > 1:
617
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
618
+
619
+ if target.ndim == 2:
620
+ target = rearrange(target, "... t -> ... 1 t")
621
+
622
+ target = target[..., : recon_audio.shape[-1]]
623
+
624
+ loss = F.l1_loss(recon_audio, target)
625
+
626
+ multi_stft_resolution_loss = 0.0
627
+
628
+ for window_size in self.multi_stft_resolutions_window_sizes:
629
+ res_stft_kwargs = dict(
630
+ n_fft=max(window_size, self.multi_stft_n_fft),
631
+ win_length=window_size,
632
+ return_complex=True,
633
+ window=self.multi_stft_window_fn(window_size, device=device),
634
+ **self.multi_stft_kwargs,
635
+ )
636
+
637
+ recon_Y = torch.stft(
638
+ rearrange(recon_audio, "... s t -> (... s) t"), **res_stft_kwargs
639
+ )
640
+ target_Y = torch.stft(
641
+ rearrange(target, "... s t -> (... s) t"), **res_stft_kwargs
642
+ )
643
+
644
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(
645
+ recon_Y, target_Y
646
+ )
647
+
648
+ weighted_multi_resolution_loss = (
649
+ multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
650
+ )
651
+
652
+ total_loss = loss + weighted_multi_resolution_loss
653
+
654
+ if not return_loss_breakdown:
655
+ return total_loss
656
+
657
+ return total_loss, (loss, multi_stft_resolution_loss)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/bs_roformer/bs_roformer_unwa_inst_large_2.py CHANGED
@@ -6,10 +6,7 @@ from torch.nn import Module, ModuleList
6
  import torch.nn.functional as F
7
 
8
  from .attend import Attend
9
- try:
10
- from .attend_sage import Attend as AttendSage
11
- except:
12
- pass
13
  from torch.utils.checkpoint import checkpoint
14
 
15
  from beartype.typing import Tuple, Optional, List, Callable
@@ -85,7 +82,6 @@ class Attention(Module):
85
  dropout=0.,
86
  rotary_embed=None,
87
  flash=True,
88
- sage_attention=False,
89
  ):
90
  super().__init__()
91
  self.heads = heads
@@ -94,10 +90,7 @@ class Attention(Module):
94
 
95
  self.rotary_embed = rotary_embed
96
 
97
- if sage_attention:
98
- self.attend = AttendSage(flash=flash, dropout=dropout)
99
- else:
100
- self.attend = Attend(flash=flash, dropout=dropout)
101
 
102
  self.norm = RMSNorm(dim)
103
  self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
@@ -142,7 +135,6 @@ class LinearAttention(Module):
142
  scale=8,
143
  flash=False,
144
  dropout=0.,
145
- sage_attention=False,
146
  ):
147
  super().__init__()
148
  dim_inner = dim_head * heads
@@ -155,18 +147,11 @@ class LinearAttention(Module):
155
 
156
  self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
157
 
158
- if sage_attention:
159
- self.attend = AttendSage(
160
- scale=scale,
161
- dropout=dropout,
162
- flash=flash
163
- )
164
- else:
165
- self.attend = Attend(
166
- scale=scale,
167
- dropout=dropout,
168
- flash=flash
169
- )
170
 
171
  self.to_out = nn.Sequential(
172
  Rearrange('b h d n -> b n (h d)'),
@@ -203,7 +188,6 @@ class Transformer(Module):
203
  rotary_embed=None,
204
  flash_attn=True,
205
  linear_attn=False,
206
- sage_attention=False,
207
  ):
208
  super().__init__()
209
  self.layers = ModuleList([])
@@ -216,7 +200,6 @@ class Transformer(Module):
216
  heads=heads,
217
  dropout=attn_dropout,
218
  flash=flash_attn,
219
- sage_attention=sage_attention
220
  )
221
  else:
222
  attn = Attention(
@@ -226,7 +209,6 @@ class Transformer(Module):
226
  dropout=attn_dropout,
227
  rotary_embed=rotary_embed,
228
  flash=flash_attn,
229
- sage_attention=sage_attention
230
  )
231
 
232
  self.layers.append(ModuleList([
@@ -342,7 +324,6 @@ class MaskEstimator(Module):
342
  ff_dropout=0.,
343
  flash_attn=True,
344
  norm_output=False,
345
- sage_attention=False,
346
  )
347
 
348
  time_rotary_embed = RotaryEmbedding(dim=dim_head)
@@ -445,7 +426,7 @@ class BSRoformer_2(Module):
445
  mlp_expansion_factor=4,
446
  use_torch_checkpoint=False,
447
  skip_connection=False,
448
- sage_attention=False,
449
  ):
450
  super().__init__()
451
 
@@ -457,9 +438,6 @@ class BSRoformer_2(Module):
457
 
458
  self.layers = ModuleList([])
459
 
460
- if sage_attention:
461
- print("Use Sage Attention")
462
-
463
  transformer_kwargs = dict(
464
  dim=dim,
465
  heads=heads,
@@ -468,7 +446,6 @@ class BSRoformer_2(Module):
468
  ff_dropout=ff_dropout,
469
  flash_attn=flash_attn,
470
  norm_output=False,
471
- sage_attention=sage_attention,
472
  )
473
 
474
  time_rotary_embed = RotaryEmbedding(dim=dim_head)
 
6
  import torch.nn.functional as F
7
 
8
  from .attend import Attend
9
+
 
 
 
10
  from torch.utils.checkpoint import checkpoint
11
 
12
  from beartype.typing import Tuple, Optional, List, Callable
 
82
  dropout=0.,
83
  rotary_embed=None,
84
  flash=True,
 
85
  ):
86
  super().__init__()
87
  self.heads = heads
 
90
 
91
  self.rotary_embed = rotary_embed
92
 
93
+ self.attend = Attend(flash=flash, dropout=dropout)
 
 
 
94
 
95
  self.norm = RMSNorm(dim)
96
  self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
 
135
  scale=8,
136
  flash=False,
137
  dropout=0.,
 
138
  ):
139
  super().__init__()
140
  dim_inner = dim_head * heads
 
147
 
148
  self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
149
 
150
+ self.attend = Attend(
151
+ scale=scale,
152
+ dropout=dropout,
153
+ flash=flash
154
+ )
 
 
 
 
 
 
 
155
 
156
  self.to_out = nn.Sequential(
157
  Rearrange('b h d n -> b n (h d)'),
 
188
  rotary_embed=None,
189
  flash_attn=True,
190
  linear_attn=False,
 
191
  ):
192
  super().__init__()
193
  self.layers = ModuleList([])
 
200
  heads=heads,
201
  dropout=attn_dropout,
202
  flash=flash_attn,
 
203
  )
204
  else:
205
  attn = Attention(
 
209
  dropout=attn_dropout,
210
  rotary_embed=rotary_embed,
211
  flash=flash_attn,
 
212
  )
213
 
214
  self.layers.append(ModuleList([
 
324
  ff_dropout=0.,
325
  flash_attn=True,
326
  norm_output=False,
 
327
  )
328
 
329
  time_rotary_embed = RotaryEmbedding(dim=dim_head)
 
426
  mlp_expansion_factor=4,
427
  use_torch_checkpoint=False,
428
  skip_connection=False,
429
+ **kwargs
430
  ):
431
  super().__init__()
432
 
 
438
 
439
  self.layers = ModuleList([])
440
 
 
 
 
441
  transformer_kwargs = dict(
442
  dim=dim,
443
  heads=heads,
 
446
  ff_dropout=ff_dropout,
447
  flash_attn=flash_attn,
448
  norm_output=False,
 
449
  )
450
 
451
  time_rotary_embed = RotaryEmbedding(dim=dim_head)
models/bs_roformer/bs_siamese_roformer.py CHANGED
@@ -6,10 +6,6 @@ from torch.nn import Module, ModuleList
6
  import torch.nn.functional as F
7
 
8
  from .attend import Attend
9
- try:
10
- from .attend_sage import Attend as AttendSage
11
- except:
12
- pass
13
  from torch.utils.checkpoint import checkpoint
14
 
15
  from beartype.typing import Tuple, Optional, List, Callable
@@ -86,7 +82,6 @@ class Attention(Module):
86
  dropout=0.,
87
  rotary_embed=None,
88
  flash=True,
89
- sage_attention=False,
90
  ):
91
  super().__init__()
92
  self.heads = heads
@@ -95,10 +90,7 @@ class Attention(Module):
95
 
96
  self.rotary_embed = rotary_embed
97
 
98
- if sage_attention:
99
- self.attend = AttendSage(flash=flash, dropout=dropout)
100
- else:
101
- self.attend = Attend(flash=flash, dropout=dropout)
102
 
103
  self.norm = RMSNorm(dim)
104
  self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
@@ -143,7 +135,6 @@ class LinearAttention(Module):
143
  scale=8,
144
  flash=False,
145
  dropout=0.,
146
- sage_attention=False,
147
  ):
148
  super().__init__()
149
  dim_inner = dim_head * heads
@@ -156,18 +147,11 @@ class LinearAttention(Module):
156
 
157
  self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
158
 
159
- if sage_attention:
160
- self.attend = AttendSage(
161
- scale=scale,
162
- dropout=dropout,
163
- flash=flash
164
- )
165
- else:
166
- self.attend = Attend(
167
- scale=scale,
168
- dropout=dropout,
169
- flash=flash
170
- )
171
 
172
  self.to_out = nn.Sequential(
173
  Rearrange('b h d n -> b n (h d)'),
@@ -205,7 +189,6 @@ class SiameseTransformer(Module):
205
  rotary_embed=None,
206
  flash_attn=True,
207
  linear_attn=False,
208
- sage_attention=False,
209
  ):
210
  super().__init__()
211
  self.layers = ModuleList([])
@@ -223,7 +206,6 @@ class SiameseTransformer(Module):
223
  heads=heads,
224
  dropout=attn_dropout,
225
  flash=flash_attn,
226
- sage_attention=sage_attention
227
  )
228
  else:
229
  attn = Attention(
@@ -233,7 +215,6 @@ class SiameseTransformer(Module):
233
  dropout=attn_dropout,
234
  rotary_embed=rotary_embed,
235
  flash=flash_attn,
236
- sage_attention=sage_attention
237
  )
238
 
239
  self.layers.append(ModuleList([
@@ -415,7 +396,7 @@ class BSSiameseRoformer(Module):
415
  mlp_expansion_factor=4,
416
  use_torch_checkpoint=False,
417
  skip_connection=False,
418
- sage_attention=False,
419
  ):
420
  super().__init__()
421
 
@@ -427,9 +408,6 @@ class BSSiameseRoformer(Module):
427
 
428
  self.layers = ModuleList([])
429
 
430
- if sage_attention:
431
- print("Use Sage Attention")
432
-
433
  transformer_kwargs = dict(
434
  dim=dim,
435
  heads=heads,
@@ -438,7 +416,6 @@ class BSSiameseRoformer(Module):
438
  ff_dropout=ff_dropout,
439
  flash_attn=flash_attn,
440
  norm_output=False,
441
- sage_attention=sage_attention,
442
  )
443
 
444
  time_rotary_embed = RotaryEmbedding(dim=dim_head)
 
6
  import torch.nn.functional as F
7
 
8
  from .attend import Attend
 
 
 
 
9
  from torch.utils.checkpoint import checkpoint
10
 
11
  from beartype.typing import Tuple, Optional, List, Callable
 
82
  dropout=0.,
83
  rotary_embed=None,
84
  flash=True,
 
85
  ):
86
  super().__init__()
87
  self.heads = heads
 
90
 
91
  self.rotary_embed = rotary_embed
92
 
93
+ self.attend = Attend(flash=flash, dropout=dropout)
 
 
 
94
 
95
  self.norm = RMSNorm(dim)
96
  self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
 
135
  scale=8,
136
  flash=False,
137
  dropout=0.,
 
138
  ):
139
  super().__init__()
140
  dim_inner = dim_head * heads
 
147
 
148
  self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
149
 
150
+ self.attend = Attend(
151
+ scale=scale,
152
+ dropout=dropout,
153
+ flash=flash
154
+ )
 
 
 
 
 
 
 
155
 
156
  self.to_out = nn.Sequential(
157
  Rearrange('b h d n -> b n (h d)'),
 
189
  rotary_embed=None,
190
  flash_attn=True,
191
  linear_attn=False,
 
192
  ):
193
  super().__init__()
194
  self.layers = ModuleList([])
 
206
  heads=heads,
207
  dropout=attn_dropout,
208
  flash=flash_attn,
 
209
  )
210
  else:
211
  attn = Attention(
 
215
  dropout=attn_dropout,
216
  rotary_embed=rotary_embed,
217
  flash=flash_attn,
 
218
  )
219
 
220
  self.layers.append(ModuleList([
 
396
  mlp_expansion_factor=4,
397
  use_torch_checkpoint=False,
398
  skip_connection=False,
399
+ **kwargs
400
  ):
401
  super().__init__()
402
 
 
408
 
409
  self.layers = ModuleList([])
410
 
 
 
 
411
  transformer_kwargs = dict(
412
  dim=dim,
413
  heads=heads,
 
416
  ff_dropout=ff_dropout,
417
  flash_attn=flash_attn,
418
  norm_output=False,
 
419
  )
420
 
421
  time_rotary_embed = RotaryEmbedding(dim=dim_head)
models/bs_roformer/fno1d.py CHANGED
The diff for this file is too large to render. See raw diff
 
models/bs_roformer/mel_band_conformer.py CHANGED
@@ -6,10 +6,6 @@ from torch.nn import Module, ModuleList
6
  import torch.nn.functional as F
7
 
8
  from .attend import Attend
9
- try:
10
- from .attend_sage import Attend as AttendSage
11
- except:
12
- pass
13
  from torch.utils.checkpoint import checkpoint
14
 
15
  from beartype.typing import Tuple, Optional, List, Callable
@@ -97,7 +93,6 @@ class Attention(Module):
97
  dropout=0.,
98
  rotary_embed=None,
99
  flash=True,
100
- sage_attention=False,
101
  ):
102
  super().__init__()
103
  self.heads = heads
@@ -106,10 +101,7 @@ class Attention(Module):
106
 
107
  self.rotary_embed = rotary_embed
108
 
109
- if sage_attention:
110
- self.attend = AttendSage(flash=flash, dropout=dropout)
111
- else:
112
- self.attend = Attend(flash=flash, dropout=dropout)
113
 
114
  self.norm = RMSNorm(dim)
115
  self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
@@ -153,7 +145,6 @@ class LinearAttention(Module):
153
  scale=8,
154
  flash=True,
155
  dropout=0.,
156
- sage_attention=False
157
  ):
158
  super().__init__()
159
  dim_inner = dim_head * heads
@@ -166,10 +157,7 @@ class LinearAttention(Module):
166
 
167
  self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
168
 
169
- if sage_attention:
170
- self.attend = AttendSage(scale=scale, dropout=dropout, flash=flash)
171
- else:
172
- self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
173
 
174
  self.to_out = nn.Sequential(
175
  Rearrange('b h d n -> b n (h d)'),
@@ -202,7 +190,6 @@ class Transformer(Module):
202
  rotary_embed=None,
203
  flash_attn=True,
204
  linear_attn=False,
205
- sage_attention=False,
206
  ):
207
  super().__init__()
208
  self.layers = ModuleList([])
@@ -215,7 +202,6 @@ class Transformer(Module):
215
  heads=heads,
216
  dropout=attn_dropout,
217
  flash=flash_attn,
218
- sage_attention=sage_attention
219
  )
220
  else:
221
  attn = Attention(
@@ -225,7 +211,6 @@ class Transformer(Module):
225
  dropout=attn_dropout,
226
  rotary_embed=rotary_embed,
227
  flash=flash_attn,
228
- sage_attention=sage_attention
229
  )
230
 
231
  self.layers.append(ModuleList([
@@ -290,7 +275,6 @@ class ConformerBlock(nn.Module):
290
  conv_kernel_size=31,
291
  rotary_embed=None,
292
  flash_attn=True,
293
- sage_attention=False
294
  ):
295
  super().__init__()
296
  self.ff1 = MacaronFF(dim=dim, mult=ff_mult, dropout=ff_dropout)
@@ -301,7 +285,6 @@ class ConformerBlock(nn.Module):
301
  dropout=attn_dropout,
302
  rotary_embed=rotary_embed,
303
  flash=flash_attn,
304
- sage_attention=sage_attention
305
  )
306
  self.conv = ConformerConvModule(
307
  dim=dim,
@@ -333,7 +316,6 @@ class Conformer(Module):
333
  ff_mult=4,
334
  rotary_embed=None,
335
  flash_attn=True,
336
- sage_attention=False,
337
  conv_expansion_factor=2,
338
  conv_kernel_size=31,
339
  norm_output=True
@@ -351,7 +333,6 @@ class Conformer(Module):
351
  conv_kernel_size=conv_kernel_size,
352
  rotary_embed=rotary_embed,
353
  flash_attn=flash_attn,
354
- sage_attention=sage_attention
355
  ) for _ in range(depth)
356
  ])
357
  self.norm = RMSNorm(dim) if norm_output else nn.Identity()
@@ -466,7 +447,6 @@ class MelBandConformer(Module):
466
  mlp_expansion_factor=4,
467
  use_torch_checkpoint=False,
468
  skip_connection=False,
469
- sage_attention=False,
470
  # conformer-specific
471
  ff_mult=4,
472
  conv_expansion_factor=2,
@@ -482,9 +462,6 @@ class MelBandConformer(Module):
482
 
483
  self.layers = ModuleList([])
484
 
485
- if sage_attention:
486
- print("Use Sage Attention")
487
-
488
  transformer_kwargs = dict(
489
  dim = dim,
490
  heads = heads,
@@ -492,7 +469,6 @@ class MelBandConformer(Module):
492
  attn_dropout = attn_dropout,
493
  ff_dropout = ff_dropout,
494
  flash_attn = flash_attn,
495
- sage_attention = sage_attention,
496
  norm_output = False
497
  )
498
 
 
6
  import torch.nn.functional as F
7
 
8
  from .attend import Attend
 
 
 
 
9
  from torch.utils.checkpoint import checkpoint
10
 
11
  from beartype.typing import Tuple, Optional, List, Callable
 
93
  dropout=0.,
94
  rotary_embed=None,
95
  flash=True,
 
96
  ):
97
  super().__init__()
98
  self.heads = heads
 
101
 
102
  self.rotary_embed = rotary_embed
103
 
104
+ self.attend = Attend(flash=flash, dropout=dropout)
 
 
 
105
 
106
  self.norm = RMSNorm(dim)
107
  self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
 
145
  scale=8,
146
  flash=True,
147
  dropout=0.,
 
148
  ):
149
  super().__init__()
150
  dim_inner = dim_head * heads
 
157
 
158
  self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
159
 
160
+ self.attend = Attend(scale=scale, dropout=dropout, flash=flash)
 
 
 
161
 
162
  self.to_out = nn.Sequential(
163
  Rearrange('b h d n -> b n (h d)'),
 
190
  rotary_embed=None,
191
  flash_attn=True,
192
  linear_attn=False,
 
193
  ):
194
  super().__init__()
195
  self.layers = ModuleList([])
 
202
  heads=heads,
203
  dropout=attn_dropout,
204
  flash=flash_attn,
 
205
  )
206
  else:
207
  attn = Attention(
 
211
  dropout=attn_dropout,
212
  rotary_embed=rotary_embed,
213
  flash=flash_attn,
 
214
  )
215
 
216
  self.layers.append(ModuleList([
 
275
  conv_kernel_size=31,
276
  rotary_embed=None,
277
  flash_attn=True,
 
278
  ):
279
  super().__init__()
280
  self.ff1 = MacaronFF(dim=dim, mult=ff_mult, dropout=ff_dropout)
 
285
  dropout=attn_dropout,
286
  rotary_embed=rotary_embed,
287
  flash=flash_attn,
 
288
  )
289
  self.conv = ConformerConvModule(
290
  dim=dim,
 
316
  ff_mult=4,
317
  rotary_embed=None,
318
  flash_attn=True,
 
319
  conv_expansion_factor=2,
320
  conv_kernel_size=31,
321
  norm_output=True
 
333
  conv_kernel_size=conv_kernel_size,
334
  rotary_embed=rotary_embed,
335
  flash_attn=flash_attn,
 
336
  ) for _ in range(depth)
337
  ])
338
  self.norm = RMSNorm(dim) if norm_output else nn.Identity()
 
447
  mlp_expansion_factor=4,
448
  use_torch_checkpoint=False,
449
  skip_connection=False,
 
450
  # conformer-specific
451
  ff_mult=4,
452
  conv_expansion_factor=2,
 
462
 
463
  self.layers = ModuleList([])
464
 
 
 
 
465
  transformer_kwargs = dict(
466
  dim = dim,
467
  heads = heads,
 
469
  attn_dropout = attn_dropout,
470
  ff_dropout = ff_dropout,
471
  flash_attn = flash_attn,
 
472
  norm_output = False
473
  )
474
 
models/bs_roformer/mel_band_roformer.py CHANGED
@@ -1,748 +1,749 @@
1
- from functools import partial
2
-
3
- import torch
4
- from torch import nn, einsum, tensor, Tensor
5
- from torch.nn import Module, ModuleList
6
- import torch.nn.functional as F
7
-
8
- from .attend import Attend
9
-
10
- from torch.utils.checkpoint import checkpoint
11
-
12
- from beartype.typing import Tuple, Optional, List, Callable
13
- from beartype import beartype
14
-
15
- from rotary_embedding_torch import RotaryEmbedding
16
-
17
- from einops import rearrange, pack, unpack, reduce, repeat
18
- from einops.layers.torch import Rearrange
19
-
20
- from librosa import filters
21
-
22
- try:
23
- from .pope.attention import flash_attn_with_pope
24
- from .pope.pope import PoPE
25
- _HAS_POPE = True
26
- except Exception:
27
- PoPE = None
28
- flash_attn_with_pope = None
29
- _HAS_POPE = False
30
-
31
- # helper functions
32
-
33
- def exists(val):
34
- return val is not None
35
-
36
-
37
- def default(v, d):
38
- return v if exists(v) else d
39
-
40
-
41
- def pack_one(t, pattern):
42
- return pack([t], pattern)
43
-
44
-
45
- def unpack_one(t, ps, pattern):
46
- return unpack(t, ps, pattern)[0]
47
-
48
-
49
- def pad_at_dim(t, pad, dim=-1, value=0.):
50
- dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
51
- zeros = ((0, 0) * dims_from_right)
52
- return F.pad(t, (*zeros, *pad), value=value)
53
-
54
-
55
- def l2norm(t):
56
- return F.normalize(t, dim=-1, p=2)
57
-
58
-
59
- # norm
60
-
61
- class RMSNorm(Module):
62
- def __init__(self, dim):
63
- super().__init__()
64
- self.scale = dim ** 0.5
65
- self.gamma = nn.Parameter(torch.ones(dim))
66
-
67
- def forward(self, x):
68
- return F.normalize(x, dim=-1) * self.scale * self.gamma
69
-
70
-
71
- # attention
72
-
73
- class FeedForward(Module):
74
- def __init__(
75
- self,
76
- dim,
77
- mult=4,
78
- dropout=0.
79
- ):
80
- super().__init__()
81
- dim_inner = int(dim * mult)
82
- self.net = nn.Sequential(
83
- RMSNorm(dim),
84
- nn.Linear(dim, dim_inner),
85
- nn.GELU(),
86
- nn.Dropout(dropout),
87
- nn.Linear(dim_inner, dim),
88
- nn.Dropout(dropout)
89
- )
90
-
91
- def forward(self, x):
92
- return self.net(x)
93
-
94
-
95
- class Attention(Module):
96
- def __init__(
97
- self,
98
- dim,
99
- heads=8,
100
- dim_head=64,
101
- dropout=0.,
102
- rotary_embed=None,
103
- flash=True,
104
- pope_embed=None,
105
- ):
106
- super().__init__()
107
- self.heads = heads
108
- self.scale = dim_head ** -0.5
109
- dim_inner = heads * dim_head
110
-
111
- self.rotary_embed = rotary_embed
112
- self.pope_embed = pope_embed
113
- assert not (self.rotary_embed is not None and self.pope_embed is not None), \
114
- "cannot have both rotary and pope embeddings"
115
-
116
- self.attend = Attend(flash=flash, dropout=dropout)
117
-
118
- self.norm = RMSNorm(dim)
119
- self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
120
-
121
- self.to_gates = nn.Linear(dim, heads)
122
-
123
- self.to_out = nn.Sequential(
124
- nn.Linear(dim_inner, dim, bias=False),
125
- nn.Dropout(dropout)
126
- )
127
-
128
- def forward(self, x):
129
- x = self.norm(x)
130
-
131
- q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
132
-
133
- if exists(self.pope_embed):
134
- assert _HAS_POPE, "PoPE requested but PoPE_pytorch is not installed"
135
- out = flash_attn_with_pope(
136
- q, k, v,
137
- pos_emb=self.pope_embed(q.shape[-2]),
138
- softmax_scale=self.scale
139
- )
140
- elif exists(self.rotary_embed):
141
- q = self.rotary_embed.rotate_queries_or_keys(q)
142
- k = self.rotary_embed.rotate_queries_or_keys(k)
143
- out = self.attend(q, k, v)
144
- else:
145
- out = self.attend(q, k, v)
146
-
147
- gates = self.to_gates(x)
148
- out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
149
-
150
- out = rearrange(out, 'b h n d -> b n (h d)')
151
- return self.to_out(out)
152
-
153
-
154
- class LinearAttention(Module):
155
- """
156
- this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
157
- """
158
-
159
- @beartype
160
- def __init__(
161
- self,
162
- *,
163
- dim,
164
- dim_head=32,
165
- heads=8,
166
- scale=8,
167
- flash=False,
168
- dropout=0.
169
- ):
170
- super().__init__()
171
- dim_inner = dim_head * heads
172
- self.norm = RMSNorm(dim)
173
-
174
- self.to_qkv = nn.Sequential(
175
- nn.Linear(dim, dim_inner * 3, bias=False),
176
- Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
177
- )
178
-
179
- self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
180
-
181
- self.attend = Attend(
182
- scale=scale,
183
- dropout=dropout,
184
- flash=flash
185
- )
186
-
187
- self.to_out = nn.Sequential(
188
- Rearrange('b h d n -> b n (h d)'),
189
- nn.Linear(dim_inner, dim, bias=False)
190
- )
191
-
192
- def forward(
193
- self,
194
- x
195
- ):
196
- x = self.norm(x)
197
-
198
- q, k, v = self.to_qkv(x)
199
-
200
- q, k = map(l2norm, (q, k))
201
- q = q * self.temperature.exp()
202
-
203
- out = self.attend(q, k, v)
204
-
205
- return self.to_out(out)
206
-
207
-
208
- class Transformer(Module):
209
- def __init__(
210
- self,
211
- *,
212
- dim,
213
- depth,
214
- dim_head=64,
215
- heads=8,
216
- attn_dropout=0.,
217
- ff_dropout=0.,
218
- ff_mult=4,
219
- norm_output=True,
220
- rotary_embed=None,
221
- pope_embed=None,
222
- flash_attn=True,
223
- linear_attn=False,
224
- ):
225
- super().__init__()
226
- self.layers = ModuleList([])
227
-
228
- for _ in range(depth):
229
- if linear_attn:
230
- attn = LinearAttention(
231
- dim=dim,
232
- dim_head=dim_head,
233
- heads=heads,
234
- dropout=attn_dropout,
235
- flash=flash_attn
236
- )
237
- else:
238
- attn = Attention(
239
- dim=dim,
240
- dim_head=dim_head,
241
- heads=heads,
242
- dropout=attn_dropout,
243
- rotary_embed=rotary_embed,
244
- pope_embed=pope_embed,
245
- flash=flash_attn
246
- )
247
-
248
- self.layers.append(ModuleList([
249
- attn,
250
- FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
251
- ]))
252
-
253
- self.norm = RMSNorm(dim) if norm_output else nn.Identity()
254
-
255
- def forward(self, x):
256
-
257
- for attn, ff in self.layers:
258
- x = attn(x) + x
259
- x = ff(x) + x
260
-
261
- return self.norm(x)
262
-
263
-
264
- # bandsplit module
265
-
266
- class BandSplit(Module):
267
- @beartype
268
- def __init__(
269
- self,
270
- dim,
271
- dim_inputs: Tuple[int, ...]
272
- ):
273
- super().__init__()
274
- self.dim_inputs = dim_inputs
275
- self.to_features = ModuleList([])
276
-
277
- for dim_in in dim_inputs:
278
- net = nn.Sequential(
279
- RMSNorm(dim_in),
280
- nn.Linear(dim_in, dim)
281
- )
282
-
283
- self.to_features.append(net)
284
-
285
- def forward(self, x):
286
- x = x.split(self.dim_inputs, dim=-1)
287
-
288
- outs = []
289
- for split_input, to_feature in zip(x, self.to_features):
290
- split_output = to_feature(split_input)
291
- outs.append(split_output)
292
-
293
- return torch.stack(outs, dim=-2)
294
-
295
-
296
- def MLP(
297
- dim_in,
298
- dim_out,
299
- dim_hidden=None,
300
- depth=1,
301
- activation=nn.Tanh
302
- ):
303
- dim_hidden = default(dim_hidden, dim_in)
304
-
305
- net = []
306
- dims = (dim_in, *((dim_hidden,) * depth), dim_out)
307
-
308
- for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
309
- is_last = ind == (len(dims) - 2)
310
-
311
- net.append(nn.Linear(layer_dim_in, layer_dim_out))
312
-
313
- if is_last:
314
- continue
315
-
316
- net.append(activation())
317
-
318
- return nn.Sequential(*net)
319
-
320
-
321
- class MaskEstimator(Module):
322
- @beartype
323
- def __init__(
324
- self,
325
- dim,
326
- dim_inputs: Tuple[int, ...],
327
- depth,
328
- mlp_expansion_factor=4
329
- ):
330
- super().__init__()
331
- self.dim_inputs = dim_inputs
332
- self.to_freqs = ModuleList([])
333
- dim_hidden = dim * mlp_expansion_factor
334
-
335
- for dim_in in dim_inputs:
336
- net = []
337
-
338
- mlp = nn.Sequential(
339
- MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
340
- nn.GLU(dim=-1)
341
- )
342
-
343
- self.to_freqs.append(mlp)
344
-
345
- def forward(self, x):
346
- x = x.unbind(dim=-2)
347
-
348
- outs = []
349
-
350
- for band_features, mlp in zip(x, self.to_freqs):
351
- freq_out = mlp(band_features)
352
- outs.append(freq_out)
353
-
354
- return torch.cat(outs, dim=-1)
355
-
356
-
357
- # main class
358
-
359
- class MelBandRoformer(Module):
360
-
361
- @beartype
362
- def __init__(
363
- self,
364
- dim,
365
- *,
366
- depth,
367
- stereo=False,
368
- num_stems=1,
369
- time_transformer_depth=2,
370
- freq_transformer_depth=2,
371
- linear_transformer_depth=0,
372
- num_bands=60,
373
- dim_head=64,
374
- heads=8,
375
- attn_dropout=0.1,
376
- ff_dropout=0.1,
377
- flash_attn=True,
378
- dim_freqs_in=1025,
379
- sample_rate=44100, # needed for mel filter bank from librosa
380
- stft_n_fft=2048,
381
- stft_hop_length=512,
382
- # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
383
- stft_win_length=2048,
384
- stft_normalized=False,
385
- stft_window_fn: Optional[Callable] = None,
386
- zero_dc = True,
387
- mask_estimator_depth=1,
388
- multi_stft_resolution_loss_weight=1.,
389
- multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
390
- multi_stft_hop_size=147,
391
- multi_stft_normalized=False,
392
- multi_stft_window_fn: Callable = torch.hann_window,
393
- match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
394
- mlp_expansion_factor=4,
395
- use_torch_checkpoint=False,
396
- skip_connection=False,
397
- use_pope: bool = False,
398
- ):
399
- super().__init__()
400
-
401
- self.stereo = stereo
402
- self.audio_channels = 2 if stereo else 1
403
- self.num_stems = num_stems
404
- self.use_torch_checkpoint = use_torch_checkpoint
405
- self.skip_connection = skip_connection
406
-
407
- self.layers = ModuleList([])
408
-
409
- transformer_kwargs = dict(
410
- dim=dim,
411
- heads=heads,
412
- dim_head=dim_head,
413
- attn_dropout=attn_dropout,
414
- ff_dropout=ff_dropout,
415
- flash_attn=flash_attn,
416
- )
417
-
418
- if use_pope:
419
- assert _HAS_POPE, "PoPE requested but PoPE_pytorch is not installed"
420
- time_pope_embed = PoPE(dim=dim_head, heads=heads)
421
- freq_pope_embed = PoPE(dim=dim_head, heads=heads)
422
- time_rotary_embed = None
423
- freq_rotary_embed = None
424
- else:
425
- time_rotary_embed = RotaryEmbedding(dim=dim_head)
426
- freq_rotary_embed = RotaryEmbedding(dim=dim_head)
427
- time_pope_embed = freq_pope_embed = None
428
-
429
- for _ in range(depth):
430
- tran_modules = []
431
- if linear_transformer_depth > 0:
432
- tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
433
- tran_modules.append(
434
- Transformer(
435
- depth=time_transformer_depth,
436
- rotary_embed=time_rotary_embed,
437
- pope_embed=time_pope_embed,
438
- **transformer_kwargs
439
- )
440
- )
441
- tran_modules.append(
442
- Transformer(
443
- depth=freq_transformer_depth,
444
- rotary_embed=freq_rotary_embed,
445
- pope_embed=freq_pope_embed,
446
- **transformer_kwargs
447
- )
448
- )
449
- self.layers.append(nn.ModuleList(tran_modules))
450
-
451
- self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
452
-
453
- self.stft_kwargs = dict(
454
- n_fft=stft_n_fft,
455
- hop_length=stft_hop_length,
456
- win_length=stft_win_length,
457
- normalized=stft_normalized
458
- )
459
-
460
- freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_n_fft), return_complex=True).shape[1]
461
-
462
- # create mel filter bank
463
- # with librosa.filters.mel as in section 2 of paper
464
-
465
- mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands)
466
-
467
- mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
468
-
469
- # for some reason, it doesn't include the first freq? just force a value for now
470
-
471
- mel_filter_bank[0][0] = 1.
472
-
473
- # In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
474
- # so let's force a positive value
475
-
476
- mel_filter_bank[-1, -1] = 1.
477
-
478
- # binary as in paper (then estimated masks are averaged for overlapping regions)
479
-
480
- freqs_per_band = mel_filter_bank > 0
481
- assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now'
482
-
483
- repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands)
484
- freq_indices = repeated_freq_indices[freqs_per_band]
485
-
486
- if stereo:
487
- freq_indices = repeat(freq_indices, 'f -> f s', s=2)
488
- freq_indices = freq_indices * 2 + torch.arange(2)
489
- freq_indices = rearrange(freq_indices, 'f s -> (f s)')
490
-
491
- self.register_buffer('freq_indices', freq_indices, persistent=False)
492
- self.register_buffer('freqs_per_band', freqs_per_band, persistent=False)
493
-
494
- num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum')
495
- num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum')
496
-
497
- self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False)
498
- self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False)
499
-
500
- # band split and mask estimator
501
-
502
- freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
503
-
504
- self.band_split = BandSplit(
505
- dim=dim,
506
- dim_inputs=freqs_per_bands_with_complex
507
- )
508
-
509
- self.mask_estimators = nn.ModuleList([])
510
-
511
- for _ in range(num_stems):
512
- mask_estimator = MaskEstimator(
513
- dim=dim,
514
- dim_inputs=freqs_per_bands_with_complex,
515
- depth=mask_estimator_depth,
516
- mlp_expansion_factor=mlp_expansion_factor,
517
- )
518
-
519
- self.mask_estimators.append(mask_estimator)
520
-
521
- # whether to zero out dc
522
-
523
- self.zero_dc = zero_dc
524
-
525
- # for the multi-resolution stft loss
526
-
527
- self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
528
- self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
529
- self.multi_stft_n_fft = stft_n_fft
530
- self.multi_stft_window_fn = multi_stft_window_fn
531
-
532
- self.multi_stft_kwargs = dict(
533
- hop_length=multi_stft_hop_size,
534
- normalized=multi_stft_normalized
535
- )
536
-
537
- self.match_input_audio_length = match_input_audio_length
538
-
539
- def forward(
540
- self,
541
- raw_audio,
542
- target=None,
543
- active_stem_ids=None,
544
- return_loss_breakdown=False
545
- ):
546
- """
547
- einops
548
-
549
- b - batch
550
- f - freq
551
- t - time
552
- s - audio channel (1 for mono, 2 for stereo)
553
- n - number of 'stems'
554
- c - complex (2)
555
- d - feature dimension
556
- """
557
-
558
- device = raw_audio.device
559
-
560
- if raw_audio.ndim == 2:
561
- raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
562
-
563
- batch, channels, raw_audio_length = raw_audio.shape
564
-
565
- istft_length = raw_audio_length if self.match_input_audio_length else None
566
-
567
- assert (not self.stereo and channels == 1) or (
568
- self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
569
-
570
- # to stft
571
-
572
- raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
573
-
574
- stft_window = self.stft_window_fn(device=device)
575
-
576
- stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
577
- stft_repr = torch.view_as_real(stft_repr)
578
-
579
- stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
580
-
581
- # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
582
- stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
583
-
584
- # index out all frequencies for all frequency ranges across bands ascending in one go
585
-
586
- batch_arange = torch.arange(batch, device=device)[..., None]
587
-
588
- # account for stereo
589
-
590
- x = stft_repr[batch_arange, self.freq_indices]
591
-
592
- # fold the complex (real and imag) into the frequencies dimension
593
-
594
- x = rearrange(x, 'b f t c -> b t (f c)')
595
-
596
- if self.use_torch_checkpoint:
597
- x = checkpoint(self.band_split, x, use_reentrant=False)
598
- else:
599
- x = self.band_split(x)
600
-
601
- # axial / hierarchical attention
602
-
603
- store = [None] * len(self.layers)
604
- for i, transformer_block in enumerate(self.layers):
605
-
606
- if len(transformer_block) == 3:
607
- linear_transformer, time_transformer, freq_transformer = transformer_block
608
-
609
- x, ft_ps = pack([x], 'b * d')
610
- if self.use_torch_checkpoint:
611
- x = checkpoint(linear_transformer, x, use_reentrant=False)
612
- else:
613
- x = linear_transformer(x)
614
- x, = unpack(x, ft_ps, 'b * d')
615
- else:
616
- time_transformer, freq_transformer = transformer_block
617
-
618
- if self.skip_connection:
619
- # Sum all previous
620
- for j in range(i):
621
- x = x + store[j]
622
-
623
- x = rearrange(x, 'b t f d -> b f t d')
624
- x, ps = pack([x], '* t d')
625
-
626
- if self.use_torch_checkpoint:
627
- x = checkpoint(time_transformer, x, use_reentrant=False)
628
- else:
629
- x = time_transformer(x)
630
-
631
- x, = unpack(x, ps, '* t d')
632
- x = rearrange(x, 'b f t d -> b t f d')
633
- x, ps = pack([x], '* f d')
634
-
635
- if self.use_torch_checkpoint:
636
- x = checkpoint(freq_transformer, x, use_reentrant=False)
637
- else:
638
- x = freq_transformer(x)
639
-
640
- x, = unpack(x, ps, '* f d')
641
-
642
- if self.skip_connection:
643
- store[i] = x
644
-
645
- if active_stem_ids is None:
646
- heads = self.mask_estimators
647
- stem_ids = list(range(len(self.mask_estimators)))
648
- else:
649
- heads = [self.mask_estimators[i] for i in active_stem_ids]
650
- stem_ids = active_stem_ids
651
-
652
- num_stems = len(heads)
653
-
654
- if self.use_torch_checkpoint:
655
- masks = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in heads], dim=1)
656
- else:
657
- masks = torch.stack([fn(x) for fn in heads], dim=1)
658
- masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2)
659
-
660
- # modulate frequency representation
661
-
662
- stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
663
-
664
- # complex number multiplication
665
-
666
- stft_repr = torch.view_as_complex(stft_repr)
667
- masks = torch.view_as_complex(masks)
668
-
669
- masks = masks.type(stft_repr.dtype)
670
-
671
- # need to average the estimated mask for the overlapped frequencies
672
-
673
- scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
674
-
675
- stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems)
676
- masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
677
-
678
- denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels)
679
-
680
- masks_averaged = masks_summed / denom.clamp(min=1e-8)
681
-
682
- # modulate stft repr with estimated mask
683
-
684
- stft_repr = stft_repr * masks_averaged
685
-
686
- # istft
687
-
688
- stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
689
-
690
- if self.zero_dc:
691
- # whether to dc filter
692
- stft_repr = stft_repr.index_fill(1, tensor(0, device = device), 0.)
693
-
694
- recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False,
695
- length=istft_length)
696
-
697
- recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems)
698
-
699
- if num_stems == 1:
700
- recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
701
-
702
- # if a target is passed in, calculate loss for learning
703
-
704
- if not exists(target):
705
- return recon_audio
706
-
707
- if self.num_stems > 1:
708
- assert target.ndim == 4 and target.shape[1] == self.num_stems
709
-
710
- if target.ndim == 2:
711
- target = rearrange(target, '... t -> ... 1 t')
712
-
713
- target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
714
-
715
- target_sel = target[:, stem_ids]
716
-
717
- loss = F.l1_loss(recon_audio, target_sel)
718
-
719
- multi_stft_resolution_loss = 0.
720
-
721
- for window_size in self.multi_stft_resolutions_window_sizes:
722
- res_stft_kwargs = dict(
723
- n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
724
- win_length=window_size,
725
- return_complex=True,
726
- window=self.multi_stft_window_fn(window_size, device=device),
727
- **self.multi_stft_kwargs,
728
- )
729
-
730
- recon_Y = torch.stft(
731
- rearrange(recon_audio, 'b n s t -> (b n s) t'),
732
- **res_stft_kwargs
733
- )
734
- target_Y = torch.stft(
735
- rearrange(target_sel, 'b n s t -> (b n s) t'),
736
- **res_stft_kwargs
737
- )
738
-
739
- multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
740
-
741
- weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
742
-
743
- total_loss = loss + weighted_multi_resolution_loss
744
-
745
- if not return_loss_breakdown:
746
- return total_loss
747
-
 
748
  return total_loss, (loss, multi_stft_resolution_loss)
 
1
+ from functools import partial
2
+
3
+ import torch
4
+ from torch import nn, einsum, tensor, Tensor
5
+ from torch.nn import Module, ModuleList
6
+ import torch.nn.functional as F
7
+
8
+ from .attend import Attend
9
+
10
+ from torch.utils.checkpoint import checkpoint
11
+
12
+ from beartype.typing import Tuple, Optional, List, Callable
13
+ from beartype import beartype
14
+
15
+ from rotary_embedding_torch import RotaryEmbedding
16
+
17
+ from einops import rearrange, pack, unpack, reduce, repeat
18
+ from einops.layers.torch import Rearrange
19
+
20
+ from librosa import filters
21
+
22
+ try:
23
+ from .pope.attention import flash_attn_with_pope
24
+ from .pope.pope import PoPE
25
+ _HAS_POPE = True
26
+ except Exception:
27
+ PoPE = None
28
+ flash_attn_with_pope = None
29
+ _HAS_POPE = False
30
+
31
+ # helper functions
32
+
33
+ def exists(val):
34
+ return val is not None
35
+
36
+
37
+ def default(v, d):
38
+ return v if exists(v) else d
39
+
40
+
41
+ def pack_one(t, pattern):
42
+ return pack([t], pattern)
43
+
44
+
45
+ def unpack_one(t, ps, pattern):
46
+ return unpack(t, ps, pattern)[0]
47
+
48
+
49
+ def pad_at_dim(t, pad, dim=-1, value=0.):
50
+ dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
51
+ zeros = ((0, 0) * dims_from_right)
52
+ return F.pad(t, (*zeros, *pad), value=value)
53
+
54
+
55
+ def l2norm(t):
56
+ return F.normalize(t, dim=-1, p=2)
57
+
58
+
59
+ # norm
60
+
61
+ class RMSNorm(Module):
62
+ def __init__(self, dim):
63
+ super().__init__()
64
+ self.scale = dim ** 0.5
65
+ self.gamma = nn.Parameter(torch.ones(dim))
66
+
67
+ def forward(self, x):
68
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
69
+
70
+
71
+ # attention
72
+
73
+ class FeedForward(Module):
74
+ def __init__(
75
+ self,
76
+ dim,
77
+ mult=4,
78
+ dropout=0.
79
+ ):
80
+ super().__init__()
81
+ dim_inner = int(dim * mult)
82
+ self.net = nn.Sequential(
83
+ RMSNorm(dim),
84
+ nn.Linear(dim, dim_inner),
85
+ nn.GELU(),
86
+ nn.Dropout(dropout),
87
+ nn.Linear(dim_inner, dim),
88
+ nn.Dropout(dropout)
89
+ )
90
+
91
+ def forward(self, x):
92
+ return self.net(x)
93
+
94
+
95
+ class Attention(Module):
96
+ def __init__(
97
+ self,
98
+ dim,
99
+ heads=8,
100
+ dim_head=64,
101
+ dropout=0.,
102
+ rotary_embed=None,
103
+ flash=True,
104
+ pope_embed=None,
105
+ ):
106
+ super().__init__()
107
+ self.heads = heads
108
+ self.scale = dim_head ** -0.5
109
+ dim_inner = heads * dim_head
110
+
111
+ self.rotary_embed = rotary_embed
112
+ self.pope_embed = pope_embed
113
+ assert not (self.rotary_embed is not None and self.pope_embed is not None), \
114
+ "cannot have both rotary and pope embeddings"
115
+
116
+ self.attend = Attend(flash=flash, dropout=dropout)
117
+
118
+ self.norm = RMSNorm(dim)
119
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
120
+
121
+ self.to_gates = nn.Linear(dim, heads)
122
+
123
+ self.to_out = nn.Sequential(
124
+ nn.Linear(dim_inner, dim, bias=False),
125
+ nn.Dropout(dropout)
126
+ )
127
+
128
+ def forward(self, x):
129
+ x = self.norm(x)
130
+
131
+ q, k, v = rearrange(self.to_qkv(x), 'b n (qkv h d) -> qkv b h n d', qkv=3, h=self.heads)
132
+
133
+ if exists(self.pope_embed):
134
+ assert _HAS_POPE, "PoPE requested but PoPE_pytorch is not installed"
135
+ out = flash_attn_with_pope(
136
+ q, k, v,
137
+ pos_emb=self.pope_embed(q.shape[-2]),
138
+ softmax_scale=self.scale
139
+ )
140
+ elif exists(self.rotary_embed):
141
+ q = self.rotary_embed.rotate_queries_or_keys(q)
142
+ k = self.rotary_embed.rotate_queries_or_keys(k)
143
+ out = self.attend(q, k, v)
144
+ else:
145
+ out = self.attend(q, k, v)
146
+
147
+ gates = self.to_gates(x)
148
+ out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
149
+
150
+ out = rearrange(out, 'b h n d -> b n (h d)')
151
+ return self.to_out(out)
152
+
153
+
154
+ class LinearAttention(Module):
155
+ """
156
+ this flavor of linear attention proposed in https://arxiv.org/abs/2106.09681 by El-Nouby et al.
157
+ """
158
+
159
+ @beartype
160
+ def __init__(
161
+ self,
162
+ *,
163
+ dim,
164
+ dim_head=32,
165
+ heads=8,
166
+ scale=8,
167
+ flash=False,
168
+ dropout=0.
169
+ ):
170
+ super().__init__()
171
+ dim_inner = dim_head * heads
172
+ self.norm = RMSNorm(dim)
173
+
174
+ self.to_qkv = nn.Sequential(
175
+ nn.Linear(dim, dim_inner * 3, bias=False),
176
+ Rearrange('b n (qkv h d) -> qkv b h d n', qkv=3, h=heads)
177
+ )
178
+
179
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
180
+
181
+ self.attend = Attend(
182
+ scale=scale,
183
+ dropout=dropout,
184
+ flash=flash
185
+ )
186
+
187
+ self.to_out = nn.Sequential(
188
+ Rearrange('b h d n -> b n (h d)'),
189
+ nn.Linear(dim_inner, dim, bias=False)
190
+ )
191
+
192
+ def forward(
193
+ self,
194
+ x
195
+ ):
196
+ x = self.norm(x)
197
+
198
+ q, k, v = self.to_qkv(x)
199
+
200
+ q, k = map(l2norm, (q, k))
201
+ q = q * self.temperature.exp()
202
+
203
+ out = self.attend(q, k, v)
204
+
205
+ return self.to_out(out)
206
+
207
+
208
+ class Transformer(Module):
209
+ def __init__(
210
+ self,
211
+ *,
212
+ dim,
213
+ depth,
214
+ dim_head=64,
215
+ heads=8,
216
+ attn_dropout=0.,
217
+ ff_dropout=0.,
218
+ ff_mult=4,
219
+ norm_output=True,
220
+ rotary_embed=None,
221
+ pope_embed=None,
222
+ flash_attn=True,
223
+ linear_attn=False,
224
+ ):
225
+ super().__init__()
226
+ self.layers = ModuleList([])
227
+
228
+ for _ in range(depth):
229
+ if linear_attn:
230
+ attn = LinearAttention(
231
+ dim=dim,
232
+ dim_head=dim_head,
233
+ heads=heads,
234
+ dropout=attn_dropout,
235
+ flash=flash_attn
236
+ )
237
+ else:
238
+ attn = Attention(
239
+ dim=dim,
240
+ dim_head=dim_head,
241
+ heads=heads,
242
+ dropout=attn_dropout,
243
+ rotary_embed=rotary_embed,
244
+ pope_embed=pope_embed,
245
+ flash=flash_attn
246
+ )
247
+
248
+ self.layers.append(ModuleList([
249
+ attn,
250
+ FeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
251
+ ]))
252
+
253
+ self.norm = RMSNorm(dim) if norm_output else nn.Identity()
254
+
255
+ def forward(self, x):
256
+
257
+ for attn, ff in self.layers:
258
+ x = attn(x) + x
259
+ x = ff(x) + x
260
+
261
+ return self.norm(x)
262
+
263
+
264
+ # bandsplit module
265
+
266
+ class BandSplit(Module):
267
+ @beartype
268
+ def __init__(
269
+ self,
270
+ dim,
271
+ dim_inputs: Tuple[int, ...]
272
+ ):
273
+ super().__init__()
274
+ self.dim_inputs = dim_inputs
275
+ self.to_features = ModuleList([])
276
+
277
+ for dim_in in dim_inputs:
278
+ net = nn.Sequential(
279
+ RMSNorm(dim_in),
280
+ nn.Linear(dim_in, dim)
281
+ )
282
+
283
+ self.to_features.append(net)
284
+
285
+ def forward(self, x):
286
+ x = x.split(self.dim_inputs, dim=-1)
287
+
288
+ outs = []
289
+ for split_input, to_feature in zip(x, self.to_features):
290
+ split_output = to_feature(split_input)
291
+ outs.append(split_output)
292
+
293
+ return torch.stack(outs, dim=-2)
294
+
295
+
296
+ def MLP(
297
+ dim_in,
298
+ dim_out,
299
+ dim_hidden=None,
300
+ depth=1,
301
+ activation=nn.Tanh
302
+ ):
303
+ dim_hidden = default(dim_hidden, dim_in)
304
+
305
+ net = []
306
+ dims = (dim_in, *((dim_hidden,) * depth), dim_out)
307
+
308
+ for ind, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
309
+ is_last = ind == (len(dims) - 2)
310
+
311
+ net.append(nn.Linear(layer_dim_in, layer_dim_out))
312
+
313
+ if is_last:
314
+ continue
315
+
316
+ net.append(activation())
317
+
318
+ return nn.Sequential(*net)
319
+
320
+
321
+ class MaskEstimator(Module):
322
+ @beartype
323
+ def __init__(
324
+ self,
325
+ dim,
326
+ dim_inputs: Tuple[int, ...],
327
+ depth,
328
+ mlp_expansion_factor=4
329
+ ):
330
+ super().__init__()
331
+ self.dim_inputs = dim_inputs
332
+ self.to_freqs = ModuleList([])
333
+ dim_hidden = dim * mlp_expansion_factor
334
+
335
+ for dim_in in dim_inputs:
336
+ net = []
337
+
338
+ mlp = nn.Sequential(
339
+ MLP(dim, dim_in * 2, dim_hidden=dim_hidden, depth=depth),
340
+ nn.GLU(dim=-1)
341
+ )
342
+
343
+ self.to_freqs.append(mlp)
344
+
345
+ def forward(self, x):
346
+ x = x.unbind(dim=-2)
347
+
348
+ outs = []
349
+
350
+ for band_features, mlp in zip(x, self.to_freqs):
351
+ freq_out = mlp(band_features)
352
+ outs.append(freq_out)
353
+
354
+ return torch.cat(outs, dim=-1)
355
+
356
+
357
+ # main class
358
+
359
+ class MelBandRoformer(Module):
360
+
361
+ @beartype
362
+ def __init__(
363
+ self,
364
+ dim,
365
+ *,
366
+ depth,
367
+ stereo=False,
368
+ num_stems=1,
369
+ time_transformer_depth=2,
370
+ freq_transformer_depth=2,
371
+ linear_transformer_depth=0,
372
+ num_bands=60,
373
+ dim_head=64,
374
+ heads=8,
375
+ attn_dropout=0.1,
376
+ ff_dropout=0.1,
377
+ flash_attn=True,
378
+ dim_freqs_in=1025,
379
+ sample_rate=44100, # needed for mel filter bank from librosa
380
+ stft_n_fft=2048,
381
+ stft_hop_length=512,
382
+ # 10ms at 44100Hz, from sections 4.1, 4.4 in the paper - @faroit recommends // 2 or // 4 for better reconstruction
383
+ stft_win_length=2048,
384
+ stft_normalized=False,
385
+ stft_window_fn: Optional[Callable] = None,
386
+ zero_dc = True,
387
+ mask_estimator_depth=1,
388
+ multi_stft_resolution_loss_weight=1.,
389
+ multi_stft_resolutions_window_sizes: Tuple[int, ...] = (4096, 2048, 1024, 512, 256),
390
+ multi_stft_hop_size=147,
391
+ multi_stft_normalized=False,
392
+ multi_stft_window_fn: Callable = torch.hann_window,
393
+ match_input_audio_length=False, # if True, pad output tensor to match length of input tensor
394
+ mlp_expansion_factor=4,
395
+ use_torch_checkpoint=False,
396
+ skip_connection=False,
397
+ use_pope: bool = False,
398
+ **kwargs
399
+ ):
400
+ super().__init__()
401
+
402
+ self.stereo = stereo
403
+ self.audio_channels = 2 if stereo else 1
404
+ self.num_stems = num_stems
405
+ self.use_torch_checkpoint = use_torch_checkpoint
406
+ self.skip_connection = skip_connection
407
+
408
+ self.layers = ModuleList([])
409
+
410
+ transformer_kwargs = dict(
411
+ dim=dim,
412
+ heads=heads,
413
+ dim_head=dim_head,
414
+ attn_dropout=attn_dropout,
415
+ ff_dropout=ff_dropout,
416
+ flash_attn=flash_attn,
417
+ )
418
+
419
+ if use_pope:
420
+ assert _HAS_POPE, "PoPE requested but PoPE_pytorch is not installed"
421
+ time_pope_embed = PoPE(dim=dim_head, heads=heads)
422
+ freq_pope_embed = PoPE(dim=dim_head, heads=heads)
423
+ time_rotary_embed = None
424
+ freq_rotary_embed = None
425
+ else:
426
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
427
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
428
+ time_pope_embed = freq_pope_embed = None
429
+
430
+ for _ in range(depth):
431
+ tran_modules = []
432
+ if linear_transformer_depth > 0:
433
+ tran_modules.append(Transformer(depth=linear_transformer_depth, linear_attn=True, **transformer_kwargs))
434
+ tran_modules.append(
435
+ Transformer(
436
+ depth=time_transformer_depth,
437
+ rotary_embed=time_rotary_embed,
438
+ pope_embed=time_pope_embed,
439
+ **transformer_kwargs
440
+ )
441
+ )
442
+ tran_modules.append(
443
+ Transformer(
444
+ depth=freq_transformer_depth,
445
+ rotary_embed=freq_rotary_embed,
446
+ pope_embed=freq_pope_embed,
447
+ **transformer_kwargs
448
+ )
449
+ )
450
+ self.layers.append(nn.ModuleList(tran_modules))
451
+
452
+ self.stft_window_fn = partial(default(stft_window_fn, torch.hann_window), stft_win_length)
453
+
454
+ self.stft_kwargs = dict(
455
+ n_fft=stft_n_fft,
456
+ hop_length=stft_hop_length,
457
+ win_length=stft_win_length,
458
+ normalized=stft_normalized
459
+ )
460
+
461
+ freqs = torch.stft(torch.randn(1, 4096), **self.stft_kwargs, window=torch.ones(stft_n_fft), return_complex=True).shape[1]
462
+
463
+ # create mel filter bank
464
+ # with librosa.filters.mel as in section 2 of paper
465
+
466
+ mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands)
467
+
468
+ mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)
469
+
470
+ # for some reason, it doesn't include the first freq? just force a value for now
471
+
472
+ mel_filter_bank[0][0] = 1.
473
+
474
+ # In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
475
+ # so let's force a positive value
476
+
477
+ mel_filter_bank[-1, -1] = 1.
478
+
479
+ # binary as in paper (then estimated masks are averaged for overlapping regions)
480
+
481
+ freqs_per_band = mel_filter_bank > 0
482
+ assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now'
483
+
484
+ repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands)
485
+ freq_indices = repeated_freq_indices[freqs_per_band]
486
+
487
+ if stereo:
488
+ freq_indices = repeat(freq_indices, 'f -> f s', s=2)
489
+ freq_indices = freq_indices * 2 + torch.arange(2)
490
+ freq_indices = rearrange(freq_indices, 'f s -> (f s)')
491
+
492
+ self.register_buffer('freq_indices', freq_indices, persistent=False)
493
+ self.register_buffer('freqs_per_band', freqs_per_band, persistent=False)
494
+
495
+ num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum')
496
+ num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum')
497
+
498
+ self.register_buffer('num_freqs_per_band', num_freqs_per_band, persistent=False)
499
+ self.register_buffer('num_bands_per_freq', num_bands_per_freq, persistent=False)
500
+
501
+ # band split and mask estimator
502
+
503
+ freqs_per_bands_with_complex = tuple(2 * f * self.audio_channels for f in num_freqs_per_band.tolist())
504
+
505
+ self.band_split = BandSplit(
506
+ dim=dim,
507
+ dim_inputs=freqs_per_bands_with_complex
508
+ )
509
+
510
+ self.mask_estimators = nn.ModuleList([])
511
+
512
+ for _ in range(num_stems):
513
+ mask_estimator = MaskEstimator(
514
+ dim=dim,
515
+ dim_inputs=freqs_per_bands_with_complex,
516
+ depth=mask_estimator_depth,
517
+ mlp_expansion_factor=mlp_expansion_factor,
518
+ )
519
+
520
+ self.mask_estimators.append(mask_estimator)
521
+
522
+ # whether to zero out dc
523
+
524
+ self.zero_dc = zero_dc
525
+
526
+ # for the multi-resolution stft loss
527
+
528
+ self.multi_stft_resolution_loss_weight = multi_stft_resolution_loss_weight
529
+ self.multi_stft_resolutions_window_sizes = multi_stft_resolutions_window_sizes
530
+ self.multi_stft_n_fft = stft_n_fft
531
+ self.multi_stft_window_fn = multi_stft_window_fn
532
+
533
+ self.multi_stft_kwargs = dict(
534
+ hop_length=multi_stft_hop_size,
535
+ normalized=multi_stft_normalized
536
+ )
537
+
538
+ self.match_input_audio_length = match_input_audio_length
539
+
540
+ def forward(
541
+ self,
542
+ raw_audio,
543
+ target=None,
544
+ active_stem_ids=None,
545
+ return_loss_breakdown=False
546
+ ):
547
+ """
548
+ einops
549
+
550
+ b - batch
551
+ f - freq
552
+ t - time
553
+ s - audio channel (1 for mono, 2 for stereo)
554
+ n - number of 'stems'
555
+ c - complex (2)
556
+ d - feature dimension
557
+ """
558
+
559
+ device = raw_audio.device
560
+
561
+ if raw_audio.ndim == 2:
562
+ raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
563
+
564
+ batch, channels, raw_audio_length = raw_audio.shape
565
+
566
+ istft_length = raw_audio_length if self.match_input_audio_length else None
567
+
568
+ assert (not self.stereo and channels == 1) or (
569
+ self.stereo and channels == 2), 'stereo needs to be set to True if passing in audio signal that is stereo (channel dimension of 2). also need to be False if mono (channel dimension of 1)'
570
+
571
+ # to stft
572
+
573
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
574
+
575
+ stft_window = self.stft_window_fn(device=device)
576
+
577
+ stft_repr = torch.stft(raw_audio, **self.stft_kwargs, window=stft_window, return_complex=True)
578
+ stft_repr = torch.view_as_real(stft_repr)
579
+
580
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
581
+
582
+ # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
583
+ stft_repr = rearrange(stft_repr,'b s f t c -> b (f s) t c')
584
+
585
+ # index out all frequencies for all frequency ranges across bands ascending in one go
586
+
587
+ batch_arange = torch.arange(batch, device=device)[..., None]
588
+
589
+ # account for stereo
590
+
591
+ x = stft_repr[batch_arange, self.freq_indices]
592
+
593
+ # fold the complex (real and imag) into the frequencies dimension
594
+
595
+ x = rearrange(x, 'b f t c -> b t (f c)')
596
+
597
+ if self.use_torch_checkpoint:
598
+ x = checkpoint(self.band_split, x, use_reentrant=False)
599
+ else:
600
+ x = self.band_split(x)
601
+
602
+ # axial / hierarchical attention
603
+
604
+ store = [None] * len(self.layers)
605
+ for i, transformer_block in enumerate(self.layers):
606
+
607
+ if len(transformer_block) == 3:
608
+ linear_transformer, time_transformer, freq_transformer = transformer_block
609
+
610
+ x, ft_ps = pack([x], 'b * d')
611
+ if self.use_torch_checkpoint:
612
+ x = checkpoint(linear_transformer, x, use_reentrant=False)
613
+ else:
614
+ x = linear_transformer(x)
615
+ x, = unpack(x, ft_ps, 'b * d')
616
+ else:
617
+ time_transformer, freq_transformer = transformer_block
618
+
619
+ if self.skip_connection:
620
+ # Sum all previous
621
+ for j in range(i):
622
+ x = x + store[j]
623
+
624
+ x = rearrange(x, 'b t f d -> b f t d')
625
+ x, ps = pack([x], '* t d')
626
+
627
+ if self.use_torch_checkpoint:
628
+ x = checkpoint(time_transformer, x, use_reentrant=False)
629
+ else:
630
+ x = time_transformer(x)
631
+
632
+ x, = unpack(x, ps, '* t d')
633
+ x = rearrange(x, 'b f t d -> b t f d')
634
+ x, ps = pack([x], '* f d')
635
+
636
+ if self.use_torch_checkpoint:
637
+ x = checkpoint(freq_transformer, x, use_reentrant=False)
638
+ else:
639
+ x = freq_transformer(x)
640
+
641
+ x, = unpack(x, ps, '* f d')
642
+
643
+ if self.skip_connection:
644
+ store[i] = x
645
+
646
+ if active_stem_ids is None:
647
+ heads = self.mask_estimators
648
+ stem_ids = list(range(len(self.mask_estimators)))
649
+ else:
650
+ heads = [self.mask_estimators[i] for i in active_stem_ids]
651
+ stem_ids = active_stem_ids
652
+
653
+ num_stems = len(heads)
654
+
655
+ if self.use_torch_checkpoint:
656
+ masks = torch.stack([checkpoint(fn, x, use_reentrant=False) for fn in heads], dim=1)
657
+ else:
658
+ masks = torch.stack([fn(x) for fn in heads], dim=1)
659
+ masks = rearrange(masks, 'b n t (f c) -> b n f t c', c=2)
660
+
661
+ # modulate frequency representation
662
+
663
+ stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
664
+
665
+ # complex number multiplication
666
+
667
+ stft_repr = torch.view_as_complex(stft_repr)
668
+ masks = torch.view_as_complex(masks)
669
+
670
+ masks = masks.type(stft_repr.dtype)
671
+
672
+ # need to average the estimated mask for the overlapped frequencies
673
+
674
+ scatter_indices = repeat(self.freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
675
+
676
+ stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems)
677
+ masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)
678
+
679
+ denom = repeat(self.num_bands_per_freq, 'f -> (f r) 1', r=channels)
680
+
681
+ masks_averaged = masks_summed / denom.clamp(min=1e-8)
682
+
683
+ # modulate stft repr with estimated mask
684
+
685
+ stft_repr = stft_repr * masks_averaged
686
+
687
+ # istft
688
+
689
+ stft_repr = rearrange(stft_repr, 'b n (f s) t -> (b n s) f t', s=self.audio_channels)
690
+
691
+ if self.zero_dc:
692
+ # whether to dc filter
693
+ stft_repr = stft_repr.index_fill(1, tensor(0, device = device), 0.)
694
+
695
+ recon_audio = torch.istft(stft_repr, **self.stft_kwargs, window=stft_window, return_complex=False,
696
+ length=istft_length)
697
+
698
+ recon_audio = rearrange(recon_audio, '(b n s) t -> b n s t', b=batch, s=self.audio_channels, n=num_stems)
699
+
700
+ if num_stems == 1:
701
+ recon_audio = rearrange(recon_audio, 'b 1 s t -> b s t')
702
+
703
+ # if a target is passed in, calculate loss for learning
704
+
705
+ if not exists(target):
706
+ return recon_audio
707
+
708
+ if self.num_stems > 1:
709
+ assert target.ndim == 4 and target.shape[1] == self.num_stems
710
+
711
+ if target.ndim == 2:
712
+ target = rearrange(target, '... t -> ... 1 t')
713
+
714
+ target = target[..., :recon_audio.shape[-1]] # protect against lost length on istft
715
+
716
+ target_sel = target[:, stem_ids]
717
+
718
+ loss = F.l1_loss(recon_audio, target_sel)
719
+
720
+ multi_stft_resolution_loss = 0.
721
+
722
+ for window_size in self.multi_stft_resolutions_window_sizes:
723
+ res_stft_kwargs = dict(
724
+ n_fft=max(window_size, self.multi_stft_n_fft), # not sure what n_fft is across multi resolution stft
725
+ win_length=window_size,
726
+ return_complex=True,
727
+ window=self.multi_stft_window_fn(window_size, device=device),
728
+ **self.multi_stft_kwargs,
729
+ )
730
+
731
+ recon_Y = torch.stft(
732
+ rearrange(recon_audio, 'b n s t -> (b n s) t'),
733
+ **res_stft_kwargs
734
+ )
735
+ target_Y = torch.stft(
736
+ rearrange(target_sel, 'b n s t -> (b n s) t'),
737
+ **res_stft_kwargs
738
+ )
739
+
740
+ multi_stft_resolution_loss = multi_stft_resolution_loss + F.l1_loss(recon_Y, target_Y)
741
+
742
+ weighted_multi_resolution_loss = multi_stft_resolution_loss * self.multi_stft_resolution_loss_weight
743
+
744
+ total_loss = loss + weighted_multi_resolution_loss
745
+
746
+ if not return_loss_breakdown:
747
+ return total_loss
748
+
749
  return total_loss, (loss, multi_stft_resolution_loss)