| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Various utilities.""" |
| |
|
| | from hashlib import sha256 |
| | from pathlib import Path |
| | import typing as tp |
| |
|
| | import torch |
| | import torchaudio |
| |
|
| |
|
| | def _linear_overlap_add(frames: tp.List[torch.Tensor], stride: int): |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | assert len(frames) |
| | device = frames[0].device |
| | dtype = frames[0].dtype |
| | shape = frames[0].shape[:-1] |
| | total_size = stride * (len(frames) - 1) + frames[-1].shape[-1] |
| |
|
| | frame_length = frames[0].shape[-1] |
| | t = torch.linspace(0, 1, frame_length + 2, device=device, dtype=dtype)[1: -1] |
| | weight = 0.5 - (t - 0.5).abs() |
| |
|
| | sum_weight = torch.zeros(total_size, device=device, dtype=dtype) |
| | out = torch.zeros(*shape, total_size, device=device, dtype=dtype) |
| | offset: int = 0 |
| |
|
| | for frame in frames: |
| | frame_length = frame.shape[-1] |
| | out[..., offset:offset + frame_length] += weight[:frame_length] * frame |
| | sum_weight[offset:offset + frame_length] += weight[:frame_length] |
| | offset += stride |
| | assert sum_weight.min() > 0 |
| | return out / sum_weight |
| |
|
| |
|
| | def _get_checkpoint_url(root_url: str, checkpoint: str): |
| | if not root_url.endswith('/'): |
| | root_url += '/' |
| | return root_url + checkpoint |
| |
|
| |
|
| | def _check_checksum(path: Path, checksum: str): |
| | sha = sha256() |
| | with open(path, 'rb') as file: |
| | while True: |
| | buf = file.read(2**20) |
| | if not buf: |
| | break |
| | sha.update(buf) |
| | actual_checksum = sha.hexdigest()[:len(checksum)] |
| | if actual_checksum != checksum: |
| | raise RuntimeError(f'Invalid checksum for file {path}, ' |
| | f'expected {checksum} but got {actual_checksum}') |
| |
|
| |
|
| | def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int): |
| | assert wav.dim() >= 2, "Audio tensor must have at least 2 dimensions" |
| | assert wav.shape[-2] in [1, 2], "Audio must be mono or stereo." |
| | *shape, channels, length = wav.shape |
| | if target_channels == 1: |
| | wav = wav.mean(-2, keepdim=True) |
| | elif target_channels == 2: |
| | wav = wav.expand(*shape, target_channels, length) |
| | elif channels == 1: |
| | wav = wav.expand(target_channels, -1) |
| | else: |
| | raise RuntimeError(f"Impossible to convert from {channels} to {target_channels}") |
| | wav = torchaudio.transforms.Resample(sr, target_sr)(wav) |
| | return wav |
| |
|
| |
|
| | def save_audio(wav: torch.Tensor, path: tp.Union[Path, str], |
| | sample_rate: int, rescale: bool = False): |
| | limit = 0.99 |
| | mx = wav.abs().max() |
| | if rescale: |
| | wav = wav * min(limit / mx, 1) |
| | else: |
| | wav = wav.clamp(-limit, limit) |
| | torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16) |