ThreadAbort commited on
Commit
4f0353a
·
1 Parent(s): fa099e3

refactor: update audio loading and saving methods to use torchcodec

Browse files
music_dcae/music_dcae_pipeline.py CHANGED
@@ -45,7 +45,7 @@ class MusicDCAE(ModelMixin, ConfigMixin, FromOriginalModelMixin):
45
  self.shift_factor = -1.9091
46
 
47
  def load_audio(self, audio_path):
48
- audio, sr = torchaudio.load(audio_path)
49
  return audio, sr
50
 
51
  def forward_mel(self, audios):
@@ -121,7 +121,7 @@ class MusicDCAE(ModelMixin, ConfigMixin, FromOriginalModelMixin):
121
 
122
  if __name__ == "__main__":
123
 
124
- audio, sr = torchaudio.load("test.wav")
125
  audio_lengths = torch.tensor([audio.shape[1]])
126
  audios = audio.unsqueeze(0)
127
 
@@ -137,5 +137,5 @@ if __name__ == "__main__":
137
  print("latents shape: ", latents.shape)
138
  print("latent_lengths: ", latent_lengths)
139
  print("sr: ", sr)
140
- torchaudio.save("test_reconstructed.flac", pred_wavs[0], sr)
141
  print("test_reconstructed.flac")
 
45
  self.shift_factor = -1.9091
46
 
47
  def load_audio(self, audio_path):
48
+ audio, sr = torchaudio.load_with_torchcodec(audio_path)
49
  return audio, sr
50
 
51
  def forward_mel(self, audios):
 
121
 
122
  if __name__ == "__main__":
123
 
124
+ audio, sr = torchaudio.load_with_torchcodec("test.wav")
125
  audio_lengths = torch.tensor([audio.shape[1]])
126
  audios = audio.unsqueeze(0)
127
 
 
137
  print("latents shape: ", latents.shape)
138
  print("latent_lengths: ", latent_lengths)
139
  print("sr: ", sr)
140
+ torchaudio.save_with_torchcodec("test_reconstructed.flac", pred_wavs[0], sr)
141
  print("test_reconstructed.flac")
pipeline_ace_step.py CHANGED
@@ -36,7 +36,6 @@ from apg_guidance import (
36
  cfg_double_condition_forward,
37
  )
38
  import torchaudio
39
- import torio
40
 
41
 
42
  torch.backends.cudnn.benchmark = False
@@ -1428,12 +1427,11 @@ class ACEStepPipeline:
1428
  f"{base_path}/output_{time.strftime('%Y%m%d%H%M%S')}_{idx}.{format}"
1429
  )
1430
  target_wav = target_wav.float()
1431
- torchaudio.save(
1432
  output_path_flac,
1433
  target_wav,
1434
  sample_rate=sample_rate,
1435
- format=format,
1436
- compression=torio.io.CodecConfig(bit_rate=320000),
1437
  )
1438
  return output_path_flac
1439
 
 
36
  cfg_double_condition_forward,
37
  )
38
  import torchaudio
 
39
 
40
 
41
  torch.backends.cudnn.benchmark = False
 
1427
  f"{base_path}/output_{time.strftime('%Y%m%d%H%M%S')}_{idx}.{format}"
1428
  )
1429
  target_wav = target_wav.float()
1430
+ torchaudio.save_with_torchcodec(
1431
  output_path_flac,
1432
  target_wav,
1433
  sample_rate=sample_rate,
1434
+ compression=320000,
 
1435
  )
1436
  return output_path_flac
1437
 
requirements.txt CHANGED
@@ -10,6 +10,7 @@ pytorch_lightning==2.5.1
10
  soundfile==0.13.1
11
  torch==2.8.0
12
  torchaudio==2.8.0
 
13
  torchvision==0.23.0
14
  tqdm
15
  transformers==4.50.0
 
10
  soundfile==0.13.1
11
  torch==2.8.0
12
  torchaudio==2.8.0
13
+ torchcodec>=0.2
14
  torchvision==0.23.0
15
  tqdm
16
  transformers==4.50.0