Spaces:
Build error
Build error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """Various utilities.""" | |
| from hashlib import sha256 | |
| from pathlib import Path | |
| import typing as tp | |
| import torch | |
| import torchaudio | |
| def _linear_overlap_add(frames: tp.List[torch.Tensor], stride: int): | |
| # Generic overlap add, with linear fade-in/fade-out, supporting complex scenario | |
| # e.g., more than 2 frames per position. | |
| # The core idea is to use a weight function that is a triangle, | |
| # with a maximum value at the middle of the segment. | |
| # We use this weighting when summing the frames, and divide by the sum of weights | |
| # for each positions at the end. Thus: | |
| # - if a frame is the only one to cover a position, the weighting is a no-op. | |
| # - if 2 frames cover a position: | |
| # ... ... | |
| # / \/ \ | |
| # / /\ \ | |
| # S T , i.e. S offset of second frame starts, T end of first frame. | |
| # Then the weight function for each one is: (t - S), (T - t), with `t` a given offset. | |
| # After the final normalization, the weight of the second frame at position `t` is | |
| # (t - S) / (t - S + (T - t)) = (t - S) / (T - S), which is exactly what we want. | |
| # | |
| # - if more than 2 frames overlap at a given point, we hope that by induction | |
| # something sensible happens. | |
| assert len(frames) | |
| device = frames[0].device | |
| dtype = frames[0].dtype | |
| shape = frames[0].shape[:-1] | |
| total_size = stride * (len(frames) - 1) + frames[-1].shape[-1] | |
| frame_length = frames[0].shape[-1] | |
| t = torch.linspace(0, 1, frame_length + 2, device=device, dtype=dtype)[1: -1] | |
| weight = 0.5 - (t - 0.5).abs() | |
| sum_weight = torch.zeros(total_size, device=device, dtype=dtype) | |
| out = torch.zeros(*shape, total_size, device=device, dtype=dtype) | |
| offset: int = 0 | |
| for frame in frames: | |
| frame_length = frame.shape[-1] | |
| out[..., offset:offset + frame_length] += weight[:frame_length] * frame | |
| sum_weight[offset:offset + frame_length] += weight[:frame_length] | |
| offset += stride | |
| assert sum_weight.min() > 0 | |
| return out / sum_weight | |
| def _get_checkpoint_url(root_url: str, checkpoint: str): | |
| if not root_url.endswith('/'): | |
| root_url += '/' | |
| return root_url + checkpoint | |
| def _check_checksum(path: Path, checksum: str): | |
| sha = sha256() | |
| with open(path, 'rb') as file: | |
| while True: | |
| buf = file.read(2**20) | |
| if not buf: | |
| break | |
| sha.update(buf) | |
| actual_checksum = sha.hexdigest()[:len(checksum)] | |
| if actual_checksum != checksum: | |
| raise RuntimeError(f'Invalid checksum for file {path}, ' | |
| f'expected {checksum} but got {actual_checksum}') | |
| def convert_audio(wav: torch.Tensor, sr: int, target_sr: int, target_channels: int): | |
| assert wav.dim() >= 2, "Audio tensor must have at least 2 dimensions" | |
| assert wav.shape[-2] in [1, 2], "Audio must be mono or stereo." | |
| *shape, channels, length = wav.shape | |
| if target_channels == 1: | |
| wav = wav.mean(-2, keepdim=True) | |
| elif target_channels == 2: | |
| wav = wav.expand(*shape, target_channels, length) | |
| elif channels == 1: | |
| wav = wav.expand(target_channels, -1) | |
| else: | |
| raise RuntimeError(f"Impossible to convert from {channels} to {target_channels}") | |
| wav = torchaudio.transforms.Resample(sr, target_sr)(wav) | |
| return wav | |
| def save_audio(wav: torch.Tensor, path: tp.Union[Path, str], | |
| sample_rate: int, rescale: bool = False): | |
| limit = 0.99 | |
| mx = wav.abs().max() | |
| if rescale: | |
| wav = wav * min(limit / mx, 1) | |
| else: | |
| wav = wav.clamp(-limit, limit) | |
| torchaudio.save(str(path), wav, sample_rate=sample_rate, encoding='PCM_S', bits_per_sample=16) | |