noblebarkrr commited on
Commit
edcdeff
·
verified ·
1 Parent(s): 0125dc3

Update mvsepless/infer.py

Browse files
Files changed (1) hide show
  1. mvsepless/infer.py +766 -768
mvsepless/infer.py CHANGED
@@ -1,769 +1,767 @@
1
- import os
2
- import sys
3
- sys.stdout.reconfigure(encoding='utf-8')
4
- sys.stderr.reconfigure(encoding='utf-8')
5
- import json
6
- import argparse
7
- import time
8
- import gc
9
- from gradio_helper import hf_spaces_gpu, zerogpu_available
10
- import torch
11
- import numpy as np
12
- import torch.nn as nn
13
- from typing import Literal, Optional, List, Tuple, Any, Dict
14
-
15
- from audio import read, multiwrite, output_formats, subtractor, bitrate_to_int
16
- from namer import Namer
17
- from i18n import _i18n
18
-
19
- namer = Namer()
20
-
21
- from infer_utils import demix, get_model_from_config
22
-
23
-
24
- def normalize_peak(audio: np.ndarray, peak: float) -> np.ndarray:
25
- """
26
- Нормализовать аудио по пиковому значению
27
-
28
- Args:
29
- audio: Аудиоданные
30
- peak: Целевое пиковое значение
31
-
32
- Returns:
33
- Нормализованные аудиоданные
34
- """
35
- current_peak = np.max(np.abs(audio))
36
- if current_peak == 0:
37
- return audio
38
- scale_factor = peak / current_peak
39
- return audio * scale_factor
40
-
41
-
42
- def create_output_path(
43
- input_path: str,
44
- stem_name: str,
45
- model_name: str,
46
- model_id: int,
47
- output_format: str,
48
- store_dir: str,
49
- template: str
50
- ) -> str:
51
- """
52
- Создать путь для выходного файла
53
-
54
- Args:
55
- input_path: Путь к входному файлу
56
- stem_name: Имя стема
57
- model_name: Имя модели
58
- model_id: ID модели
59
- output_format: Формат вывода
60
- store_dir: Директория для сохранения
61
- template: Шаблон имени
62
-
63
- Returns:
64
- Путь к выходному файлу
65
- """
66
- file_name = os.path.splitext(os.path.basename(input_path))[0]
67
- file_name_shorted = namer.short_input_name_template(
68
- template, STEM=stem_name, MODEL=model_name, ID=model_id, NAME=file_name
69
- )
70
- custom_name = namer.template(
71
- template,
72
- STEM=stem_name,
73
- MODEL=model_name,
74
- ID=model_id,
75
- NAME=file_name_shorted,
76
- )
77
- return os.path.join(store_dir, f"{custom_name}.{output_format}")
78
-
79
-
80
- gc.enable()
81
-
82
-
83
- def cleanup_model(model: Optional[nn.Module]) -> None:
84
- """
85
- Очистить модель из памяти
86
-
87
- Args:
88
- model: Модель для очистки
89
- """
90
- try:
91
- if model is None:
92
- return
93
-
94
- if isinstance(model, torch.nn.DataParallel):
95
- model = model.module
96
-
97
- model.to("cpu")
98
-
99
- for name, param in list(model.named_parameters()):
100
- del param
101
- for name, buf in list(model.named_buffers()):
102
- del buf
103
-
104
- del model
105
-
106
- if torch.cuda.is_available():
107
- torch.cuda.empty_cache()
108
- torch.cuda.ipc_collect()
109
-
110
- gc.collect()
111
- except Exception as e:
112
- pass
113
-
114
-
115
- def once_inference(
116
- path: str = None,
117
- model: Any = None,
118
- config: Any = None,
119
- device: Any = None,
120
- model_type: str = None,
121
- extract_instrumental: bool = False,
122
- output_format: Literal[
123
- "mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "aiff"
124
- ] = "mp3",
125
- output_bitrate: str = "320k",
126
- model_name: str = None,
127
- sample_rate: int = 44100,
128
- instruments: List[str] = [],
129
- store_dir: str = None,
130
- template: str = None,
131
- selected_instruments: List[str] = [],
132
- model_id: int = 0,
133
- spec_invert_target_instrument: bool = False
134
- ) -> List[Tuple[str, str]]:
135
- """
136
- Однократный инференс
137
-
138
- Args:
139
- path: Путь к входному файлу
140
- model: Модель
141
- config: Конфигурация
142
- device: Устройство
143
- model_type: Тип модели
144
- extract_instrumental: Извлечь инструментал
145
- output_format: Формат вывода
146
- output_bitrate: Битрейт
147
- model_name: Имя модели
148
- sample_rate: Частота дискретизации
149
- instruments: Список инструментов
150
- store_dir: Директория для сохранения
151
- template: Шаблон имени
152
- selected_instruments: Выбранные инструменты
153
- model_id: ID модели
154
- spec_invert_target_instrument: Инвертировать спектрограмму для целевого инструмента
155
-
156
- Returns:
157
- Список кортежей (имя стема, путь к файлу)
158
- """
159
- results = []
160
- sys.stdout.write(json.dumps({"reading": path}, ensure_ascii=False) + "\n")
161
- sys.stdout.flush()
162
- sys.stdout.write(
163
- json.dumps({"selected_stems": selected_instruments}, ensure_ascii=False) + "\n"
164
- )
165
- sys.stdout.flush()
166
-
167
- output_instruments = []
168
- output_waveforms = {}
169
-
170
- mono_bool = False
171
- if hasattr(config, "model"):
172
- if hasattr(config.model, "stereo"):
173
- mono_bool = False if config.model.stereo else True
174
- try:
175
- mix, sr = read(path=path, sr=sample_rate, mono=mono_bool)
176
- except Exception as e:
177
- error_msg = _i18n("audio_read_error", path=path, error=str(e))
178
- sys.stdout.write(json.dumps({"error": error_msg}, ensure_ascii=False) + "\n")
179
- sys.stdout.flush()
180
- return results
181
-
182
- mix_orig = mix.copy()
183
-
184
- mean = std = None
185
- if config.inference.get("normalize", False):
186
- mono = mix.mean(0)
187
- mean = mono.mean()
188
- std = mono.std()
189
- mix = (mix - mean) / std
190
-
191
- waveforms = {}
192
-
193
- try:
194
- waveforms = demix(
195
- config, model, mix_orig, device, model_type
196
- )
197
- except Exception as e:
198
- sys.stdout.write(
199
- json.dumps({"error": _i18n("demix_error", error=str(e))}, ensure_ascii=False)
200
- + "\n"
201
- )
202
- sys.stdout.flush()
203
- gc.collect()
204
-
205
- if not waveforms:
206
- sys.stdout.write(
207
- json.dumps({"error": _i18n("empty_demix_result")}, ensure_ascii=False)
208
- + "\n"
209
- )
210
- sys.stdout.flush()
211
- return results
212
-
213
- # Если обнаружен целевой инструмент и не выбрано ни одного стема
214
- if config.training.target_instrument is not None:
215
- if not selected_instruments:
216
- output_waveforms[config.training.target_instrument] = waveforms[config.training.target_instrument]
217
- second_stem = None
218
- for instr_ in instruments:
219
- if instr_ != config.training.target_instrument:
220
- second_stem = instr_
221
- break
222
- if second_stem:
223
- output_waveforms[second_stem] = subtractor(mix_orig, waveforms[config.training.target_instrument], sample_rate, sample_rate, spectrogram=spec_invert_target_instrument)[0]
224
- else: # Если обнаружен целевой инструмент и выбран хотя бы один стем
225
- if config.training.target_instrument in selected_instruments:
226
- output_waveforms[config.training.target_instrument] = waveforms[config.training.target_instrument]
227
- second_stem = None
228
- for instr_ in instruments:
229
- if instr_ != config.training.target_instrument:
230
- second_stem = instr_
231
- break
232
- if second_stem and second_stem in selected_instruments:
233
- output_waveforms[second_stem] = subtractor(mix_orig, waveforms[config.training.target_instrument], sample_rate, sample_rate, spectrogram=spec_invert_target_instrument)[0]
234
-
235
- elif config.training.target_instrument is None:
236
- if not selected_instruments:
237
- for instr in waveforms:
238
- output_waveforms[instr] = waveforms[instr]
239
- if extract_instrumental:
240
- if (
241
- all(
242
- instr in instruments
243
- for instr in ["bass", "drums", "other", "vocals"]
244
- )
245
- or all(
246
- instr in instruments
247
- for instr in ["bass", "drums", "other", "vocals", "piano", "guitar"]
248
- )
249
- ):
250
- output_waveforms["instrumental -"] = mix_orig.copy()
251
- output_waveforms["instrumental -"] = subtractor(output_waveforms["instrumental -"], waveforms["vocals"], sample_rate, sample_rate, spectrogram=spec_invert_target_instrument)[0]
252
-
253
- non_vocal_stems = [s for s in instruments if s not in ["vocals"]]
254
- if non_vocal_stems:
255
- output_waveforms["instrumental +"] = np.zeros_like(mix_orig)
256
- for stem in non_vocal_stems:
257
- if stem in waveforms:
258
- output_waveforms["instrumental +"] += waveforms[stem]
259
-
260
- peak = np.max(np.abs(output_waveforms["instrumental -"]))
261
- output_waveforms["instrumental +"] = normalize_peak(output_waveforms["instrumental +"], peak)
262
- else:
263
- for instr in waveforms:
264
- if instr in selected_instruments:
265
- output_waveforms[instr] = waveforms[instr]
266
- if extract_instrumental:
267
- if len(instruments) >= 3:
268
- output_waveforms["inverted -"] = mix_orig.copy()
269
- for instr_ in selected_instruments:
270
- if instr_ in waveforms:
271
- output_waveforms["inverted -"] = subtractor(output_waveforms["inverted -"], waveforms[instr_], sample_rate, sample_rate, spectrogram=spec_invert_target_instrument)[0]
272
-
273
- unselected_stems = [
274
- s for s in instruments if s not in selected_instruments
275
- ]
276
- if unselected_stems:
277
- output_waveforms["inverted +"] = np.zeros_like(mix_orig)
278
- for stem in unselected_stems:
279
- if stem in waveforms:
280
- output_waveforms["inverted +"] += waveforms[stem]
281
- if "inverted +" not in instruments:
282
- instruments.append("inverted +")
283
-
284
- peak = np.max(np.abs(output_waveforms["inverted -"]))
285
- output_waveforms["inverted +"] = normalize_peak(output_waveforms["inverted +"], peak)
286
-
287
- output_instruments = [instr__ for instr__ in output_waveforms]
288
-
289
- # Подготовка шаблона
290
- template = namer.sanitize(template)
291
- template = namer.dedup_template(template, keys=["NAME", "MODEL", "STEM", "ID"])
292
- template = namer.short(template, length=40)
293
-
294
- output_paths = [create_output_path(path, instr, model_name, model_id, output_format, store_dir, template) for instr in output_instruments]
295
-
296
- if mean is not None and std is not None:
297
- output_arrays = [output_waveforms[instr] * std + mean for instr in output_instruments]
298
- else:
299
- output_arrays = [output_waveforms[instr] for instr in output_instruments]
300
- output_sample_rates = [sample_rate for _c in range(len(output_instruments))]
301
-
302
- def flush_writing_file(file: str) -> None:
303
- sys.stdout.write(
304
- json.dumps({"writing": file}, ensure_ascii=False) + "\n"
305
- )
306
- sys.stdout.flush()
307
-
308
- try:
309
- writed_files = multiwrite(output_arrays, output_sample_rates, [namer.iter(output_path_) for output_path_ in output_paths], output_bitrate, callable_func=flush_writing_file, strict=True)
310
- except Exception as e:
311
- sys.stdout.write(
312
- json.dumps(
313
- {"error": _i18n("write_error", error=str(e))}, ensure_ascii=False
314
- )
315
- + "\n"
316
- )
317
- sys.stdout.flush()
318
- gc.collect()
319
-
320
- results = list(zip(output_instruments, writed_files))
321
-
322
- del mix, mix_orig, waveforms, output_arrays
323
- gc.collect()
324
-
325
- return results
326
-
327
-
328
- def run_inference(
329
- model: Any = None,
330
- config: Any = None,
331
- input_path: str = None,
332
- store_dir: str = None,
333
- device: Any = None,
334
- model_type: str = None,
335
- extract_instrumental: bool = False,
336
- output_format: Literal[
337
- "mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "aiff"
338
- ] = "mp3",
339
- output_bitrate: str = "320k",
340
- model_name: str = None,
341
- template: str = "NAME_STEM",
342
- selected_instruments: List[str] = [],
343
- model_id: int = 0,
344
- spec_invert_target_instrument: bool = False
345
- ) -> List[Tuple[str, str]]:
346
- """
347
- Запустить инференс
348
-
349
- Args:
350
- model: Модель
351
- config: Конфигурация
352
- input_path: Путь к входному файлу
353
- store_dir: Директория для сохранения
354
- device: Устройство
355
- model_type: Тип модели
356
- extract_instrumental: Извлечь инструментал
357
- output_format: Формат вывода
358
- output_bitrate: Битрейт
359
- model_name: Имя модели
360
- template: Шаблон имени
361
- selected_instruments: Выбранные инструменты
362
- model_id: ID модели
363
- spec_invert_target_instrument: Инвертировать спектрограмму для целевого инструмента
364
-
365
- Returns:
366
- Список кортежей (имя стема, путь к файлу)
367
- """
368
- start_time = time.time()
369
- if model_type != "vr":
370
- model.eval()
371
- sample_rate = 44100
372
- if "sample_rate" in config.audio:
373
- sample_rate = config.audio["sample_rate"]
374
-
375
- instruments = config.training.instruments
376
-
377
- os.makedirs(store_dir, exist_ok=True)
378
-
379
- results = once_inference(
380
- path=input_path,
381
- model=model,
382
- config=config,
383
- device=device,
384
- model_type=model_type,
385
- extract_instrumental=extract_instrumental,
386
- output_format=output_format,
387
- output_bitrate=output_bitrate,
388
- model_name=model_name,
389
- sample_rate=sample_rate,
390
- instruments=instruments,
391
- store_dir=store_dir,
392
- template=template,
393
- selected_instruments=selected_instruments,
394
- model_id=model_id,
395
- spec_invert_target_instrument=spec_invert_target_instrument
396
- )
397
-
398
- time.sleep(1)
399
- time_taken = time.time() - start_time
400
- sys.stdout.write(
401
- json.dumps({"time": _i18n("time_seconds", seconds=f"{time_taken:.2f}")}, ensure_ascii=False) + "\n"
402
- )
403
- sys.stdout.flush()
404
- sys.stdout.write(json.dumps({"done": results}, ensure_ascii=False) + "\n")
405
- sys.stdout.flush()
406
- return results
407
-
408
-
409
- def load_model(
410
- model_type: str,
411
- config_path: str,
412
- start_check_point: str,
413
- device: str
414
- ) -> Tuple[Any, Any, torch.device]:
415
- """
416
- Загрузить модель
417
-
418
- Args:
419
- model_type: Тип модели
420
- config_path: Путь к конфигурации
421
- start_check_point: Путь к чекпоинту
422
- device: Строка устройства
423
-
424
- Returns:
425
- Кортеж (модель, конфигурация, устройство)
426
- """
427
- sys.stdout.write(json.dumps({"device": device}, ensure_ascii=False) + "\n")
428
- sys.stdout.flush()
429
-
430
- # Определяем тип устройства
431
- if "cuda" in device.lower():
432
- # Извлекаем ID устройств для CUDA
433
- if ":" in device:
434
- device_spec = device.split(":")[1]
435
- device_ids = [int(id) for id in device_spec.split(",") if id.isdigit()]
436
- else:
437
- # Если указано просто "cuda", используем все доступные GPU
438
- device_ids = list(range(torch.cuda.device_count()))
439
- torch_device = torch.device("cuda" if not device_ids else f"cuda:{device_ids[0]}")
440
- elif "mps" in device.lower():
441
- device_ids = None
442
- torch_device = torch.device("mps")
443
- else:
444
- # CPU
445
- device_ids = None
446
- torch_device = torch.device("cpu")
447
-
448
- model_load_start_time = time.time()
449
-
450
- # Устанавливаем оптимизации только для CUDA
451
- if torch_device.type == "cuda":
452
- if hasattr(torch, 'backends'):
453
- if hasattr(torch.backends, 'cudnn'):
454
- torch.backends.cudnn.benchmark = True
455
-
456
- if hasattr(torch.backends.cudnn, 'allow_tf32'):
457
- torch.backends.cudnn.allow_tf32 = True
458
-
459
- if hasattr(torch.backends, 'cuda') and hasattr(torch.backends.cuda, 'matmul'):
460
- if hasattr(torch.backends.cuda.matmul, 'allow_tf32'):
461
- torch.backends.cuda.matmul.allow_tf32 = True
462
-
463
- model, config = get_model_from_config(model_type, config_path)
464
-
465
- if model_type == "vr":
466
- enable_post_process = False
467
- if hasattr(config.inference, "enable_post_process"):
468
- enable_post_process = config.inference.enable_post_process
469
- model.load_checkpoint(start_check_point, torch_device)
470
- model.settings(
471
- enable_post_process=enable_post_process,
472
- post_process_threshold=config.inference.post_process_threshold,
473
- batch_size=config.inference.batch_size,
474
- window_size=config.inference.window_size,
475
- high_end_process=config.inference.high_end_process,
476
- primary_stem=config.training.instruments[0],
477
- secondary_stem=config.training.instruments[1],
478
- )
479
- return model, config, torch_device
480
-
481
- elif model_type == "medley_vox":
482
- if start_check_point != "":
483
- checkpoint = torch.load(start_check_point, map_location=torch_device)
484
- if config.model.ema:
485
- model_dict = model.state_dict()
486
- # 1. filter out unnecessary keys
487
- checkpoint = {
488
- k.replace("ema_model.module.", ""): v
489
- for k, v in checkpoint.items()
490
- if k.replace("ema_model.module.", "") in model_dict
491
- }
492
- # 2. overwrite entries in the existing state dict
493
- model_dict.update(checkpoint)
494
- # 3. load the new state dict
495
- model.load_state_dict(model_dict)
496
- elif not config.model.ema:
497
- model_dict = model.state_dict()
498
- # 1. filter out unnecessary keys
499
- checkpoint = {
500
- k.replace("online_model.module.", ""): v
501
- for k, v in checkpoint.items()
502
- if k.replace("online_model.module.", "") in model_dict
503
- }
504
- # 2. overwrite entries in the existing state dict
505
- model_dict.update(checkpoint)
506
- # 3. load the new state dict
507
- model.load_state_dict(model_dict)
508
- else:
509
- model.load_state_dict(checkpoint)
510
- model.eval()
511
- return model, config, torch_device
512
-
513
- elif model_type == "mdxnet":
514
- if start_check_point != "":
515
- sys.stdout.write(json.dumps({"checkpoint": start_check_point}) + "\n")
516
- sys.stdout.flush()
517
- model.init_onnx_session(start_check_point, torch_device, device_ids)
518
- return model, config, torch_device
519
-
520
- else:
521
- if start_check_point != "":
522
- sys.stdout.write(json.dumps({"checkpoint": start_check_point}) + "\n")
523
- sys.stdout.flush()
524
-
525
- if model_type in ["htdemucs", "apollo"]:
526
- state_dict = torch.load(
527
- start_check_point, map_location=torch_device, weights_only=False
528
- )
529
- else:
530
- if hasattr(config, "fno"):
531
- with torch.serialization.safe_globals([torch._C._nn.gelu]):
532
- state_dict = torch.load(
533
- start_check_point, map_location=torch_device, weights_only=True
534
- )
535
- else:
536
- try:
537
- state_dict = torch.load(
538
- start_check_point, map_location=torch_device, weights_only=True
539
- )
540
- except torch.serialization.pickle.UnpicklingError:
541
- state_dict = torch.load(
542
- start_check_point,
543
- map_location=torch_device,
544
- weights_only=False
545
- )
546
-
547
- if "state" in state_dict:
548
- state_dict = state_dict["state"]
549
- if "state_dict" in state_dict:
550
- state_dict = state_dict["state_dict"]
551
- if "model_state_dict" in state_dict:
552
- state_dict = state_dict["model_state_dict"]
553
-
554
- try:
555
- model.load_state_dict(state_dict)
556
- except RuntimeError as e:
557
- sys.stdout.write(
558
- json.dumps({"stems": ["error", "error"]}, ensure_ascii=False)
559
- + "\n"
560
- )
561
- sys.stdout.write(
562
- json.dumps({"stems": [str(e)]}, ensure_ascii=False)
563
- + "\n"
564
- )
565
- print(_i18n("state_dict_load_warning", error=str(e)))
566
- model.load_state_dict(state_dict, strict=False)
567
-
568
- sys.stdout.write(
569
- json.dumps({"stems": list(config.training.instruments)}, ensure_ascii=False)
570
- + "\n"
571
- )
572
- sys.stdout.flush()
573
-
574
- # Перемещаем модель на устройство
575
- model = model.to(torch_device)
576
-
577
- # Используем DataParallel только если есть несколько GPU и это не MPS
578
- if torch_device.type == "cuda" and len(device_ids) > 1:
579
- model = nn.DataParallel(model, device_ids=device_ids)
580
- print(_i18n("using_dataparallel", devices=device_ids))
581
-
582
- load_time = time.time() - model_load_start_time
583
-
584
- sys.stdout.write(
585
- json.dumps({"model_load_time": _i18n("time_seconds", seconds=f"{load_time:.2f}")}, ensure_ascii=False)
586
- + "\n"
587
- )
588
- sys.stdout.flush()
589
-
590
- return model, config, torch_device
591
-
592
-
593
- def mvsep_offline(
594
- input_path: str,
595
- store_dir: str,
596
- model_type: str,
597
- config_path: str,
598
- start_check_point: str,
599
- extract_instrumental: bool,
600
- output_format: str,
601
- output_bitrate: str,
602
- model_name: str,
603
- template: str,
604
- device: str = "cpu",
605
- selected_instruments: Optional[List[str]] = None,
606
- model_id: int = 0,
607
- spec_invert_target_instrument: bool = False
608
- ) -> List[Tuple[str, str]]:
609
- """
610
- Оффлайн разделение
611
-
612
- Args:
613
- input_path: Путь к входному файлу
614
- store_dir: Директория для сохранения
615
- model_type: Тип модели
616
- config_path: Путь к конфигурации
617
- start_check_point: Путь к чекпоинту
618
- extract_instrumental: Извлечь инструментал
619
- output_format: Формат вывода
620
- output_bitrate: Битрейт
621
- model_name: Имя модели
622
- template: Шаблон имени
623
- device: Устройство
624
- selected_instruments: Выбранные инструменты
625
- model_id: ID модели
626
- spec_invert_target_instrument: Инвертировать спектрограмму для целевого инструмента
627
-
628
- Returns:
629
- Список кортежей (имя стема, путь к файлу)
630
- """
631
- model, config, device = load_model(
632
- model_type, config_path, start_check_point, device
633
- )
634
-
635
- results = run_inference(
636
- model=model,
637
- config=config,
638
- input_path=input_path,
639
- store_dir=store_dir,
640
- device=device,
641
- model_type=model_type,
642
- extract_instrumental=extract_instrumental,
643
- output_format=output_format,
644
- output_bitrate=output_bitrate,
645
- model_name=model_name,
646
- template=template,
647
- selected_instruments=selected_instruments or [],
648
- model_id=model_id,
649
- spec_invert_target_instrument=spec_invert_target_instrument
650
- )
651
-
652
- if model_type != "vr":
653
- cleanup_model(model)
654
- del config
655
- gc.collect()
656
- return results
657
-
658
-
659
- def parse_args() -> argparse.Namespace:
660
- """Парсинг аргументов командной строки"""
661
- parser = argparse.ArgumentParser(
662
- description=_i18n("infer_description")
663
- )
664
-
665
- parser.add_argument("--input", type=str, required=True, help=_i18n("input_path_help"))
666
- parser.add_argument(
667
- "--store_dir", type=str, required=True, help=_i18n("store_dir_help")
668
- )
669
-
670
- parser.add_argument(
671
- "--model_type",
672
- type=str,
673
- default="htdemucs",
674
- choices=[
675
- "mel_band_roformer",
676
- "bs_roformer",
677
- "mdx23c",
678
- "scnet",
679
- "scnet_masked",
680
- "scnet_tran",
681
- "htdemucs",
682
- "bandit",
683
- "bandit_v2",
684
- "mdxnet",
685
- "vr",
686
- "medley_vox"
687
- ],
688
- help=_i18n("model_type_help"),
689
- )
690
- parser.add_argument(
691
- "--config_path",
692
- type=str,
693
- required=True,
694
- help=_i18n("config_path_help"),
695
- )
696
- parser.add_argument(
697
- "--start_check_point", type=str, required=True, help=_i18n("checkpoint_help")
698
- )
699
-
700
- parser.add_argument(
701
- "--output_format",
702
- type=str,
703
- default="wav",
704
- choices=output_formats,
705
- help=_i18n("output_format_help"),
706
- )
707
- parser.add_argument(
708
- "--output_bitrate", type=str, required=True, help=_i18n("output_bitrate_help")
709
- )
710
-
711
- parser.add_argument(
712
- "--selected_instruments",
713
- nargs="+",
714
- help=_i18n("selected_instruments_help"),
715
- )
716
- parser.add_argument(
717
- "--extract_instrumental",
718
- action="store_true",
719
- help=_i18n("extract_instrumental_help"),
720
- )
721
- parser.add_argument(
722
- "--use_spec_invert",
723
- action="store_true",
724
- help=_i18n("use_spec_invert_help"),
725
- )
726
- parser.add_argument(
727
- "--template",
728
- type=str,
729
- default="NAME_STEM",
730
- help=_i18n("template_help"),
731
- )
732
- parser.add_argument(
733
- "--model_name",
734
- type=str,
735
- default="model",
736
- help=_i18n("model_name_help"),
737
- )
738
- parser.add_argument("-m_id", "--model_id", type=int, required=True, help=_i18n("model_id_help"))
739
- parser.add_argument(
740
- "--device", type=str, help=_i18n("device_help"), default="cuda:0"
741
- )
742
- parser.add_argument("--verbose", action="store_true", help=_i18n("verbose_help"))
743
-
744
- return parser.parse_args()
745
-
746
- @hf_spaces_gpu(duration=80)
747
- def main() -> None:
748
- """Главная функция"""
749
- args = parse_args()
750
-
751
- results = mvsep_offline(
752
- input_path=args.input,
753
- store_dir=args.store_dir,
754
- model_type=args.model_type,
755
- config_path=args.config_path,
756
- start_check_point=args.start_check_point,
757
- extract_instrumental=args.extract_instrumental,
758
- output_format=args.output_format,
759
- output_bitrate=args.output_bitrate,
760
- model_name=args.model_name,
761
- template=args.template,
762
- device="cuda:0" if zerogpu_available else args.device,
763
- selected_instruments=args.selected_instruments,
764
- model_id=args.model_id,
765
- spec_invert_target_instrument=args.use_spec_invert
766
- )
767
-
768
- if __name__ == "__main__":
769
  main()
 
1
+ import os
2
+ import sys
3
+ sys.stdout.reconfigure(encoding='utf-8')
4
+ sys.stderr.reconfigure(encoding='utf-8')
5
+ import json
6
+ import argparse
7
+ import time
8
+ import gc
9
+ import torch
10
+ import numpy as np
11
+ import torch.nn as nn
12
+ from typing import Literal, Optional, List, Tuple, Any, Dict
13
+
14
+ from audio import read, multiwrite, output_formats, subtractor, bitrate_to_int
15
+ from namer import Namer
16
+ from i18n import _i18n
17
+
18
+ namer = Namer()
19
+
20
+ from infer_utils import demix, get_model_from_config
21
+
22
+
23
+ def normalize_peak(audio: np.ndarray, peak: float) -> np.ndarray:
24
+ """
25
+ Нормализовать аудио по пиковому значению
26
+
27
+ Args:
28
+ audio: Аудиоданные
29
+ peak: Целевое пиковое значение
30
+
31
+ Returns:
32
+ Нормализованные аудиоданные
33
+ """
34
+ current_peak = np.max(np.abs(audio))
35
+ if current_peak == 0:
36
+ return audio
37
+ scale_factor = peak / current_peak
38
+ return audio * scale_factor
39
+
40
+
41
+ def create_output_path(
42
+ input_path: str,
43
+ stem_name: str,
44
+ model_name: str,
45
+ model_id: int,
46
+ output_format: str,
47
+ store_dir: str,
48
+ template: str
49
+ ) -> str:
50
+ """
51
+ Создать путь для выходного файла
52
+
53
+ Args:
54
+ input_path: Путь к входному файлу
55
+ stem_name: Имя стема
56
+ model_name: Имя модели
57
+ model_id: ID модели
58
+ output_format: Формат вывода
59
+ store_dir: Директория для сохранения
60
+ template: Шаблон имени
61
+
62
+ Returns:
63
+ Путь к выходному файлу
64
+ """
65
+ file_name = os.path.splitext(os.path.basename(input_path))[0]
66
+ file_name_shorted = namer.short_input_name_template(
67
+ template, STEM=stem_name, MODEL=model_name, ID=model_id, NAME=file_name
68
+ )
69
+ custom_name = namer.template(
70
+ template,
71
+ STEM=stem_name,
72
+ MODEL=model_name,
73
+ ID=model_id,
74
+ NAME=file_name_shorted,
75
+ )
76
+ return os.path.join(store_dir, f"{custom_name}.{output_format}")
77
+
78
+
79
+ gc.enable()
80
+
81
+
82
+ def cleanup_model(model: Optional[nn.Module]) -> None:
83
+ """
84
+ Очистить модель из памяти
85
+
86
+ Args:
87
+ model: Модель для очистки
88
+ """
89
+ try:
90
+ if model is None:
91
+ return
92
+
93
+ if isinstance(model, torch.nn.DataParallel):
94
+ model = model.module
95
+
96
+ model.to("cpu")
97
+
98
+ for name, param in list(model.named_parameters()):
99
+ del param
100
+ for name, buf in list(model.named_buffers()):
101
+ del buf
102
+
103
+ del model
104
+
105
+ if torch.cuda.is_available():
106
+ torch.cuda.empty_cache()
107
+ torch.cuda.ipc_collect()
108
+
109
+ gc.collect()
110
+ except Exception as e:
111
+ pass
112
+
113
+
114
+ def once_inference(
115
+ path: str = None,
116
+ model: Any = None,
117
+ config: Any = None,
118
+ device: Any = None,
119
+ model_type: str = None,
120
+ extract_instrumental: bool = False,
121
+ output_format: Literal[
122
+ "mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "aiff"
123
+ ] = "mp3",
124
+ output_bitrate: str = "320k",
125
+ model_name: str = None,
126
+ sample_rate: int = 44100,
127
+ instruments: List[str] = [],
128
+ store_dir: str = None,
129
+ template: str = None,
130
+ selected_instruments: List[str] = [],
131
+ model_id: int = 0,
132
+ spec_invert_target_instrument: bool = False
133
+ ) -> List[Tuple[str, str]]:
134
+ """
135
+ Однократный инференс
136
+
137
+ Args:
138
+ path: Путь к входному файлу
139
+ model: Модель
140
+ config: Конфигурация
141
+ device: Устройство
142
+ model_type: Тип модели
143
+ extract_instrumental: Извлечь инструментал
144
+ output_format: Формат вывода
145
+ output_bitrate: Битрейт
146
+ model_name: Имя модели
147
+ sample_rate: Частота дискретизации
148
+ instruments: Список инструментов
149
+ store_dir: Директория для сохранения
150
+ template: Шаблон имени
151
+ selected_instruments: Выбранные инструменты
152
+ model_id: ID модели
153
+ spec_invert_target_instrument: Инвертировать спектрограмму для целевого инструмента
154
+
155
+ Returns:
156
+ Список кортежей (имя стема, путь к файлу)
157
+ """
158
+ results = []
159
+ sys.stdout.write(json.dumps({"reading": path}, ensure_ascii=False) + "\n")
160
+ sys.stdout.flush()
161
+ sys.stdout.write(
162
+ json.dumps({"selected_stems": selected_instruments}, ensure_ascii=False) + "\n"
163
+ )
164
+ sys.stdout.flush()
165
+
166
+ output_instruments = []
167
+ output_waveforms = {}
168
+
169
+ mono_bool = False
170
+ if hasattr(config, "model"):
171
+ if hasattr(config.model, "stereo"):
172
+ mono_bool = False if config.model.stereo else True
173
+ try:
174
+ mix, sr = read(path=path, sr=sample_rate, mono=mono_bool)
175
+ except Exception as e:
176
+ error_msg = _i18n("audio_read_error", path=path, error=str(e))
177
+ sys.stdout.write(json.dumps({"error": error_msg}, ensure_ascii=False) + "\n")
178
+ sys.stdout.flush()
179
+ return results
180
+
181
+ mix_orig = mix.copy()
182
+
183
+ mean = std = None
184
+ if config.inference.get("normalize", False):
185
+ mono = mix.mean(0)
186
+ mean = mono.mean()
187
+ std = mono.std()
188
+ mix = (mix - mean) / std
189
+
190
+ waveforms = {}
191
+
192
+ try:
193
+ waveforms = demix(
194
+ config, model, mix_orig, device, model_type
195
+ )
196
+ except Exception as e:
197
+ sys.stdout.write(
198
+ json.dumps({"error": _i18n("demix_error", error=str(e))}, ensure_ascii=False)
199
+ + "\n"
200
+ )
201
+ sys.stdout.flush()
202
+ gc.collect()
203
+
204
+ if not waveforms:
205
+ sys.stdout.write(
206
+ json.dumps({"error": _i18n("empty_demix_result")}, ensure_ascii=False)
207
+ + "\n"
208
+ )
209
+ sys.stdout.flush()
210
+ return results
211
+
212
+ # Если обнаружен целевой инструмент и не выбрано ни одного стема
213
+ if config.training.target_instrument is not None:
214
+ if not selected_instruments:
215
+ output_waveforms[config.training.target_instrument] = waveforms[config.training.target_instrument]
216
+ second_stem = None
217
+ for instr_ in instruments:
218
+ if instr_ != config.training.target_instrument:
219
+ second_stem = instr_
220
+ break
221
+ if second_stem:
222
+ output_waveforms[second_stem] = subtractor(mix_orig, waveforms[config.training.target_instrument], sample_rate, sample_rate, spectrogram=spec_invert_target_instrument)[0]
223
+ else: # Если обнаружен целевой инструмент и выбран хотя бы один стем
224
+ if config.training.target_instrument in selected_instruments:
225
+ output_waveforms[config.training.target_instrument] = waveforms[config.training.target_instrument]
226
+ second_stem = None
227
+ for instr_ in instruments:
228
+ if instr_ != config.training.target_instrument:
229
+ second_stem = instr_
230
+ break
231
+ if second_stem and second_stem in selected_instruments:
232
+ output_waveforms[second_stem] = subtractor(mix_orig, waveforms[config.training.target_instrument], sample_rate, sample_rate, spectrogram=spec_invert_target_instrument)[0]
233
+
234
+ elif config.training.target_instrument is None:
235
+ if not selected_instruments:
236
+ for instr in waveforms:
237
+ output_waveforms[instr] = waveforms[instr]
238
+ if extract_instrumental:
239
+ if (
240
+ all(
241
+ instr in instruments
242
+ for instr in ["bass", "drums", "other", "vocals"]
243
+ )
244
+ or all(
245
+ instr in instruments
246
+ for instr in ["bass", "drums", "other", "vocals", "piano", "guitar"]
247
+ )
248
+ ):
249
+ output_waveforms["instrumental -"] = mix_orig.copy()
250
+ output_waveforms["instrumental -"] = subtractor(output_waveforms["instrumental -"], waveforms["vocals"], sample_rate, sample_rate, spectrogram=spec_invert_target_instrument)[0]
251
+
252
+ non_vocal_stems = [s for s in instruments if s not in ["vocals"]]
253
+ if non_vocal_stems:
254
+ output_waveforms["instrumental +"] = np.zeros_like(mix_orig)
255
+ for stem in non_vocal_stems:
256
+ if stem in waveforms:
257
+ output_waveforms["instrumental +"] += waveforms[stem]
258
+
259
+ peak = np.max(np.abs(output_waveforms["instrumental -"]))
260
+ output_waveforms["instrumental +"] = normalize_peak(output_waveforms["instrumental +"], peak)
261
+ else:
262
+ for instr in waveforms:
263
+ if instr in selected_instruments:
264
+ output_waveforms[instr] = waveforms[instr]
265
+ if extract_instrumental:
266
+ if len(instruments) >= 3:
267
+ output_waveforms["inverted -"] = mix_orig.copy()
268
+ for instr_ in selected_instruments:
269
+ if instr_ in waveforms:
270
+ output_waveforms["inverted -"] = subtractor(output_waveforms["inverted -"], waveforms[instr_], sample_rate, sample_rate, spectrogram=spec_invert_target_instrument)[0]
271
+
272
+ unselected_stems = [
273
+ s for s in instruments if s not in selected_instruments
274
+ ]
275
+ if unselected_stems:
276
+ output_waveforms["inverted +"] = np.zeros_like(mix_orig)
277
+ for stem in unselected_stems:
278
+ if stem in waveforms:
279
+ output_waveforms["inverted +"] += waveforms[stem]
280
+ if "inverted +" not in instruments:
281
+ instruments.append("inverted +")
282
+
283
+ peak = np.max(np.abs(output_waveforms["inverted -"]))
284
+ output_waveforms["inverted +"] = normalize_peak(output_waveforms["inverted +"], peak)
285
+
286
+ output_instruments = [instr__ for instr__ in output_waveforms]
287
+
288
+ # Подготовка шаблона
289
+ template = namer.sanitize(template)
290
+ template = namer.dedup_template(template, keys=["NAME", "MODEL", "STEM", "ID"])
291
+ template = namer.short(template, length=40)
292
+
293
+ output_paths = [create_output_path(path, instr, model_name, model_id, output_format, store_dir, template) for instr in output_instruments]
294
+
295
+ if mean is not None and std is not None:
296
+ output_arrays = [output_waveforms[instr] * std + mean for instr in output_instruments]
297
+ else:
298
+ output_arrays = [output_waveforms[instr] for instr in output_instruments]
299
+ output_sample_rates = [sample_rate for _c in range(len(output_instruments))]
300
+
301
+ def flush_writing_file(file: str) -> None:
302
+ sys.stdout.write(
303
+ json.dumps({"writing": file}, ensure_ascii=False) + "\n"
304
+ )
305
+ sys.stdout.flush()
306
+
307
+ try:
308
+ writed_files = multiwrite(output_arrays, output_sample_rates, [namer.iter(output_path_) for output_path_ in output_paths], output_bitrate, callable_func=flush_writing_file, strict=True)
309
+ except Exception as e:
310
+ sys.stdout.write(
311
+ json.dumps(
312
+ {"error": _i18n("write_error", error=str(e))}, ensure_ascii=False
313
+ )
314
+ + "\n"
315
+ )
316
+ sys.stdout.flush()
317
+ gc.collect()
318
+
319
+ results = list(zip(output_instruments, writed_files))
320
+
321
+ del mix, mix_orig, waveforms, output_arrays
322
+ gc.collect()
323
+
324
+ return results
325
+
326
+
327
+ def run_inference(
328
+ model: Any = None,
329
+ config: Any = None,
330
+ input_path: str = None,
331
+ store_dir: str = None,
332
+ device: Any = None,
333
+ model_type: str = None,
334
+ extract_instrumental: bool = False,
335
+ output_format: Literal[
336
+ "mp3", "wav", "flac", "ogg", "opus", "m4a", "aac", "aiff"
337
+ ] = "mp3",
338
+ output_bitrate: str = "320k",
339
+ model_name: str = None,
340
+ template: str = "NAME_STEM",
341
+ selected_instruments: List[str] = [],
342
+ model_id: int = 0,
343
+ spec_invert_target_instrument: bool = False
344
+ ) -> List[Tuple[str, str]]:
345
+ """
346
+ Запустить инференс
347
+
348
+ Args:
349
+ model: Модель
350
+ config: Конфигурация
351
+ input_path: Путь к входному файлу
352
+ store_dir: Директория для сохранения
353
+ device: Устройство
354
+ model_type: Тип модели
355
+ extract_instrumental: Извлечь инструментал
356
+ output_format: Формат вывода
357
+ output_bitrate: Битрейт
358
+ model_name: Имя модели
359
+ template: Шаблон имени
360
+ selected_instruments: Выбранные инструменты
361
+ model_id: ID модели
362
+ spec_invert_target_instrument: Инвертировать спектрограмму для целевого инструмента
363
+
364
+ Returns:
365
+ Список кортежей (имя стема, путь к файлу)
366
+ """
367
+ start_time = time.time()
368
+ if model_type != "vr":
369
+ model.eval()
370
+ sample_rate = 44100
371
+ if "sample_rate" in config.audio:
372
+ sample_rate = config.audio["sample_rate"]
373
+
374
+ instruments = config.training.instruments
375
+
376
+ os.makedirs(store_dir, exist_ok=True)
377
+
378
+ results = once_inference(
379
+ path=input_path,
380
+ model=model,
381
+ config=config,
382
+ device=device,
383
+ model_type=model_type,
384
+ extract_instrumental=extract_instrumental,
385
+ output_format=output_format,
386
+ output_bitrate=output_bitrate,
387
+ model_name=model_name,
388
+ sample_rate=sample_rate,
389
+ instruments=instruments,
390
+ store_dir=store_dir,
391
+ template=template,
392
+ selected_instruments=selected_instruments,
393
+ model_id=model_id,
394
+ spec_invert_target_instrument=spec_invert_target_instrument
395
+ )
396
+
397
+ time.sleep(1)
398
+ time_taken = time.time() - start_time
399
+ sys.stdout.write(
400
+ json.dumps({"time": _i18n("time_seconds", seconds=f"{time_taken:.2f}")}, ensure_ascii=False) + "\n"
401
+ )
402
+ sys.stdout.flush()
403
+ sys.stdout.write(json.dumps({"done": results}, ensure_ascii=False) + "\n")
404
+ sys.stdout.flush()
405
+ return results
406
+
407
+
408
+ def load_model(
409
+ model_type: str,
410
+ config_path: str,
411
+ start_check_point: str,
412
+ device: str
413
+ ) -> Tuple[Any, Any, torch.device]:
414
+ """
415
+ Загрузить модель
416
+
417
+ Args:
418
+ model_type: Тип модели
419
+ config_path: Путь к конфигурации
420
+ start_check_point: Путь к чекпоинту
421
+ device: Строка устройства
422
+
423
+ Returns:
424
+ Кортеж (модель, конфигурация, устройство)
425
+ """
426
+ sys.stdout.write(json.dumps({"device": device}, ensure_ascii=False) + "\n")
427
+ sys.stdout.flush()
428
+
429
+ # Определяем тип устройства
430
+ if "cuda" in device.lower():
431
+ # Извлекаем ID устройств для CUDA
432
+ if ":" in device:
433
+ device_spec = device.split(":")[1]
434
+ device_ids = [int(id) for id in device_spec.split(",") if id.isdigit()]
435
+ else:
436
+ # Если указано просто "cuda", используем все доступные GPU
437
+ device_ids = list(range(torch.cuda.device_count()))
438
+ torch_device = torch.device("cuda" if not device_ids else f"cuda:{device_ids[0]}")
439
+ elif "mps" in device.lower():
440
+ device_ids = None
441
+ torch_device = torch.device("mps")
442
+ else:
443
+ # CPU
444
+ device_ids = None
445
+ torch_device = torch.device("cpu")
446
+
447
+ model_load_start_time = time.time()
448
+
449
+ # Устанавливаем оптимизации только для CUDA
450
+ if torch_device.type == "cuda":
451
+ if hasattr(torch, 'backends'):
452
+ if hasattr(torch.backends, 'cudnn'):
453
+ torch.backends.cudnn.benchmark = True
454
+
455
+ if hasattr(torch.backends.cudnn, 'allow_tf32'):
456
+ torch.backends.cudnn.allow_tf32 = True
457
+
458
+ if hasattr(torch.backends, 'cuda') and hasattr(torch.backends.cuda, 'matmul'):
459
+ if hasattr(torch.backends.cuda.matmul, 'allow_tf32'):
460
+ torch.backends.cuda.matmul.allow_tf32 = True
461
+
462
+ model, config = get_model_from_config(model_type, config_path)
463
+
464
+ if model_type == "vr":
465
+ enable_post_process = False
466
+ if hasattr(config.inference, "enable_post_process"):
467
+ enable_post_process = config.inference.enable_post_process
468
+ model.load_checkpoint(start_check_point, torch_device)
469
+ model.settings(
470
+ enable_post_process=enable_post_process,
471
+ post_process_threshold=config.inference.post_process_threshold,
472
+ batch_size=config.inference.batch_size,
473
+ window_size=config.inference.window_size,
474
+ high_end_process=config.inference.high_end_process,
475
+ primary_stem=config.training.instruments[0],
476
+ secondary_stem=config.training.instruments[1],
477
+ )
478
+ return model, config, torch_device
479
+
480
+ elif model_type == "medley_vox":
481
+ if start_check_point != "":
482
+ checkpoint = torch.load(start_check_point, map_location=torch_device)
483
+ if config.model.ema:
484
+ model_dict = model.state_dict()
485
+ # 1. filter out unnecessary keys
486
+ checkpoint = {
487
+ k.replace("ema_model.module.", ""): v
488
+ for k, v in checkpoint.items()
489
+ if k.replace("ema_model.module.", "") in model_dict
490
+ }
491
+ # 2. overwrite entries in the existing state dict
492
+ model_dict.update(checkpoint)
493
+ # 3. load the new state dict
494
+ model.load_state_dict(model_dict)
495
+ elif not config.model.ema:
496
+ model_dict = model.state_dict()
497
+ # 1. filter out unnecessary keys
498
+ checkpoint = {
499
+ k.replace("online_model.module.", ""): v
500
+ for k, v in checkpoint.items()
501
+ if k.replace("online_model.module.", "") in model_dict
502
+ }
503
+ # 2. overwrite entries in the existing state dict
504
+ model_dict.update(checkpoint)
505
+ # 3. load the new state dict
506
+ model.load_state_dict(model_dict)
507
+ else:
508
+ model.load_state_dict(checkpoint)
509
+ model.eval()
510
+ return model, config, torch_device
511
+
512
+ elif model_type == "mdxnet":
513
+ if start_check_point != "":
514
+ sys.stdout.write(json.dumps({"checkpoint": start_check_point}) + "\n")
515
+ sys.stdout.flush()
516
+ model.init_onnx_session(start_check_point, torch_device, device_ids)
517
+ return model, config, torch_device
518
+
519
+ else:
520
+ if start_check_point != "":
521
+ sys.stdout.write(json.dumps({"checkpoint": start_check_point}) + "\n")
522
+ sys.stdout.flush()
523
+
524
+ if model_type in ["htdemucs", "apollo"]:
525
+ state_dict = torch.load(
526
+ start_check_point, map_location=torch_device, weights_only=False
527
+ )
528
+ else:
529
+ if hasattr(config, "fno"):
530
+ with torch.serialization.safe_globals([torch._C._nn.gelu]):
531
+ state_dict = torch.load(
532
+ start_check_point, map_location=torch_device, weights_only=True
533
+ )
534
+ else:
535
+ try:
536
+ state_dict = torch.load(
537
+ start_check_point, map_location=torch_device, weights_only=True
538
+ )
539
+ except torch.serialization.pickle.UnpicklingError:
540
+ state_dict = torch.load(
541
+ start_check_point,
542
+ map_location=torch_device,
543
+ weights_only=False
544
+ )
545
+
546
+ if "state" in state_dict:
547
+ state_dict = state_dict["state"]
548
+ if "state_dict" in state_dict:
549
+ state_dict = state_dict["state_dict"]
550
+ if "model_state_dict" in state_dict:
551
+ state_dict = state_dict["model_state_dict"]
552
+
553
+ try:
554
+ model.load_state_dict(state_dict)
555
+ except RuntimeError as e:
556
+ sys.stdout.write(
557
+ json.dumps({"stems": ["error", "error"]}, ensure_ascii=False)
558
+ + "\n"
559
+ )
560
+ sys.stdout.write(
561
+ json.dumps({"stems": [str(e)]}, ensure_ascii=False)
562
+ + "\n"
563
+ )
564
+ print(_i18n("state_dict_load_warning", error=str(e)))
565
+ model.load_state_dict(state_dict, strict=False)
566
+
567
+ sys.stdout.write(
568
+ json.dumps({"stems": list(config.training.instruments)}, ensure_ascii=False)
569
+ + "\n"
570
+ )
571
+ sys.stdout.flush()
572
+
573
+ # Перемещаем модель на устройство
574
+ model = model.to(torch_device)
575
+
576
+ # Используем DataParallel только если есть несколько GPU и это не MPS
577
+ if torch_device.type == "cuda" and len(device_ids) > 1:
578
+ model = nn.DataParallel(model, device_ids=device_ids)
579
+ print(_i18n("using_dataparallel", devices=device_ids))
580
+
581
+ load_time = time.time() - model_load_start_time
582
+
583
+ sys.stdout.write(
584
+ json.dumps({"model_load_time": _i18n("time_seconds", seconds=f"{load_time:.2f}")}, ensure_ascii=False)
585
+ + "\n"
586
+ )
587
+ sys.stdout.flush()
588
+
589
+ return model, config, torch_device
590
+
591
+
592
+ def mvsep_offline(
593
+ input_path: str,
594
+ store_dir: str,
595
+ model_type: str,
596
+ config_path: str,
597
+ start_check_point: str,
598
+ extract_instrumental: bool,
599
+ output_format: str,
600
+ output_bitrate: str,
601
+ model_name: str,
602
+ template: str,
603
+ device: str = "cpu",
604
+ selected_instruments: Optional[List[str]] = None,
605
+ model_id: int = 0,
606
+ spec_invert_target_instrument: bool = False
607
+ ) -> List[Tuple[str, str]]:
608
+ """
609
+ Оффлайн разделение
610
+
611
+ Args:
612
+ input_path: Путь к входному файлу
613
+ store_dir: Директория для сохранения
614
+ model_type: Тип модели
615
+ config_path: Путь к конфигурации
616
+ start_check_point: Путь к чекпоинту
617
+ extract_instrumental: Извлечь инструментал
618
+ output_format: Формат вывода
619
+ output_bitrate: Битрейт
620
+ model_name: Имя модели
621
+ template: Шаблон имени
622
+ device: Устройство
623
+ selected_instruments: Выбранные инструменты
624
+ model_id: ID модели
625
+ spec_invert_target_instrument: Инвертировать спектрограмму для целевого инструмента
626
+
627
+ Returns:
628
+ Список кортежей (имя стема, путь к файлу)
629
+ """
630
+ model, config, device = load_model(
631
+ model_type, config_path, start_check_point, device
632
+ )
633
+
634
+ results = run_inference(
635
+ model=model,
636
+ config=config,
637
+ input_path=input_path,
638
+ store_dir=store_dir,
639
+ device=device,
640
+ model_type=model_type,
641
+ extract_instrumental=extract_instrumental,
642
+ output_format=output_format,
643
+ output_bitrate=output_bitrate,
644
+ model_name=model_name,
645
+ template=template,
646
+ selected_instruments=selected_instruments or [],
647
+ model_id=model_id,
648
+ spec_invert_target_instrument=spec_invert_target_instrument
649
+ )
650
+
651
+ if model_type != "vr":
652
+ cleanup_model(model)
653
+ del config
654
+ gc.collect()
655
+ return results
656
+
657
+
658
+ def parse_args() -> argparse.Namespace:
659
+ """Парсинг аргументов командной строки"""
660
+ parser = argparse.ArgumentParser(
661
+ description=_i18n("infer_description")
662
+ )
663
+
664
+ parser.add_argument("--input", type=str, required=True, help=_i18n("input_path_help"))
665
+ parser.add_argument(
666
+ "--store_dir", type=str, required=True, help=_i18n("store_dir_help")
667
+ )
668
+
669
+ parser.add_argument(
670
+ "--model_type",
671
+ type=str,
672
+ default="htdemucs",
673
+ choices=[
674
+ "mel_band_roformer",
675
+ "bs_roformer",
676
+ "mdx23c",
677
+ "scnet",
678
+ "scnet_masked",
679
+ "scnet_tran",
680
+ "htdemucs",
681
+ "bandit",
682
+ "bandit_v2",
683
+ "mdxnet",
684
+ "vr",
685
+ "medley_vox"
686
+ ],
687
+ help=_i18n("model_type_help"),
688
+ )
689
+ parser.add_argument(
690
+ "--config_path",
691
+ type=str,
692
+ required=True,
693
+ help=_i18n("config_path_help"),
694
+ )
695
+ parser.add_argument(
696
+ "--start_check_point", type=str, required=True, help=_i18n("checkpoint_help")
697
+ )
698
+
699
+ parser.add_argument(
700
+ "--output_format",
701
+ type=str,
702
+ default="wav",
703
+ choices=output_formats,
704
+ help=_i18n("output_format_help"),
705
+ )
706
+ parser.add_argument(
707
+ "--output_bitrate", type=str, required=True, help=_i18n("output_bitrate_help")
708
+ )
709
+
710
+ parser.add_argument(
711
+ "--selected_instruments",
712
+ nargs="+",
713
+ help=_i18n("selected_instruments_help"),
714
+ )
715
+ parser.add_argument(
716
+ "--extract_instrumental",
717
+ action="store_true",
718
+ help=_i18n("extract_instrumental_help"),
719
+ )
720
+ parser.add_argument(
721
+ "--use_spec_invert",
722
+ action="store_true",
723
+ help=_i18n("use_spec_invert_help"),
724
+ )
725
+ parser.add_argument(
726
+ "--template",
727
+ type=str,
728
+ default="NAME_STEM",
729
+ help=_i18n("template_help"),
730
+ )
731
+ parser.add_argument(
732
+ "--model_name",
733
+ type=str,
734
+ default="model",
735
+ help=_i18n("model_name_help"),
736
+ )
737
+ parser.add_argument("-m_id", "--model_id", type=int, required=True, help=_i18n("model_id_help"))
738
+ parser.add_argument(
739
+ "--device", type=str, help=_i18n("device_help"), default="cuda:0"
740
+ )
741
+ parser.add_argument("--verbose", action="store_true", help=_i18n("verbose_help"))
742
+
743
+ return parser.parse_args()
744
+
745
+ def main() -> None:
746
+ """Главная функция"""
747
+ args = parse_args()
748
+
749
+ results = mvsep_offline(
750
+ input_path=args.input,
751
+ store_dir=args.store_dir,
752
+ model_type=args.model_type,
753
+ config_path=args.config_path,
754
+ start_check_point=args.start_check_point,
755
+ extract_instrumental=args.extract_instrumental,
756
+ output_format=args.output_format,
757
+ output_bitrate=args.output_bitrate,
758
+ model_name=args.model_name,
759
+ template=args.template,
760
+ device=args.device,
761
+ selected_instruments=args.selected_instruments,
762
+ model_id=args.model_id,
763
+ spec_invert_target_instrument=args.use_spec_invert
764
+ )
765
+
766
+ if __name__ == "__main__":
 
 
767
  main()