tjpurdy commited on
Commit
d9b10d7
·
verified ·
1 Parent(s): ae325b5

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +574 -0
inference.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import shutil
4
+ import tempfile
5
+ from pathlib import Path
6
+ from urllib.error import HTTPError, URLError
7
+ from urllib.request import Request, urlopen
8
+
9
+ import numpy as np
10
+ import soundfile as sf
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from einops import pack, rearrange, unpack
14
+ from rotary_embedding_torch import RotaryEmbedding
15
+ from safetensors.torch import load_file
16
+ from torch import einsum, nn
17
+
18
+
19
+ def pack_one(tensor, pattern):
20
+ return pack([tensor], pattern)
21
+
22
+
23
+ def unpack_one(tensor, packed_shape, pattern):
24
+ return unpack(tensor, packed_shape, pattern)[0]
25
+
26
+
27
+ class Attend(nn.Module):
28
+ def __init__(self):
29
+ super().__init__()
30
+
31
+ def forward(self, q, k, v):
32
+ scale = q.shape[-1] ** -0.5
33
+ sim = einsum('b h i d, b h j d -> b h i j', q, k) * scale
34
+ attn = sim.softmax(dim=-1)
35
+ return einsum('b h i j, b h j d -> b h i d', attn, v)
36
+
37
+
38
+ class RMSNorm(nn.Module):
39
+ def __init__(self, dim):
40
+ super().__init__()
41
+ self.scale = dim ** 0.5
42
+ self.gamma = nn.Parameter(torch.ones(dim))
43
+
44
+ def forward(self, x):
45
+ return F.normalize(x, dim=-1) * self.scale * self.gamma
46
+
47
+
48
+ class FeedForward(nn.Module):
49
+ def __init__(self, dim, ff_mult):
50
+ super().__init__()
51
+ dim_inner = int(dim * ff_mult)
52
+ self.net = nn.Sequential(
53
+ RMSNorm(dim),
54
+ nn.Linear(dim, dim_inner),
55
+ nn.GELU(),
56
+ nn.Identity(),
57
+ nn.Linear(dim_inner, dim),
58
+ nn.Identity(),
59
+ )
60
+
61
+ def forward(self, x):
62
+ return self.net(x)
63
+
64
+
65
+ class Attention(nn.Module):
66
+ def __init__(self, dim, heads, dim_head, rotary_embed):
67
+ super().__init__()
68
+ self.heads = heads
69
+ dim_inner = heads * dim_head
70
+ self.rotary_embed = rotary_embed
71
+ self.attend = Attend()
72
+ self.norm = RMSNorm(dim)
73
+ self.to_qkv = nn.Linear(dim, dim_inner * 3, bias=False)
74
+ self.to_gates = nn.Linear(dim, heads)
75
+ self.to_out = nn.Sequential(
76
+ nn.Linear(dim_inner, dim, bias=False),
77
+ nn.Identity(),
78
+ )
79
+
80
+ def forward(self, x):
81
+ x = self.norm(x)
82
+ q, k, v = rearrange(
83
+ self.to_qkv(x),
84
+ 'b n (qkv h d) -> qkv b h n d',
85
+ qkv=3,
86
+ h=self.heads,
87
+ )
88
+
89
+ q = self.rotary_embed.rotate_queries_or_keys(q)
90
+ k = self.rotary_embed.rotate_queries_or_keys(k)
91
+
92
+ out = self.attend(q, k, v)
93
+ gates = self.to_gates(x)
94
+ out = out * rearrange(gates, 'b n h -> b h n 1').sigmoid()
95
+ out = rearrange(out, 'b h n d -> b n (h d)')
96
+ return self.to_out(out)
97
+
98
+
99
+ class Transformer(nn.Module):
100
+ def __init__(self, depth, dim, heads, dim_head, ff_mult, rotary_embed):
101
+ super().__init__()
102
+ self.layers = nn.ModuleList([])
103
+
104
+ for _ in range(depth):
105
+ self.layers.append(
106
+ nn.ModuleList(
107
+ [
108
+ Attention(
109
+ dim=dim,
110
+ heads=heads,
111
+ dim_head=dim_head,
112
+ rotary_embed=rotary_embed,
113
+ ),
114
+ FeedForward(dim=dim, ff_mult=ff_mult),
115
+ ]
116
+ )
117
+ )
118
+
119
+ def forward(self, x):
120
+ for attn, ff in self.layers:
121
+ x = attn(x) + x
122
+ x = ff(x) + x
123
+ return x
124
+
125
+
126
+ class BandSplit(nn.Module):
127
+ def __init__(self, dim_inputs, feature_dim):
128
+ super().__init__()
129
+ self.dim_inputs = dim_inputs
130
+ self.to_features = nn.ModuleList(
131
+ [nn.Sequential(nn.Linear(dim_in, feature_dim)) for dim_in in dim_inputs]
132
+ )
133
+
134
+ def forward(self, x):
135
+ splits = x.split(self.dim_inputs, dim=-1)
136
+ features = [
137
+ to_feature(split_input)
138
+ for split_input, to_feature in zip(splits, self.to_features)
139
+ ]
140
+ return torch.stack(features, dim=-2)
141
+
142
+
143
+ def MLP(dim_in, dim_out, dim_hidden, depth=1):
144
+ dims = (dim_in, *((dim_hidden,) * (depth - 1)), dim_out)
145
+
146
+ layers = []
147
+ for index, (layer_dim_in, layer_dim_out) in enumerate(zip(dims[:-1], dims[1:])):
148
+ is_last = index == len(dims) - 2
149
+ layers.append(nn.Linear(layer_dim_in, layer_dim_out))
150
+ if not is_last:
151
+ layers.append(nn.Tanh())
152
+
153
+ return nn.Sequential(*layers)
154
+
155
+
156
+ class MaskEstimator(nn.Module):
157
+ def __init__(self, dim_inputs, model_dim, depth, mlp_expansion_factor=4):
158
+ super().__init__()
159
+ dim_hidden = int(model_dim * mlp_expansion_factor)
160
+ self.to_freqs = nn.ModuleList(
161
+ [
162
+ nn.Sequential(
163
+ MLP(
164
+ model_dim,
165
+ dim_in * 2,
166
+ dim_hidden=dim_hidden,
167
+ depth=depth,
168
+ ),
169
+ nn.GLU(dim=-1),
170
+ )
171
+ for dim_in in dim_inputs
172
+ ]
173
+ )
174
+
175
+ def forward(self, x):
176
+ outputs = [
177
+ mlp(band_features)
178
+ for band_features, mlp in zip(x.unbind(dim=-2), self.to_freqs)
179
+ ]
180
+ return torch.cat(outputs, dim=-1)
181
+
182
+
183
+ class BSRoformer(nn.Module):
184
+ def __init__(
185
+ self,
186
+ *,
187
+ model_dim,
188
+ model_depth,
189
+ audio_channels,
190
+ num_stems,
191
+ time_transformer_depth,
192
+ freq_transformer_depth,
193
+ dim_head,
194
+ heads,
195
+ ff_mult,
196
+ stft_n_fft,
197
+ stft_hop_length,
198
+ stft_win_length,
199
+ stft_normalized,
200
+ mask_estimator_depth,
201
+ freq_range,
202
+ freqs_per_bands,
203
+ mask_mlp_expansion_factor=4,
204
+ ):
205
+ super().__init__()
206
+
207
+ self.audio_channels = audio_channels
208
+ self.num_stems = num_stems
209
+ self.layers = nn.ModuleList([])
210
+
211
+ time_rotary_embed = RotaryEmbedding(dim=dim_head)
212
+ freq_rotary_embed = RotaryEmbedding(dim=dim_head)
213
+
214
+ for _ in range(model_depth):
215
+ self.layers.append(
216
+ nn.ModuleList(
217
+ [
218
+ Transformer(
219
+ depth=time_transformer_depth,
220
+ dim=model_dim,
221
+ heads=heads,
222
+ dim_head=dim_head,
223
+ ff_mult=ff_mult,
224
+ rotary_embed=time_rotary_embed,
225
+ ),
226
+ Transformer(
227
+ depth=freq_transformer_depth,
228
+ dim=model_dim,
229
+ heads=heads,
230
+ dim_head=dim_head,
231
+ ff_mult=ff_mult,
232
+ rotary_embed=freq_rotary_embed,
233
+ ),
234
+ ]
235
+ )
236
+ )
237
+
238
+ self.final_norm = RMSNorm(model_dim)
239
+ self.stft_kwargs = dict(
240
+ n_fft=stft_n_fft,
241
+ hop_length=stft_hop_length,
242
+ win_length=stft_win_length,
243
+ normalized=stft_normalized,
244
+ )
245
+ self.stft_window = torch.hann_window(stft_win_length)
246
+
247
+ freqs = stft_n_fft // 2 + 1
248
+ min_freq, max_freq = (int(value) for value in freq_range)
249
+ if not 0 <= min_freq < max_freq <= freqs:
250
+ raise ValueError(
251
+ f'freq_range must satisfy 0 <= min < max <= {freqs}, got {(min_freq, max_freq)}'
252
+ )
253
+ self.freq_slice = slice(min_freq, max_freq)
254
+ self.freq_pad = (min_freq, freqs - max_freq)
255
+
256
+ freqs_per_bands = tuple(int(band_size) for band_size in freqs_per_bands)
257
+ band_frequencies = max_freq - min_freq
258
+ if sum(freqs_per_bands) != band_frequencies:
259
+ raise ValueError(
260
+ f'freqs_per_bands must sum to {band_frequencies}, got {sum(freqs_per_bands)}'
261
+ )
262
+
263
+ freqs_per_bands_with_complex = tuple(
264
+ 2 * band_size * self.audio_channels for band_size in freqs_per_bands
265
+ )
266
+ self.band_split = BandSplit(
267
+ dim_inputs=freqs_per_bands_with_complex,
268
+ feature_dim=model_dim,
269
+ )
270
+ self.mask_estimators = nn.ModuleList(
271
+ [
272
+ MaskEstimator(
273
+ dim_inputs=freqs_per_bands_with_complex,
274
+ model_dim=model_dim,
275
+ depth=mask_estimator_depth,
276
+ mlp_expansion_factor=mask_mlp_expansion_factor,
277
+ )
278
+ for _ in range(num_stems)
279
+ ]
280
+ )
281
+
282
+ def forward(self, raw_audio):
283
+ if raw_audio.ndim == 2:
284
+ raw_audio = rearrange(raw_audio, 'b t -> b 1 t')
285
+
286
+ batch, channels, raw_audio_length = raw_audio.shape
287
+ if channels != self.audio_channels:
288
+ raise ValueError('audio channel count does not match the checkpoint architecture')
289
+
290
+ raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t')
291
+
292
+ stft_window = self.stft_window.to(device=raw_audio.device)
293
+
294
+ stft_repr = torch.stft(
295
+ raw_audio,
296
+ **self.stft_kwargs,
297
+ window=stft_window,
298
+ return_complex=True,
299
+ )
300
+ stft_repr = torch.view_as_real(stft_repr)
301
+ stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c')
302
+ stft_repr = stft_repr[:, :, self.freq_slice]
303
+ stft_repr = rearrange(stft_repr, 'b s f t c -> b (f s) t c')
304
+
305
+ x = rearrange(stft_repr, 'b f t c -> b t (f c)')
306
+ x = self.band_split(x)
307
+
308
+ for time_transformer, freq_transformer in self.layers:
309
+ x = rearrange(x, 'b t f d -> b f t d')
310
+ x, packed_shape = pack([x], '* t d')
311
+ x = time_transformer(x)
312
+ x, = unpack(x, packed_shape, '* t d')
313
+
314
+ x = rearrange(x, 'b f t d -> b t f d')
315
+ x, packed_shape = pack([x], '* f d')
316
+ x = freq_transformer(x)
317
+ x, = unpack(x, packed_shape, '* f d')
318
+
319
+ x = self.final_norm(x)
320
+
321
+ mask = torch.stack(
322
+ [mask_estimator(x) for mask_estimator in self.mask_estimators],
323
+ dim=1,
324
+ )
325
+ mask = rearrange(mask, 'b n t (f c) -> b n f t c', c=2)
326
+
327
+ stft_repr = rearrange(stft_repr, 'b f t c -> b 1 f t c')
328
+ stft_repr = torch.view_as_complex(stft_repr)
329
+ mask = torch.view_as_complex(mask)
330
+ stft_repr = stft_repr * mask
331
+
332
+ stft_repr = rearrange(
333
+ stft_repr,
334
+ 'b n (f s) t -> (b n s) f t',
335
+ s=self.audio_channels,
336
+ )
337
+ stft_repr = F.pad(stft_repr, (0, 0, *self.freq_pad))
338
+
339
+ recon_audio = torch.istft(
340
+ stft_repr,
341
+ **self.stft_kwargs,
342
+ window=stft_window,
343
+ return_complex=False,
344
+ length=raw_audio_length,
345
+ )
346
+
347
+ return rearrange(
348
+ recon_audio,
349
+ '(b n s) t -> b n s t',
350
+ b=batch,
351
+ s=self.audio_channels,
352
+ n=self.num_stems,
353
+ )
354
+
355
+
356
+ INPUT_EXTENSIONS = {'.flac', '.wav', '.mp3'}
357
+ OUTPUT_FORMATS = {'wav', 'flac'}
358
+ DEFAULT_CONFIG_PATH = Path(__file__).with_name('config.json')
359
+ MODEL_CONFIG_URL = 'https://huggingface.co/tjpurdy/Piano-Separation-Model-small/resolve/main/config.json'
360
+ MODEL_CHECKPOINT_URL = 'https://huggingface.co/tjpurdy/Piano-Separation-Model-small/resolve/main/model.safetensors'
361
+ DOWNLOAD_TIMEOUT_SECONDS = 60
362
+ MODEL_SAMPLE_RATE = 44100
363
+ SEGMENT_SECONDS = 10
364
+ DEFAULT_OVERLAP = 0.25
365
+
366
+
367
+ def parse_output_format(value):
368
+ value = value.lower().lstrip('.')
369
+ if value not in OUTPUT_FORMATS:
370
+ raise argparse.ArgumentTypeError('output format must be wav or flac')
371
+ return value
372
+
373
+
374
+ def parse_overlap(value):
375
+ value = float(value)
376
+ if not 0 <= value < 1:
377
+ raise argparse.ArgumentTypeError('overlap must be in the range [0, 1)')
378
+ return value
379
+
380
+
381
+ def ensure_downloaded(file_path, url, description):
382
+ file_path = Path(file_path)
383
+ if file_path.exists():
384
+ return file_path
385
+
386
+ file_path.parent.mkdir(parents=True, exist_ok=True)
387
+ temp_path = None
388
+ request = Request(url, headers={'User-Agent': 'inferencedownload/1.0'})
389
+
390
+ try:
391
+ print(f'{description} not found at {file_path}, downloading from {url}')
392
+ with urlopen(request, timeout=DOWNLOAD_TIMEOUT_SECONDS) as response:
393
+ with tempfile.NamedTemporaryFile(
394
+ mode='wb',
395
+ delete=False,
396
+ dir=file_path.parent,
397
+ suffix='.download',
398
+ ) as temp_file:
399
+ temp_path = Path(temp_file.name)
400
+ shutil.copyfileobj(response, temp_file)
401
+
402
+ temp_path.replace(file_path)
403
+ print(f'Downloaded {description} to {file_path}')
404
+ return file_path
405
+ except (HTTPError, URLError, OSError) as exc:
406
+ if temp_path is not None and temp_path.exists():
407
+ temp_path.unlink()
408
+ raise RuntimeError(f'Failed to download {description} from {url}: {exc}') from exc
409
+
410
+
411
+ def load_config(config_path):
412
+ config_path = ensure_downloaded(config_path, MODEL_CONFIG_URL, 'Model config')
413
+ with config_path.open('r', encoding='utf-8') as config_file:
414
+ return json.load(config_file)
415
+
416
+
417
+ def convert_audio(wav, from_sr, to_sr, channels):
418
+ if wav.ndim == 1:
419
+ wav = wav.unsqueeze(0)
420
+ if channels == 1:
421
+ wav = wav.mean(dim=0, keepdim=True)
422
+ elif wav.shape[0] == 1:
423
+ wav = wav.expand(channels, -1)
424
+ elif wav.shape[0] > channels:
425
+ wav = wav[:channels]
426
+ elif wav.shape[0] < channels:
427
+ raise ValueError('Audio has fewer channels than requested and is not mono.')
428
+ if from_sr == to_sr:
429
+ return wav
430
+
431
+ target_length = max(1, int(round(wav.shape[-1] * to_sr / from_sr)))
432
+ return F.interpolate(
433
+ wav.unsqueeze(0),
434
+ size=target_length,
435
+ mode='linear',
436
+ align_corners=False,
437
+ ).squeeze(0)
438
+
439
+
440
+ def load_separator(checkpoint_path, model_config, device):
441
+ model = BSRoformer(**model_config).eval().to(device)
442
+
443
+ checkpoint_path = Path(checkpoint_path)
444
+ checkpoint_was_missing = not checkpoint_path.exists()
445
+ checkpoint_path = ensure_downloaded(
446
+ checkpoint_path,
447
+ MODEL_CHECKPOINT_URL,
448
+ 'Model checkpoint',
449
+ )
450
+ checkpoint_is_safetensors = checkpoint_was_missing or checkpoint_path.suffix == '.safetensors'
451
+ state = load_file(checkpoint_path) if checkpoint_is_safetensors else torch.load(checkpoint_path, map_location='cpu')
452
+ state = state.get('state', state)
453
+ model.load_state_dict({k[7:] if k.startswith('module.') else k: v for k, v in state.items()})
454
+ return model
455
+
456
+
457
+ def list_audio_files(input_path):
458
+ input_path = Path(input_path)
459
+ if input_path.is_file():
460
+ if input_path.suffix.lower() not in INPUT_EXTENSIONS:
461
+ raise ValueError(f'Input file is not a supported audio file: {input_path}')
462
+ return [input_path]
463
+
464
+ if not input_path.is_dir():
465
+ raise FileNotFoundError(
466
+ f'Input path does not exist or is not a supported file/directory: {input_path}'
467
+ )
468
+
469
+ files = sorted(
470
+ path
471
+ for path in input_path.rglob('*')
472
+ if path.is_file() and path.suffix.lower() in INPUT_EXTENSIONS
473
+ )
474
+ duplicates = {}
475
+ for path in files:
476
+ duplicates.setdefault(path.stem, []).append(path)
477
+ duplicates = {stem: paths for stem, paths in duplicates.items() if len(paths) > 1}
478
+ if duplicates:
479
+ details = '\n'.join(f'{stem}: {", ".join(str(path) for path in paths)}' for stem, paths in sorted(duplicates.items()))
480
+ raise ValueError(
481
+ 'Multiple input files share the same name, so flat output filenames would collide:\n' + details
482
+ )
483
+ return files
484
+
485
+
486
+ def run_model(model, mix, overlap):
487
+ length = mix.shape[-1]
488
+ segment = MODEL_SAMPLE_RATE * SEGMENT_SECONDS
489
+ stride = max(1, int(segment * (1 - overlap)))
490
+ weight = torch.cat((
491
+ torch.arange(1, segment // 2 + 1, device=mix.device),
492
+ torch.arange(segment - segment // 2, 0, -1, device=mix.device),
493
+ )).float()
494
+ estimate = None
495
+ sum_weight = torch.zeros(length, device=mix.device)
496
+
497
+ with torch.inference_mode():
498
+ for start in range(0, length, stride):
499
+ chunk = mix[:, start:start + segment]
500
+ chunk_est = model(chunk[None])[0]
501
+ if estimate is None:
502
+ estimate = torch.zeros(*chunk_est.shape[:-1], length, device=mix.device)
503
+ chunk_weight = weight[:chunk.shape[-1]]
504
+ estimate[..., start:start + chunk.shape[-1]] += chunk_est * chunk_weight
505
+ sum_weight[start:start + chunk.shape[-1]] += chunk_weight
506
+
507
+ return estimate / sum_weight
508
+
509
+
510
+ def separate_file(model, file_path, device, overlap):
511
+ audio, sample_rate = sf.read(file_path, dtype='float32')
512
+ mix = torch.from_numpy(np.asarray(audio, np.float32))
513
+ mix = mix.unsqueeze(0) if mix.ndim == 1 else mix.T
514
+ source_channels = mix.shape[0]
515
+ mix = convert_audio(mix.to(device), sample_rate, MODEL_SAMPLE_RATE, model.audio_channels)
516
+
517
+ mono = mix.mean(0)
518
+ mean = mono.mean()
519
+ std = mono.std().clamp_min(1e-8)
520
+ mix = (mix - mean) / std
521
+
522
+ estimate = run_model(model, mix, overlap)[0] * std + mean
523
+ estimate = convert_audio(estimate, MODEL_SAMPLE_RATE, sample_rate, source_channels)
524
+ return estimate.T.cpu().numpy(), sample_rate
525
+
526
+
527
+ def parse_args():
528
+ parser = argparse.ArgumentParser(description='Music source separation inference')
529
+ parser.add_argument('--input_dir', type=str, required=True, help='Input audio file or directory containing audio files')
530
+ parser.add_argument(
531
+ '--output_dir',
532
+ type=str,
533
+ default=None,
534
+ help='Output directory to save separated audio (default: same location as input)',
535
+ )
536
+ parser.add_argument('--config_path', type=str, default=str(DEFAULT_CONFIG_PATH), help='Path to model config JSON')
537
+ parser.add_argument('--checkpoint_path', type=str, default='./model.safetensors', help='Path to model checkpoint file')
538
+ parser.add_argument('--output_format', type=parse_output_format, default='wav', help='Output file format: wav or flac (default: wav)')
539
+ parser.add_argument('--overlap', type=parse_overlap, default=DEFAULT_OVERLAP, help='Chunk overlap ratio in [0, 1) (default: 0.25)')
540
+ return parser.parse_args()
541
+
542
+
543
+ def main():
544
+ args = parse_args()
545
+ input_path = Path(args.input_dir)
546
+ model_config = load_config(args.config_path)
547
+ audio_files = list_audio_files(args.input_dir)
548
+ if not audio_files:
549
+ print(f'No supported audio files found in {args.input_dir}')
550
+ return
551
+
552
+ if args.output_dir is not None:
553
+ output_dir = Path(args.output_dir)
554
+ else:
555
+ output_dir = input_path.parent if input_path.is_file() else input_path
556
+ output_dir.mkdir(parents=True, exist_ok=True)
557
+
558
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
559
+ if device.type == 'cpu':
560
+ print('WARNING, using CPU')
561
+
562
+ model = load_separator(args.checkpoint_path, model_config, device)
563
+ print(f'Found {len(audio_files)} audio file(s) from {args.input_dir}')
564
+
565
+ for file_path in audio_files:
566
+ print(f'Processing {file_path}')
567
+ estimate, sample_rate = separate_file(model, file_path, device, args.overlap)
568
+ save_path = output_dir / f'{file_path.stem}_Piano.{args.output_format}'
569
+ sf.write(save_path, estimate, sample_rate)
570
+ print(f'Saved {save_path}')
571
+
572
+
573
+ if __name__ == '__main__':
574
+ main()