| | |
| | |
| | |
| | |
| | |
| |
|
| | from itertools import product |
| | import random |
| |
|
| | import numpy as np |
| | import torch |
| | import torchaudio |
| |
|
| | from audiocraft.data.audio import audio_info, audio_read, audio_write, _av_read |
| |
|
| | from ..common_utils import TempDirMixin, get_white_noise, save_wav |
| |
|
| |
|
| | class TestInfo(TempDirMixin): |
| |
|
| | def test_info_mp3(self): |
| | sample_rates = [8000, 16_000] |
| | channels = [1, 2] |
| | duration = 1. |
| | for sample_rate, ch in product(sample_rates, channels): |
| | wav = get_white_noise(ch, int(sample_rate * duration)) |
| | path = self.get_temp_path('sample_wav.mp3') |
| | save_wav(path, wav, sample_rate) |
| | info = audio_info(path) |
| | assert info.sample_rate == sample_rate |
| | assert info.channels == ch |
| | |
| |
|
| | def _test_info_format(self, ext: str): |
| | sample_rates = [8000, 16_000] |
| | channels = [1, 2] |
| | duration = 1. |
| | for sample_rate, ch in product(sample_rates, channels): |
| | n_frames = int(sample_rate * duration) |
| | wav = get_white_noise(ch, n_frames) |
| | path = self.get_temp_path(f'sample_wav{ext}') |
| | save_wav(path, wav, sample_rate) |
| | info = audio_info(path) |
| | assert info.sample_rate == sample_rate |
| | assert info.channels == ch |
| | assert np.isclose(info.duration, duration, atol=1e-5) |
| |
|
| | def test_info_wav(self): |
| | self._test_info_format('.wav') |
| |
|
| | def test_info_flac(self): |
| | self._test_info_format('.flac') |
| |
|
| | def test_info_ogg(self): |
| | self._test_info_format('.ogg') |
| |
|
| | def test_info_m4a(self): |
| | |
| | |
| | pass |
| |
|
| |
|
| | class TestRead(TempDirMixin): |
| |
|
| | def test_read_full_wav(self): |
| | sample_rates = [8000, 16_000] |
| | channels = [1, 2] |
| | duration = 1. |
| | for sample_rate, ch in product(sample_rates, channels): |
| | n_frames = int(sample_rate * duration) |
| | wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99) |
| | path = self.get_temp_path('sample_wav.wav') |
| | save_wav(path, wav, sample_rate) |
| | read_wav, read_sr = audio_read(path) |
| | assert read_sr == sample_rate |
| | assert read_wav.shape[0] == wav.shape[0] |
| | assert read_wav.shape[1] == wav.shape[1] |
| | assert torch.allclose(read_wav, wav, rtol=1e-03, atol=1e-04) |
| |
|
| | def test_read_partial_wav(self): |
| | sample_rates = [8000, 16_000] |
| | channels = [1, 2] |
| | duration = 1. |
| | read_duration = torch.rand(1).item() |
| | for sample_rate, ch in product(sample_rates, channels): |
| | n_frames = int(sample_rate * duration) |
| | read_frames = int(sample_rate * read_duration) |
| | wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99) |
| | path = self.get_temp_path('sample_wav.wav') |
| | save_wav(path, wav, sample_rate) |
| | read_wav, read_sr = audio_read(path, 0, read_duration) |
| | assert read_sr == sample_rate |
| | assert read_wav.shape[0] == wav.shape[0] |
| | assert read_wav.shape[1] == read_frames |
| | assert torch.allclose(read_wav[..., 0:read_frames], wav[..., 0:read_frames], rtol=1e-03, atol=1e-04) |
| |
|
| | def test_read_seek_time_wav(self): |
| | sample_rates = [8000, 16_000] |
| | channels = [1, 2] |
| | duration = 1. |
| | read_duration = 1. |
| | for sample_rate, ch in product(sample_rates, channels): |
| | n_frames = int(sample_rate * duration) |
| | wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99) |
| | path = self.get_temp_path('sample_wav.wav') |
| | save_wav(path, wav, sample_rate) |
| | seek_time = torch.rand(1).item() |
| | read_wav, read_sr = audio_read(path, seek_time, read_duration) |
| | seek_frames = int(sample_rate * seek_time) |
| | expected_frames = n_frames - seek_frames |
| | assert read_sr == sample_rate |
| | assert read_wav.shape[0] == wav.shape[0] |
| | assert read_wav.shape[1] == expected_frames |
| | assert torch.allclose(read_wav, wav[..., seek_frames:], rtol=1e-03, atol=1e-04) |
| |
|
| | def test_read_seek_time_wav_padded(self): |
| | sample_rates = [8000, 16_000] |
| | channels = [1, 2] |
| | duration = 1. |
| | read_duration = 1. |
| | for sample_rate, ch in product(sample_rates, channels): |
| | n_frames = int(sample_rate * duration) |
| | read_frames = int(sample_rate * read_duration) |
| | wav = get_white_noise(ch, n_frames).clamp(-0.99, 0.99) |
| | path = self.get_temp_path('sample_wav.wav') |
| | save_wav(path, wav, sample_rate) |
| | seek_time = torch.rand(1).item() |
| | seek_frames = int(sample_rate * seek_time) |
| | expected_frames = n_frames - seek_frames |
| | read_wav, read_sr = audio_read(path, seek_time, read_duration, pad=True) |
| | expected_pad_wav = torch.zeros(wav.shape[0], read_frames - expected_frames) |
| | assert read_sr == sample_rate |
| | assert read_wav.shape[0] == wav.shape[0] |
| | assert read_wav.shape[1] == read_frames |
| | assert torch.allclose(read_wav[..., :expected_frames], wav[..., seek_frames:], rtol=1e-03, atol=1e-04) |
| | assert torch.allclose(read_wav[..., expected_frames:], expected_pad_wav) |
| |
|
| |
|
| | class TestAvRead(TempDirMixin): |
| |
|
| | def test_avread_seek_base(self): |
| | sample_rates = [8000, 16_000] |
| | channels = [1, 2] |
| | duration = 2. |
| | for sample_rate, ch in product(sample_rates, channels): |
| | n_frames = int(sample_rate * duration) |
| | wav = get_white_noise(ch, n_frames) |
| | path = self.get_temp_path(f'reference_a_{sample_rate}_{ch}.wav') |
| | save_wav(path, wav, sample_rate) |
| | for _ in range(100): |
| | |
| | seek_time = random.uniform(0.0, 1.0) |
| | seek_duration = random.uniform(0.001, 1.0) |
| | read_wav, read_sr = _av_read(path, seek_time, seek_duration) |
| | assert read_sr == sample_rate |
| | assert read_wav.shape[0] == wav.shape[0] |
| | assert read_wav.shape[-1] == int(seek_duration * sample_rate) |
| |
|
| | def test_avread_seek_partial(self): |
| | sample_rates = [8000, 16_000] |
| | channels = [1, 2] |
| | duration = 1. |
| | for sample_rate, ch in product(sample_rates, channels): |
| | n_frames = int(sample_rate * duration) |
| | wav = get_white_noise(ch, n_frames) |
| | path = self.get_temp_path(f'reference_b_{sample_rate}_{ch}.wav') |
| | save_wav(path, wav, sample_rate) |
| | for _ in range(100): |
| | |
| | seek_time = random.uniform(0.5, 1.) |
| | seek_duration = 1. |
| | expected_num_frames = n_frames - int(seek_time * sample_rate) |
| | read_wav, read_sr = _av_read(path, seek_time, seek_duration) |
| | assert read_sr == sample_rate |
| | assert read_wav.shape[0] == wav.shape[0] |
| | assert read_wav.shape[-1] == expected_num_frames |
| |
|
| | def test_avread_seek_outofbound(self): |
| | sample_rates = [8000, 16_000] |
| | channels = [1, 2] |
| | duration = 1. |
| | for sample_rate, ch in product(sample_rates, channels): |
| | n_frames = int(sample_rate * duration) |
| | wav = get_white_noise(ch, n_frames) |
| | path = self.get_temp_path(f'reference_c_{sample_rate}_{ch}.wav') |
| | save_wav(path, wav, sample_rate) |
| | seek_time = 1.5 |
| | read_wav, read_sr = _av_read(path, seek_time, 1.) |
| | assert read_sr == sample_rate |
| | assert read_wav.shape[0] == wav.shape[0] |
| | assert read_wav.shape[-1] == 0 |
| |
|
| | def test_avread_seek_edge(self): |
| | sample_rates = [8000, 16_000] |
| | |
| | |
| | n_frames = [1000, 1001, 1002] |
| | channels = [1, 2] |
| | for sample_rate, ch, frames in product(sample_rates, channels, n_frames): |
| | duration = frames / sample_rate |
| | wav = get_white_noise(ch, frames) |
| | path = self.get_temp_path(f'reference_d_{sample_rate}_{ch}.wav') |
| | save_wav(path, wav, sample_rate) |
| | seek_time = (frames - 1) / sample_rate |
| | seek_frames = int(seek_time * sample_rate) |
| | read_wav, read_sr = _av_read(path, seek_time, duration) |
| | assert read_sr == sample_rate |
| | assert read_wav.shape[0] == wav.shape[0] |
| | assert read_wav.shape[-1] == (frames - seek_frames) |
| |
|
| |
|
| | class TestAudioWrite(TempDirMixin): |
| |
|
| | def test_audio_write_wav(self): |
| | torch.manual_seed(1234) |
| | sample_rates = [8000, 16_000] |
| | n_frames = [1000, 1001, 1002] |
| | channels = [1, 2] |
| | strategies = ["peak", "clip", "rms"] |
| | formats = ["wav", "mp3"] |
| | for sample_rate, ch, frames in product(sample_rates, channels, n_frames): |
| | for format_, strategy in product(formats, strategies): |
| | wav = get_white_noise(ch, frames) |
| | path = self.get_temp_path(f'pred_{sample_rate}_{ch}') |
| | audio_write(path, wav, sample_rate, format_, strategy=strategy) |
| | read_wav, read_sr = torchaudio.load(f'{path}.{format_}') |
| | if format_ == "wav": |
| | assert read_wav.shape == wav.shape |
| |
|
| | if format_ == "wav" and strategy in ["peak", "rms"]: |
| | rescaled_read_wav = read_wav / read_wav.abs().max() * wav.abs().max() |
| | |
| | |
| | |
| | |
| | atol = (5 if strategy == "peak" else 20) / 2**15 |
| | delta = (rescaled_read_wav - wav).abs().max() |
| | assert torch.allclose(wav, rescaled_read_wav, rtol=0, atol=atol), (delta, atol) |
| | formats = ["wav"] |
| |
|