| import argparse |
| import json |
| import shutil |
| import tempfile |
| from pathlib import Path |
| from urllib.error import HTTPError, URLError |
| from urllib.request import Request, urlopen |
|
|
| import numpy as np |
| import soundfile as sf |
| import torch |
| import torch.nn.functional as F |
| from einops import pack, rearrange, unpack |
| from rotary_embedding_torch import RotaryEmbedding |
| from safetensors.torch import load_file |
| from torch import einsum, nn |
|
|
|
|
| def pack_one(tensor, pattern): |
| return pack([tensor], pattern) |
|
|
|
|
| def unpack_one(tensor, packed_shape, pattern): |
| return unpack(tensor, packed_shape, pattern)[0] |
|
|
|
|
| class Attend(nn.Module): |
| def __init__(self): |
| super().__init__() |
|
|
| def forward(self, q, k, v): |
| scale = q.shape[-1] ** -0.5 |
| sim = einsum('b h i d, b h j d -> b h i j', q, k) * scale |
| attn = sim.softmax(dim=-1) |
| return einsum('b h i j, b h j d -> b h i d', attn, v) |
|
|
|
|
| class RMSNorm(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.scale = dim ** 0.5 |
| self.gamma = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x): |
| return F.normalize(x, dim=-1) * self.scale * self.gamma |
|
|
|
|
| class FeedForward(nn.Module): |
| def __init__(self, dim, ff_mult): |
| super().__init__() |
| dim_inner = int(dim * ff_mult) |
| self.net = nn.Sequential( |
| RMSNorm(dim), |
| nn.Linear(dim, dim_inner), |
| nn.GELU(), |
| nn.Identity(), |
| nn.Linear(dim_inner, dim), |
| nn.Identity(), |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
|
|
| class Attention(nn.Module): |
| def __init__(self, dim, heads, dim_head, rotary_embed): |
| super().__init__() |
| self.heads = heads |
| dim_inner = heads * dim_head |
| self.rotary_embed = rotary_embed |
| self.attend = Attend() |
| self.norm = RMSNorm(dim) |
| self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False) |
| self.to_gates = nn.Linear(dim, heads) |
| self.to_out = nn.Sequential( |
| nn.Linear(dim_inner, dim, bias=False), |
| nn.Identity(), |
| ) |
|
|
| def forward(self, x): |
| x = self.norm(x) |
| q, k, v = rearrange( |
| self.to_qkv(x), |
| 'b n (qkv h d) -> qkv b h n d', |
| qkv=3, |
| h=self.heads, |
| ) |
|
|
| q = self.rotary_embed.rotate_queries_or_keys(q) |
| k = self.rotary_embed.rotate_queries_or_keys(k) |
|
|
| out = self.attend(q, k, v) |
| gates = self.to_gates(x) |
| out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid() |
| out = rearrange(out, 'b h n d -> b n (h d)') |
| return self.to_out(out) |
|
|
|
|
| class Transformer(nn.Module): |
| def __init__(self, depth, dim, heads, dim_head, ff_mult, rotary_embed): |
| super().__init__() |
| self.layers = nn.ModuleList([]) |
|
|
| for _ in range(depth): |
| self.layers.append( |
| nn.ModuleList( |
| [ |
| Attention( |
| dim=dim, |
| heads=heads, |
| dim_head=dim_head, |
| rotary_embed=rotary_embed, |
| ), |
| FeedForward(dim=dim, ff_mult=ff_mult), |
| ] |
| ) |
| ) |
|
|
| def forward(self, x): |
| for attn, ff in self.layers: |
| x = attn(x) + x |
| x = ff(x) + x |
| return x |
|
|
|
|
| class BandSplit(nn.Module): |
| def __init__(self, dim_inputs, feature_dim): |
| super().__init__() |
| self.dim_inputs = dim_inputs |
| self.to_features = nn.ModuleList( |
| [nn.Sequential(nn.Linear(dim_in, feature_dim)) for dim_in in dim_inputs] |
| ) |
|
|
| def forward(self, x): |
| splits = x.split(self.dim_inputs, dim=-1) |
| features = [ |
| to_feature(split_input) |
| for split_input, to_feature in zip(splits, self.to_features) |
| ] |
| return torch.stack(features, dim=-2) |
|
|
|
|
| def MLP(dim_in, dim_out, dim_hidden, depth=1): |
| dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out) |
|
|
| layers = [] |
| for index, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])): |
| is_last = index == len(dims) - 2 |
| layers.append(nn.Linear(layer_dim_in, layer_dim_out)) |
| if not is_last: |
| layers.append(nn.Tanh()) |
|
|
| return nn.Sequential(*layers) |
|
|
|
|
| class MaskEstimator(nn.Module): |
| def __init__(self, dim_inputs, model_dim, depth, mlp_expansion_factor=4): |
| super().__init__() |
| dim_hidden = int(model_dim * mlp_expansion_factor) |
| self.to_freqs = nn.ModuleList( |
| [ |
| nn.Sequential( |
| MLP( |
| model_dim, |
| dim_in * 2, |
| dim_hidden=dim_hidden, |
| depth=depth, |
| ), |
| nn.GLU(dim=-1), |
| ) |
| for dim_in in dim_inputs |
| ] |
| ) |
|
|
| def forward(self, x): |
| outputs = [ |
| mlp(band_features) |
| for band_features, mlp in zip(x.unbind(dim=-2), self.to_freqs) |
| ] |
| return torch.cat(outputs, dim=-1) |
|
|
|
|
| class BSRoformer(nn.Module): |
| def __init__( |
| self, |
| *, |
| model_dim, |
| model_depth, |
| audio_channels, |
| num_stems, |
| time_transformer_depth, |
| freq_transformer_depth, |
| dim_head, |
| heads, |
| ff_mult, |
| stft_n_fft, |
| stft_hop_length, |
| stft_win_length, |
| stft_normalized, |
| mask_estimator_depth, |
| freq_range, |
| freqs_per_bands, |
| mask_mlp_expansion_factor=4, |
| ): |
| super().__init__() |
|
|
| self.audio_channels = audio_channels |
| self.num_stems = num_stems |
| self.layers = nn.ModuleList([]) |
|
|
| time_rotary_embed = RotaryEmbedding(dim=dim_head) |
| freq_rotary_embed = RotaryEmbedding(dim=dim_head) |
|
|
| for _ in range(model_depth): |
| self.layers.append( |
| nn.ModuleList( |
| [ |
| Transformer( |
| depth=time_transformer_depth, |
| dim=model_dim, |
| heads=heads, |
| dim_head=dim_head, |
| ff_mult=ff_mult, |
| rotary_embed=time_rotary_embed, |
| ), |
| Transformer( |
| depth=freq_transformer_depth, |
| dim=model_dim, |
| heads=heads, |
| dim_head=dim_head, |
| ff_mult=ff_mult, |
| rotary_embed=freq_rotary_embed, |
| ), |
| ] |
| ) |
| ) |
|
|
| self.final_norm = RMSNorm(model_dim) |
| self.stft_kwargs = dict( |
| n_fft=stft_n_fft, |
| hop_length=stft_hop_length, |
| win_length=stft_win_length, |
| normalized=stft_normalized, |
| ) |
| self.stft_window = torch.hann_window(stft_win_length) |
|
|
| freqs = stft_n_fft // 2 + 1 |
| min_freq, max_freq = (int(value) for value in freq_range) |
| if not 0 <= min_freq < max_freq <= freqs: |
| raise ValueError( |
| f'freq_range must satisfy 0 <= min < max <= {freqs}, got {(min_freq, max_freq)}' |
| ) |
| self.freq_slice = slice(min_freq, max_freq) |
| self.freq_pad = (min_freq, freqs - max_freq) |
|
|
| freqs_per_bands = tuple(int(band_size) for band_size in freqs_per_bands) |
| band_frequencies = max_freq - min_freq |
| if sum(freqs_per_bands) != band_frequencies: |
| raise ValueError( |
| f'freqs_per_bands must sum to {band_frequencies}, got {sum(freqs_per_bands)}' |
| ) |
|
|
| freqs_per_bands_with_complex = tuple( |
| 2 * band_size * self.audio_channels for band_size in freqs_per_bands |
| ) |
| self.band_split = BandSplit( |
| dim_inputs=freqs_per_bands_with_complex, |
| feature_dim=model_dim, |
| ) |
| self.mask_estimators = nn.ModuleList( |
| [ |
| MaskEstimator( |
| dim_inputs=freqs_per_bands_with_complex, |
| model_dim=model_dim, |
| depth=mask_estimator_depth, |
| mlp_expansion_factor=mask_mlp_expansion_factor, |
| ) |
| for _ in range(num_stems) |
| ] |
| ) |
|
|
| def forward(self, raw_audio): |
| if raw_audio.ndim == 2: |
| raw_audio = rearrange(raw_audio, 'b t -> b 1 t') |
|
|
| batch, channels, raw_audio_length = raw_audio.shape |
| if channels != self.audio_channels: |
| raise ValueError('audio channel count does not match the checkpoint architecture') |
|
|
| raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t') |
|
|
| stft_window = self.stft_window.to(device=raw_audio.device) |
|
|
| stft_repr = torch.stft( |
| raw_audio, |
| **self.stft_kwargs, |
| window=stft_window, |
| return_complex=True, |
| ) |
| stft_repr = torch.view_as_real(stft_repr) |
| stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c') |
| stft_repr = stft_repr[:, :, self.freq_slice] |
| stft_repr = rearrange(stft_repr, 'b s f t c -> b (f s) t c') |
|
|
| x = rearrange(stft_repr, 'b f t c -> b t (f c)') |
| x = self.band_split(x) |
|
|
| for time_transformer, freq_transformer in self.layers: |
| x = rearrange(x, 'b t f d -> b f t d') |
| x, packed_shape = pack([x], '* t d') |
| x = time_transformer(x) |
| x, = unpack(x, packed_shape, '* t d') |
|
|
| x = rearrange(x, 'b f t d -> b t f d') |
| x, packed_shape = pack([x], '* f d') |
| x = freq_transformer(x) |
| x, = unpack(x, packed_shape, '* f d') |
|
|
| x = self.final_norm(x) |
|
|
| mask = torch.stack( |
| [mask_estimator(x) for mask_estimator in self.mask_estimators], |
| dim=1, |
| ) |
| mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2) |
|
|
| stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c') |
| stft_repr = torch.view_as_complex(stft_repr) |
| mask = torch.view_as_complex(mask) |
| stft_repr = stft_repr * mask |
|
|
| stft_repr = rearrange( |
| stft_repr, |
| 'b n (f s) t -> (b n s) f t', |
| s=self.audio_channels, |
| ) |
| stft_repr = F.pad(stft_repr, (0, 0, *self.freq_pad)) |
|
|
| recon_audio = torch.istft( |
| stft_repr, |
| **self.stft_kwargs, |
| window=stft_window, |
| return_complex=False, |
| length=raw_audio_length, |
| ) |
|
|
| return rearrange( |
| recon_audio, |
| '(b n s) t -> b n s t', |
| b=batch, |
| s=self.audio_channels, |
| n=self.num_stems, |
| ) |
|
|
|
|
| INPUT_EXTENSIONS = {'.flac', '.wav', '.mp3'} |
| OUTPUT_FORMATS = {'wav', 'flac'} |
| DEFAULT_CONFIG_PATH = Path(__file__).with_name('config.json') |
| MODEL_CONFIG_URL = 'https://huggingface.co/tjpurdy/Piano-Separation-Model-small/resolve/main/config.json' |
| MODEL_CHECKPOINT_URL = 'https://huggingface.co/tjpurdy/Piano-Separation-Model-small/resolve/main/model.safetensors' |
| DOWNLOAD_TIMEOUT_SECONDS = 60 |
| MODEL_SAMPLE_RATE = 44100 |
| SEGMENT_SECONDS = 10 |
| DEFAULT_OVERLAP = 0.25 |
|
|
|
|
| def parse_output_format(value): |
| value = value.lower().lstrip('.') |
| if value not in OUTPUT_FORMATS: |
| raise argparse.ArgumentTypeError('output format must be wav or flac') |
| return value |
|
|
|
|
| def parse_overlap(value): |
| value = float(value) |
| if not 0 <= value < 1: |
| raise argparse.ArgumentTypeError('overlap must be in the range [0, 1)') |
| return value |
|
|
|
|
| def ensure_downloaded(file_path, url, description): |
| file_path = Path(file_path) |
| if file_path.exists(): |
| return file_path |
|
|
| file_path.parent.mkdir(parents=True, exist_ok=True) |
| temp_path = None |
| request = Request(url, headers={'User-Agent': 'inferencedownload/1.0'}) |
|
|
| try: |
| print(f'{description} not found at {file_path}, downloading from {url}') |
| with urlopen(request, timeout=DOWNLOAD_TIMEOUT_SECONDS) as response: |
| with tempfile.NamedTemporaryFile( |
| mode='wb', |
| delete=False, |
| dir=file_path.parent, |
| suffix='.download', |
| ) as temp_file: |
| temp_path = Path(temp_file.name) |
| shutil.copyfileobj(response, temp_file) |
|
|
| temp_path.replace(file_path) |
| print(f'Downloaded {description} to {file_path}') |
| return file_path |
| except (HTTPError, URLError, OSError) as exc: |
| if temp_path is not None and temp_path.exists(): |
| temp_path.unlink() |
| raise RuntimeError(f'Failed to download {description} from {url}: {exc}') from exc |
|
|
|
|
| def load_config(config_path): |
| config_path = ensure_downloaded(config_path, MODEL_CONFIG_URL, 'Model config') |
| with config_path.open('r', encoding='utf-8') as config_file: |
| return json.load(config_file) |
|
|
|
|
| def convert_audio(wav, from_sr, to_sr, channels): |
| if wav.ndim == 1: |
| wav = wav.unsqueeze(0) |
| if channels == 1: |
| wav = wav.mean(dim=0, keepdim=True) |
| elif wav.shape[0] == 1: |
| wav = wav.expand(channels, -1) |
| elif wav.shape[0] > channels: |
| wav = wav[:channels] |
| elif wav.shape[0] < channels: |
| raise ValueError('Audio has fewer channels than requested and is not mono.') |
| if from_sr == to_sr: |
| return wav |
|
|
| target_length = max(1, int(round(wav.shape[-1] * to_sr / from_sr))) |
| return F.interpolate( |
| wav.unsqueeze(0), |
| size=target_length, |
| mode='linear', |
| align_corners=False, |
| ).squeeze(0) |
|
|
|
|
| def load_separator(checkpoint_path, model_config, device): |
| model = BSRoformer(**model_config).eval().to(device) |
|
|
| checkpoint_path = Path(checkpoint_path) |
| checkpoint_was_missing = not checkpoint_path.exists() |
| checkpoint_path = ensure_downloaded( |
| checkpoint_path, |
| MODEL_CHECKPOINT_URL, |
| 'Model checkpoint', |
| ) |
| checkpoint_is_safetensors = checkpoint_was_missing or checkpoint_path.suffix == '.safetensors' |
| state = load_file(checkpoint_path) if checkpoint_is_safetensors else torch.load(checkpoint_path, map_location='cpu') |
| state = state.get('state', state) |
| model.load_state_dict({k[7:] if k.startswith('module.') else k: v for k, v in state.items()}) |
| return model |
|
|
|
|
| def list_audio_files(input_path): |
| input_path = Path(input_path) |
| if input_path.is_file(): |
| if input_path.suffix.lower() not in INPUT_EXTENSIONS: |
| raise ValueError(f'Input file is not a supported audio file: {input_path}') |
| return [input_path] |
|
|
| if not input_path.is_dir(): |
| raise FileNotFoundError( |
| f'Input path does not exist or is not a supported file/directory: {input_path}' |
| ) |
|
|
| files = sorted( |
| path |
| for path in input_path.rglob('*') |
| if path.is_file() and path.suffix.lower() in INPUT_EXTENSIONS |
| ) |
| duplicates = {} |
| for path in files: |
| duplicates.setdefault(path.stem, []).append(path) |
| duplicates = {stem: paths for stem, paths in duplicates.items() if len(paths) > 1} |
| if duplicates: |
| details = '\n'.join(f'{stem}: {", ".join(str(path) for path in paths)}' for stem, paths in sorted(duplicates.items())) |
| raise ValueError( |
| 'Multiple input files share the same name, so flat output filenames would collide:\n' + details |
| ) |
| return files |
|
|
|
|
| def run_model(model, mix, overlap): |
| length = mix.shape[-1] |
| segment = MODEL_SAMPLE_RATE * SEGMENT_SECONDS |
| stride = max(1, int(segment * (1 - overlap))) |
| weight = torch.cat(( |
| torch.arange(1, segment // 2 + 1, device=mix.device), |
| torch.arange(segment - segment // 2, 0, -1, device=mix.device), |
| )).float() |
| estimate = None |
| sum_weight = torch.zeros(length, device=mix.device) |
|
|
| with torch.inference_mode(): |
| for start in range(0, length, stride): |
| chunk = mix[:, start:start + segment] |
| chunk_est = model(chunk[None])[0] |
| if estimate is None: |
| estimate = torch.zeros(*chunk_est.shape[:-1], length, device=mix.device) |
| chunk_weight = weight[:chunk.shape[-1]] |
| estimate[..., start:start + chunk.shape[-1]] += chunk_est * chunk_weight |
| sum_weight[start:start + chunk.shape[-1]] += chunk_weight |
|
|
| return estimate / sum_weight |
|
|
|
|
| def separate_file(model, file_path, device, overlap): |
| audio, sample_rate = sf.read(file_path, dtype='float32') |
| mix = torch.from_numpy(np.asarray(audio, np.float32)) |
| mix = mix.unsqueeze(0) if mix.ndim == 1 else mix.T |
| source_channels = mix.shape[0] |
| mix = convert_audio(mix.to(device), sample_rate, MODEL_SAMPLE_RATE, model.audio_channels) |
|
|
| mono = mix.mean(0) |
| mean = mono.mean() |
| std = mono.std().clamp_min(1e-8) |
| mix = (mix - mean) / std |
|
|
| estimate = run_model(model, mix, overlap)[0] * std + mean |
| estimate = convert_audio(estimate, MODEL_SAMPLE_RATE, sample_rate, source_channels) |
| return estimate.T.cpu().numpy(), sample_rate |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description='Music source separation inference') |
| parser.add_argument('--input_dir', type=str, required=True, help='Input audio file or directory containing audio files') |
| parser.add_argument( |
| '--output_dir', |
| type=str, |
| default=None, |
| help='Output directory to save separated audio (default: same location as input)', |
| ) |
| parser.add_argument('--config_path', type=str, default=str(DEFAULT_CONFIG_PATH), help='Path to model config JSON') |
| parser.add_argument('--checkpoint_path', type=str, default='./model.safetensors', help='Path to model checkpoint file') |
| parser.add_argument('--output_format', type=parse_output_format, default='wav', help='Output file format: wav or flac (default: wav)') |
| parser.add_argument('--overlap', type=parse_overlap, default=DEFAULT_OVERLAP, help='Chunk overlap ratio in [0, 1) (default: 0.25)') |
| return parser.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
| input_path = Path(args.input_dir) |
| model_config = load_config(args.config_path) |
| audio_files = list_audio_files(args.input_dir) |
| if not audio_files: |
| print(f'No supported audio files found in {args.input_dir}') |
| return |
|
|
| if args.output_dir is not None: |
| output_dir = Path(args.output_dir) |
| else: |
| output_dir = input_path.parent if input_path.is_file() else input_path |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| if device.type == 'cpu': |
| print('WARNING, using CPU') |
|
|
| model = load_separator(args.checkpoint_path, model_config, device) |
| print(f'Found {len(audio_files)} audio file(s) from {args.input_dir}') |
|
|
| for file_path in audio_files: |
| print(f'Processing {file_path}') |
| estimate, sample_rate = separate_file(model, file_path, device, args.overlap) |
| save_path = output_dir / f'{file_path.stem}_Piano.{args.output_format}' |
| sf.write(save_path, estimate, sample_rate) |
| print(f'Saved {save_path}') |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|