Spaces:
Paused
Paused
| from batch_processing import Batch | |
| import argparse | |
| from pathlib import Path | |
| from gyraudio.audio_separation.experiment_tracking.experiments import get_experience | |
| from gyraudio.audio_separation.experiment_tracking.storage import get_output_folder | |
| from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT | |
| from gyraudio.audio_separation.properties import ( | |
| SHORT_NAME, CLEAN, NOISY, MIXED, PREDICTED, ANNOTATIONS, PATHS, BUFFERS, SAMPLING_RATE, NAME | |
| ) | |
| import torch | |
| from gyraudio.audio_separation.experiment_tracking.storage import load_checkpoint | |
| from gyraudio.audio_separation.visualization.pre_load_audio import ( | |
| parse_command_line_audio_load, load_buffers, audio_loading_batch) | |
| from gyraudio.audio_separation.visualization.pre_load_custom_audio import ( | |
| parse_command_line_generic_audio_load, generic_audio_loading_batch, | |
| load_buffers_custom | |
| ) | |
| from torchaudio.functional import resample | |
| from typing import List | |
| import numpy as np | |
| import logging | |
| from interactive_pipe.data_objects.curves import Curve, SingleCurve | |
| from interactive_pipe import interactive, KeyboardControl, Control | |
| from interactive_pipe import interactive_pipeline | |
| from gyraudio.audio_separation.visualization.audio_player import audio_selector, audio_trim, audio_player | |
| default_device = "cuda" if torch.cuda.is_available() else "cpu" | |
| LEARNT_SAMPLING_RATE = 8000 | |
| def remix(signals, snr=0., global_params={}): | |
| signal = signals[BUFFERS][CLEAN] | |
| noisy = signals[BUFFERS][NOISY] | |
| alpha = 10 ** (-snr / 20) * torch.norm(signal) / torch.norm(noisy) | |
| mixed_signal = signal + alpha * noisy | |
| global_params["snr"] = snr | |
| return mixed_signal | |
| def augment(signals, mixed, std_dev=0., amplify=1.): | |
| signals[BUFFERS][MIXED] *= amplify | |
| signals[BUFFERS][NOISY] *= amplify | |
| signals[BUFFERS][CLEAN] *= amplify | |
| mixed = mixed*amplify+torch.randn_like(mixed)*std_dev | |
| return signals, mixed | |
| # @interactive( | |
| # device=("cuda", ["cpu", "cuda"] | |
| # ) if default_device == "cuda" else ("cpu", ["cpu"]) | |
| # ) | |
| def select_device(device=default_device, global_params={}): | |
| global_params["device"] = device | |
| # @interactive( | |
| # model=KeyboardControl(value_default=0, value_range=[ | |
| # 0, 99], keyup="pagedown", keydown="pageup") | |
| # ) | |
| ALL_MODELS = ["Tiny UNET", "Large UNET", "Large UNET (Bias Free)"] | |
| def audio_sep_inference(mixed, models, configs, model: int = 0, global_params={}): | |
| if isinstance(model, str): | |
| model = ALL_MODELS.index(model) | |
| assert isinstance(model, int) | |
| selected_model = models[model % len(models)] | |
| config = configs[model % len(models)] | |
| short_name = config.get(SHORT_NAME, "") | |
| annotations = config.get(ANNOTATIONS, "") | |
| global_params[SHORT_NAME] = short_name | |
| global_params[ANNOTATIONS] = annotations | |
| device = global_params.get("device", "cpu") | |
| with torch.no_grad(): | |
| selected_model.eval() | |
| selected_model.to(device) | |
| predicted_signal, predicted_noise = selected_model( | |
| mixed.to(device).unsqueeze(0)) | |
| predicted_signal = predicted_signal.squeeze(0) | |
| pred_curve = predicted_signal.detach().cpu().numpy() | |
| return predicted_signal, pred_curve | |
| def compute_metrics(pred, sig, global_params={}): | |
| METRICS = "metrics" | |
| target = sig[BUFFERS][CLEAN] | |
| global_params[METRICS] = {} | |
| global_params[METRICS]["MSE"] = torch.mean((target-pred.cpu())**2) | |
| global_params[METRICS]["SNR"] = 10. * \ | |
| torch.log10(torch.sum(target**2)/torch.sum((target-pred.cpu())**2)) | |
| def get_trim(sig, zoom, center, num_samples=300): | |
| N = len(sig) | |
| native_ds = N/num_samples | |
| center_idx = int(center*N) | |
| window = int(num_samples/zoom*native_ds) | |
| start_idx = max(0, center_idx - window//2) | |
| end_idx = min(N, center_idx + window//2) | |
| skip_factor = max(1, int(native_ds/zoom)) | |
| return start_idx, end_idx, skip_factor | |
| def zin(sig, zoom, center, num_samples=300): | |
| start_idx, end_idx, skip_factor = get_trim( | |
| sig, zoom, center, num_samples=num_samples) | |
| out = np.zeros(num_samples) | |
| trimmed = sig[start_idx:end_idx:skip_factor] | |
| out[:len(trimmed)] = trimmed[:num_samples] | |
| return out | |
| def visualize_audio(signal: dict, mixed_signal, predicted_signal, zoom=1, zoomy=0., center=0.5, global_params={}): | |
| """Create curves | |
| """ | |
| selected = global_params.get("selected_audio", MIXED) | |
| short_name = global_params.get(SHORT_NAME, "") | |
| annotations = global_params.get(ANNOTATIONS, "") | |
| zval = 1.5**zoom | |
| start_idx, end_idx, _skip_factor = get_trim( | |
| signal[BUFFERS][CLEAN][0, :], zval, center) | |
| global_params["trim"] = dict(start=start_idx, end=end_idx) | |
| selected = global_params.get("selected_audio", MIXED) | |
| pred = SingleCurve(y=zin(predicted_signal[0, :], zval, center), | |
| style="g-", label=("*" if selected == PREDICTED else " ")+f"predicted_{short_name} {annotations}") | |
| clean = SingleCurve(y=zin(signal[BUFFERS][CLEAN][0, :], zval, center), | |
| alpha=1., | |
| style="k-", | |
| linewidth=0.9, | |
| label=("*" if selected == CLEAN else " ")+"clean") | |
| noisy = SingleCurve(y=zin(signal[BUFFERS][NOISY][0, :], zval, center), | |
| alpha=0.3, | |
| style="y--", | |
| linewidth=1, | |
| label=("*" if selected == NOISY else " ") + "noisy" | |
| ) | |
| mixed = SingleCurve(y=zin(mixed_signal[0, :], zval, center), style="r-", | |
| alpha=0.1, | |
| linewidth=2, | |
| label=("*" if selected == MIXED else " ") + "mixed") | |
| # true_mixed = SingleCurve(y=zin(signal[BUFFERS][MIXED][0, :], zval, center), | |
| # alpha=0.3, style="b-", linewidth=1, label="true mixed") | |
| curves = [noisy, mixed, pred, clean] | |
| title = f"SNR in {global_params.get('snr', np.nan):.1f} dB" | |
| if "selected_info" in global_params: | |
| title += f" | {global_params['selected_info']}" | |
| title += "\n" | |
| for metric_name, metric_value in global_params.get("metrics", {}).items(): | |
| title += f" | {metric_name} " | |
| title += f"{metric_value:.2e}" if (abs(metric_value) < 1e-2 or abs(metric_value) | |
| > 1000) else f"{metric_value:.2f}" | |
| # if global_params.get("premixed_snr", None) is not None: | |
| # title += f"| Premixed SNR : {global_params['premixed_snr']:.1f} dB" | |
| return Curve(curves, ylim=[-0.04 * 1.5 ** zoomy, 0.04 * 1.5 ** zoomy], xlabel="Time index", ylabel="Amplitude", title=title) | |
| def signal_selector(signals, idx="Voice 1", idn=0, global_params={}): | |
| idx = int(idx.split("Voice ")[-1]) | |
| if isinstance(signals, dict): | |
| clean_sigs = signals[CLEAN] | |
| clean = clean_sigs[idx % len(clean_sigs)] | |
| if BUFFERS not in clean: | |
| load_buffers_custom(clean) | |
| noise_sigs = signals[NOISY] | |
| noise = noise_sigs[idn % len(noise_sigs)] | |
| if BUFFERS not in noise: | |
| load_buffers_custom(noise) | |
| cbuf, nbuf = clean[BUFFERS], noise[BUFFERS] | |
| if clean[SAMPLING_RATE] != LEARNT_SAMPLING_RATE: | |
| cbuf = resample(cbuf, clean[SAMPLING_RATE], LEARNT_SAMPLING_RATE) | |
| clean[SAMPLING_RATE] = LEARNT_SAMPLING_RATE | |
| if noise[SAMPLING_RATE] != LEARNT_SAMPLING_RATE: | |
| nbuf = resample(nbuf, noise[SAMPLING_RATE], LEARNT_SAMPLING_RATE) | |
| noise[SAMPLING_RATE] = LEARNT_SAMPLING_RATE | |
| min_length = min(cbuf.shape[-1], nbuf.shape[-1]) | |
| min_length = min_length - min_length % 1024 | |
| signal = { | |
| PATHS: { | |
| CLEAN: clean[PATHS], | |
| NOISY: noise[PATHS] | |
| }, | |
| BUFFERS: { | |
| CLEAN: cbuf[..., :1, :min_length], | |
| NOISY: nbuf[..., :1, :min_length], | |
| }, | |
| NAME: f"Clean={clean[NAME]} | Noise={noise[NAME]}", | |
| SAMPLING_RATE: LEARNT_SAMPLING_RATE | |
| } | |
| else: | |
| # signals are loaded in CPU | |
| signal = signals[idx % len(signals)] | |
| if BUFFERS not in signal: | |
| load_buffers(signal) | |
| global_params["premixed_snr"] = signal.get("premixed_snr", None) | |
| signal[NAME] = f"File={signal[NAME]}" | |
| global_params["selected_info"] = signal[NAME] | |
| global_params[SAMPLING_RATE] = signal[SAMPLING_RATE] | |
| return signal | |
| def interactive_audio_separation_processing(signals, model_list, config_list): | |
| sig = signal_selector(signals) | |
| mixed = remix(sig) | |
| # sig, mixed = augment(sig, mixed) | |
| select_device() | |
| pred, pred_curve = audio_sep_inference(mixed, model_list, config_list) | |
| compute_metrics(pred, sig) | |
| sound = audio_selector(sig, mixed, pred) | |
| curve = visualize_audio(sig, mixed, pred_curve) | |
| trimmed_sound = audio_trim(sound) | |
| audio_player(trimmed_sound) | |
| return curve | |
| def interactive_audio_separation_visualization( | |
| all_signals: List[dict], | |
| model_list: List[torch.nn.Module], | |
| config_list: List[dict], | |
| gui="gradio" | |
| ): | |
| interactive_pipeline(gui=gui, cache=True, audio=True)( | |
| interactive_audio_separation_processing | |
| )( | |
| all_signals, model_list, config_list | |
| ) | |
| def visualization( | |
| all_signals: List[dict], | |
| model_list: List[torch.nn.Module], | |
| config_list: List[dict], | |
| device="cuda" | |
| ): | |
| for signal in all_signals: | |
| if BUFFERS not in signal: | |
| load_buffers(signal, device="cpu") | |
| clean = SingleCurve(y=signal[BUFFERS][CLEAN][0, :], label="clean") | |
| noisy = SingleCurve(y=signal[BUFFERS][NOISY] | |
| [0, :], label="noise", alpha=0.3) | |
| curves = [clean, noisy] | |
| for config, model in zip(config_list, model_list): | |
| short_name = config.get(SHORT_NAME, "unknown") | |
| predicted_signal, predicted_noise = model( | |
| signal[BUFFERS][MIXED].to(device).unsqueeze(0)) | |
| predicted = SingleCurve(y=predicted_signal.squeeze(0)[0, :].detach().cpu().numpy(), | |
| label=f"predicted_{short_name}") | |
| curves.append(predicted) | |
| Curve(curves).show() | |
| def parse_command_line(parser: Batch = None, gradio_demo=True) -> argparse.ArgumentParser: | |
| if gradio_demo: | |
| parser = parse_command_line_gradio(parser) | |
| else: | |
| parser = parse_command_line_generic(parser) | |
| return parser | |
| def parse_command_line_gradio(parser: Batch = None, gradio_demo=True) -> argparse.ArgumentParser: | |
| if parser is None: | |
| parser = parse_command_line_audio_load() | |
| default_device = "cuda" if torch.cuda.is_available() else "cpu" | |
| iparse = parser.add_argument_group("Audio separation visualization") | |
| iparse.add_argument("-e", "--experiments", type=int, nargs="+", default=[4, 1004, 3001,], | |
| help="Experiment ids to be inferred sequentially") | |
| iparse.add_argument("-p", "--interactive", default=True, | |
| action="store_true", help="Play = Interactive mode") | |
| iparse.add_argument("-m", "--model-root", type=str, | |
| default=EXPERIMENT_STORAGE_ROOT) | |
| iparse.add_argument("-d", "--device", type=str, default=default_device, | |
| choices=["cpu", "cuda"] if default_device == "cuda" else ["cpu"]) | |
| iparse.add_argument("-gui", "--gui", type=str, | |
| default="gradio", choices=["qt", "mpl", "gradio"]) | |
| return parser | |
| def parse_command_line_generic(parser: Batch = None, gradio_demo=True) -> argparse.ArgumentParser: | |
| if parser is None: | |
| parser = parse_command_line_audio_load() | |
| default_device = "cuda" if torch.cuda.is_available() else "cpu" | |
| iparse = parser.add_argument_group("Audio separation visualization") | |
| iparse.add_argument("-e", "--experiments", type=int, nargs="+", required=True, | |
| help="Experiment ids to be inferred sequentially") | |
| iparse.add_argument("-p", "--interactive", | |
| action="store_true", help="Play = Interactive mode") | |
| iparse.add_argument("-m", "--model-root", type=str, | |
| default=EXPERIMENT_STORAGE_ROOT) | |
| iparse.add_argument("-d", "--device", type=str, default=default_device, | |
| choices=["cpu", "cuda"] if default_device == "cuda" else ["cpu"]) | |
| iparse.add_argument("-gui", "--gui", type=str, | |
| default="qt", choices=["qt", "mpl", "gradio"]) | |
| return parser | |
| def main(argv: List[str]): | |
| """Paired signals and noise in folders""" | |
| batch = Batch(argv) | |
| batch.set_io_description( | |
| input_help='input audio files', | |
| output_help=argparse.SUPPRESS | |
| ) | |
| batch.set_multiprocessing_enabled(False) | |
| parser = parse_command_line() | |
| args = batch.parse_args(parser) | |
| exp = args.experiments[0] | |
| device = args.device | |
| models_list = [] | |
| config_list = [] | |
| logging.info(f"Loading experiments models {args.experiments}") | |
| for exp in args.experiments: | |
| model_dir = Path(args.model_root) | |
| short_name, model, config, _dl = get_experience(exp) | |
| _, exp_dir = get_output_folder( | |
| config, root_dir=model_dir, override=False) | |
| assert exp_dir.exists( | |
| ), f"Experiment {short_name} does not exist in {model_dir}" | |
| model.eval() | |
| model.to(device) | |
| model, __optimizer, epoch, config = load_checkpoint( | |
| model, exp_dir, epoch=None, device=args.device) | |
| config[SHORT_NAME] = short_name | |
| models_list.append(model) | |
| config_list.append(config) | |
| logging.info("Load audio buffers:") | |
| all_signals = batch.run(audio_loading_batch) | |
| if not args.interactive: | |
| visualization(all_signals, models_list, config_list, device=device) | |
| else: | |
| interactive_audio_separation_visualization( | |
| all_signals, models_list, config_list, gui=args.gui) | |
| def main_custom(argv: List[str]): | |
| """Handle custom noise and custom signals | |
| """ | |
| parser = parse_command_line() | |
| parser.add_argument("-s", "--signal", type=str, required=True, | |
| nargs="+", help="Signal to be preloaded") | |
| parser.add_argument("-n", "--noise", type=str, required=True, | |
| nargs="+", help="Noise to be preloaded") | |
| args = parser.parse_args(argv) | |
| exp = args.experiments[0] | |
| device = args.device | |
| models_list = [] | |
| config_list = [] | |
| logging.info(f"Loading experiments models {args.experiments}") | |
| for exp in args.experiments: | |
| model_dir = Path(args.model_root) | |
| short_name, model, config, _dl = get_experience(exp) | |
| _, exp_dir = get_output_folder( | |
| config, root_dir=model_dir, override=False) | |
| assert exp_dir.exists( | |
| ), f"Experiment {short_name} does not exist in {model_dir}" | |
| model.eval() | |
| model.to(device) | |
| model, __optimizer, epoch, config = load_checkpoint( | |
| model, exp_dir, epoch=None, device=args.device) | |
| config[SHORT_NAME] = short_name | |
| models_list.append(model) | |
| config_list.append(config) | |
| all_signals = {} | |
| for args_paths, key in zip([args.signal, args.noise], [CLEAN, NOISY]): | |
| new_argv = ["-i"] + args_paths | |
| if args.preload: | |
| new_argv += ["--preload"] | |
| batch = Batch(new_argv) | |
| new_parser = parse_command_line_generic_audio_load() | |
| batch.set_io_description( | |
| input_help=argparse.SUPPRESS, # 'input audio files', | |
| output_help=argparse.SUPPRESS | |
| ) | |
| batch.set_multiprocessing_enabled(False) | |
| _ = batch.parse_args(new_parser) | |
| all_signals[key] = batch.run(generic_audio_loading_batch) | |
| interactive_audio_separation_visualization( | |
| all_signals, models_list, config_list, gui=args.gui) | |