javier233455 commited on
Commit
865175e
·
verified ·
1 Parent(s): 62f44c9

Update music_dcae/music_dcae_pipeline.py

Browse files
Files changed (1) hide show
  1. music_dcae/music_dcae_pipeline.py +133 -20
music_dcae/music_dcae_pipeline.py CHANGED
@@ -21,7 +21,12 @@ VOCODER_PRETRAINED_PATH = os.path.join(root_dir, "checkpoints", "music_vocoder")
21
 
22
  class MusicDCAE(ModelMixin, ConfigMixin, FromOriginalModelMixin):
23
  @register_to_config
24
- def __init__(self, source_sample_rate=None, dcae_checkpoint_path=DEFAULT_PRETRAINED_PATH, vocoder_checkpoint_path=VOCODER_PRETRAINED_PATH):
 
 
 
 
 
25
  super(MusicDCAE, self).__init__()
26
 
27
  self.dcae = AutoencoderDC.from_pretrained(dcae_checkpoint_path)
@@ -35,6 +40,7 @@ class MusicDCAE(ModelMixin, ConfigMixin, FromOriginalModelMixin):
35
  self.transform = transforms.Compose([
36
  transforms.Normalize(0.5, 0.5),
37
  ])
 
38
  self.min_mel_value = -11.0
39
  self.max_mel_value = 3.0
40
  self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000)))
@@ -46,48 +52,128 @@ class MusicDCAE(ModelMixin, ConfigMixin, FromOriginalModelMixin):
46
 
47
  def load_audio(self, audio_path):
48
  audio, sr = torchaudio.load(audio_path)
 
 
 
 
 
 
 
 
 
 
49
  return audio, sr
50
 
51
  def forward_mel(self, audios):
52
  mels = []
 
53
  for i in range(len(audios)):
54
- image = self.vocoder.mel_transform(audios[i])
 
 
 
 
 
 
 
 
 
 
 
55
  mels.append(image)
 
56
  mels = torch.stack(mels)
57
  return mels
58
 
59
  @torch.no_grad()
60
  def encode(self, audios, audio_lengths=None, sr=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  if audio_lengths is None:
62
  audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0])
63
  audio_lengths = audio_lengths.to(audios.device)
64
 
65
- # audios: N x 2 x T, 48kHz
66
  device = audios.device
67
  dtype = audios.dtype
68
 
69
  if sr is None:
70
  sr = 48000
71
- resampler = self.resampler
72
  else:
73
  resampler = torchaudio.transforms.Resample(sr, 44100).to(device).to(dtype)
74
 
75
  audio = resampler(audios)
76
 
 
 
 
 
 
 
77
  max_audio_len = audio.shape[-1]
 
78
  if max_audio_len % (8 * 512) != 0:
79
- audio = torch.nn.functional.pad(audio, (0, 8 * 512 - max_audio_len % (8 * 512)))
 
 
 
80
 
81
  mels = self.forward_mel(audio)
82
- mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value)
 
 
 
83
  mels = self.transform(mels)
 
84
  latents = []
 
85
  for mel in mels:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  latent = self.dcae.encoder(mel.unsqueeze(0))
87
  latents.append(latent)
 
88
  latents = torch.cat(latents, dim=0)
89
- latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
 
 
 
 
90
  latents = (latents - self.shift_factor) * self.scale_factor
 
91
  return latents, latent_lengths
92
 
93
  @torch.no_grad()
@@ -99,43 +185,70 @@ class MusicDCAE(ModelMixin, ConfigMixin, FromOriginalModelMixin):
99
  for latent in latents:
100
  mels = self.dcae.decoder(latent.unsqueeze(0))
101
  mels = mels * 0.5 + 0.5
102
- mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
 
 
 
103
  wav = self.vocoder.decode(mels[0]).squeeze(1)
104
 
105
  if sr is not None:
106
- resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
 
 
 
107
  wav = resampler(wav)
108
  else:
109
  sr = 44100
 
110
  pred_wavs.append(wav)
111
 
112
  if audio_lengths is not None:
113
- pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
 
 
 
 
114
  return sr, pred_wavs
115
 
116
  def forward(self, audios, audio_lengths=None, sr=None):
117
- latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
118
- sr, pred_wavs = self.decode(latents=latents, audio_lengths=audio_lengths, sr=sr)
 
 
 
 
 
 
 
 
 
 
119
  return sr, pred_wavs, latents, latent_lengths
120
 
121
 
122
  if __name__ == "__main__":
123
-
124
  audio, sr = torchaudio.load("test.wav")
 
 
 
 
 
 
 
 
 
 
125
  audio_lengths = torch.tensor([audio.shape[1]])
126
  audios = audio.unsqueeze(0)
127
-
128
- # test encode only
129
  model = MusicDCAE()
130
- # latents, latent_lengths = model.encode(audios, audio_lengths)
131
- # print("latents shape: ", latents.shape)
132
- # print("latent_lengths: ", latent_lengths)
133
 
134
- # test encode and decode
135
  sr, pred_wavs, latents, latent_lengths = model(audios, audio_lengths, sr)
 
136
  print("reconstructed wavs: ", pred_wavs[0].shape)
137
  print("latents shape: ", latents.shape)
138
  print("latent_lengths: ", latent_lengths)
139
  print("sr: ", sr)
 
140
  torchaudio.save("test_reconstructed.flac", pred_wavs[0], sr)
141
- print("test_reconstructed.flac")
 
21
 
22
  class MusicDCAE(ModelMixin, ConfigMixin, FromOriginalModelMixin):
23
  @register_to_config
24
+ def __init__(
25
+ self,
26
+ source_sample_rate=None,
27
+ dcae_checkpoint_path=DEFAULT_PRETRAINED_PATH,
28
+ vocoder_checkpoint_path=VOCODER_PRETRAINED_PATH
29
+ ):
30
  super(MusicDCAE, self).__init__()
31
 
32
  self.dcae = AutoencoderDC.from_pretrained(dcae_checkpoint_path)
 
40
  self.transform = transforms.Compose([
41
  transforms.Normalize(0.5, 0.5),
42
  ])
43
+
44
  self.min_mel_value = -11.0
45
  self.max_mel_value = 3.0
46
  self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000)))
 
52
 
53
  def load_audio(self, audio_path):
54
  audio, sr = torchaudio.load(audio_path)
55
+
56
+ # FIX: si el audio está en mono, duplicarlo a estéreo
57
+ if audio.dim() == 1:
58
+ audio = audio.unsqueeze(0)
59
+
60
+ if audio.shape[0] == 1:
61
+ audio = audio.repeat(2, 1)
62
+ elif audio.shape[0] > 2:
63
+ audio = audio[:2]
64
+
65
  return audio, sr
66
 
67
  def forward_mel(self, audios):
68
  mels = []
69
+
70
  for i in range(len(audios)):
71
+ audio_item = audios[i]
72
+
73
+ # FIX: asegurar audio estéreo antes de convertir a mel
74
+ if audio_item.dim() == 1:
75
+ audio_item = audio_item.unsqueeze(0)
76
+
77
+ if audio_item.shape[0] == 1:
78
+ audio_item = audio_item.repeat(2, 1)
79
+ elif audio_item.shape[0] > 2:
80
+ audio_item = audio_item[:2]
81
+
82
+ image = self.vocoder.mel_transform(audio_item)
83
  mels.append(image)
84
+
85
  mels = torch.stack(mels)
86
  return mels
87
 
88
  @torch.no_grad()
89
  def encode(self, audios, audio_lengths=None, sr=None):
90
+ # ============================================================
91
+ # FIX PRINCIPAL:
92
+ # ACE-Step / MusicDCAE espera audios con forma N x 2 x T.
93
+ # Si llega mono N x 1 x T, se duplica el canal.
94
+ # ============================================================
95
+
96
+ if audios.dim() == 1:
97
+ # T -> 1 x 1 x T
98
+ audios = audios.unsqueeze(0).unsqueeze(0)
99
+
100
+ elif audios.dim() == 2:
101
+ # Puede venir como C x T
102
+ audios = audios.unsqueeze(0)
103
+
104
+ if audios.shape[1] == 1:
105
+ # N x 1 x T -> N x 2 x T
106
+ audios = audios.repeat(1, 2, 1)
107
+
108
+ elif audios.shape[1] > 2:
109
+ # Si tiene más de 2 canales, usar solo los dos primeros
110
+ audios = audios[:, :2, :]
111
+
112
  if audio_lengths is None:
113
  audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0])
114
  audio_lengths = audio_lengths.to(audios.device)
115
 
116
+ # audios: N x 2 x T
117
  device = audios.device
118
  dtype = audios.dtype
119
 
120
  if sr is None:
121
  sr = 48000
122
+ resampler = self.resampler.to(device).to(dtype)
123
  else:
124
  resampler = torchaudio.transforms.Resample(sr, 44100).to(device).to(dtype)
125
 
126
  audio = resampler(audios)
127
 
128
+ # FIX extra después del resample
129
+ if audio.shape[1] == 1:
130
+ audio = audio.repeat(1, 2, 1)
131
+ elif audio.shape[1] > 2:
132
+ audio = audio[:, :2, :]
133
+
134
  max_audio_len = audio.shape[-1]
135
+
136
  if max_audio_len % (8 * 512) != 0:
137
+ audio = torch.nn.functional.pad(
138
+ audio,
139
+ (0, 8 * 512 - max_audio_len % (8 * 512))
140
+ )
141
 
142
  mels = self.forward_mel(audio)
143
+
144
+ mels = (mels - self.min_mel_value) / (
145
+ self.max_mel_value - self.min_mel_value
146
+ )
147
  mels = self.transform(mels)
148
+
149
  latents = []
150
+
151
  for mel in mels:
152
+ # ========================================================
153
+ # FIX FINAL:
154
+ # El encoder espera mel con 2 canales.
155
+ # Si mel viene como 1 x 128 x T, convertir a 2 x 128 x T.
156
+ # ========================================================
157
+
158
+ if mel.dim() == 2:
159
+ mel = mel.unsqueeze(0)
160
+
161
+ if mel.shape[0] == 1:
162
+ mel = mel.repeat(2, 1, 1)
163
+ elif mel.shape[0] > 2:
164
+ mel = mel[:2]
165
+
166
  latent = self.dcae.encoder(mel.unsqueeze(0))
167
  latents.append(latent)
168
+
169
  latents = torch.cat(latents, dim=0)
170
+
171
+ latent_lengths = (
172
+ audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple
173
+ ).long()
174
+
175
  latents = (latents - self.shift_factor) * self.scale_factor
176
+
177
  return latents, latent_lengths
178
 
179
  @torch.no_grad()
 
185
  for latent in latents:
186
  mels = self.dcae.decoder(latent.unsqueeze(0))
187
  mels = mels * 0.5 + 0.5
188
+ mels = mels * (
189
+ self.max_mel_value - self.min_mel_value
190
+ ) + self.min_mel_value
191
+
192
  wav = self.vocoder.decode(mels[0]).squeeze(1)
193
 
194
  if sr is not None:
195
+ resampler = torchaudio.transforms.Resample(
196
+ 44100,
197
+ sr
198
+ ).to(latents.device).to(latents.dtype)
199
  wav = resampler(wav)
200
  else:
201
  sr = 44100
202
+
203
  pred_wavs.append(wav)
204
 
205
  if audio_lengths is not None:
206
+ pred_wavs = [
207
+ wav[:, :length].cpu()
208
+ for wav, length in zip(pred_wavs, audio_lengths)
209
+ ]
210
+
211
  return sr, pred_wavs
212
 
213
  def forward(self, audios, audio_lengths=None, sr=None):
214
+ latents, latent_lengths = self.encode(
215
+ audios=audios,
216
+ audio_lengths=audio_lengths,
217
+ sr=sr
218
+ )
219
+
220
+ sr, pred_wavs = self.decode(
221
+ latents=latents,
222
+ audio_lengths=audio_lengths,
223
+ sr=sr
224
+ )
225
+
226
  return sr, pred_wavs, latents, latent_lengths
227
 
228
 
229
  if __name__ == "__main__":
 
230
  audio, sr = torchaudio.load("test.wav")
231
+
232
+ # FIX para prueba local con audio mono
233
+ if audio.dim() == 1:
234
+ audio = audio.unsqueeze(0)
235
+
236
+ if audio.shape[0] == 1:
237
+ audio = audio.repeat(2, 1)
238
+ elif audio.shape[0] > 2:
239
+ audio = audio[:2]
240
+
241
  audio_lengths = torch.tensor([audio.shape[1]])
242
  audios = audio.unsqueeze(0)
243
+
 
244
  model = MusicDCAE()
 
 
 
245
 
 
246
  sr, pred_wavs, latents, latent_lengths = model(audios, audio_lengths, sr)
247
+
248
  print("reconstructed wavs: ", pred_wavs[0].shape)
249
  print("latents shape: ", latents.shape)
250
  print("latent_lengths: ", latent_lengths)
251
  print("sr: ", sr)
252
+
253
  torchaudio.save("test_reconstructed.flac", pred_wavs[0], sr)
254
+ print("test_reconstructed.flac")