Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from asteroid_filterbanks import make_enc_dec | |
| from asteroid.masknn import TDConvNet | |
| import utils | |
| from .base_models import ( | |
| BaseEncoderMaskerDecoderWithConfigs, | |
| BaseEncoderMaskerDecoderWithConfigsMaskOnOutput, | |
| BaseEncoderMaskerDecoderWithConfigsMultiChannelAsteroid, | |
| ) | |
| def load_model_with_args(args): | |
| if args.model_loss_params.architecture == "conv_tasnet_mask_on_output": | |
| encoder, decoder = make_enc_dec( | |
| "free", | |
| n_filters=args.conv_tasnet_params.n_filters, | |
| kernel_size=args.conv_tasnet_params.kernel_size, | |
| stride=args.conv_tasnet_params.stride, | |
| sample_rate=args.sample_rate, | |
| ) | |
| masker = TDConvNet( | |
| in_chan=encoder.n_feats_out * args.data_params.nb_channels, # stereo | |
| n_src=1, # for de-limit task. | |
| out_chan=encoder.n_feats_out, | |
| n_blocks=args.conv_tasnet_params.n_blocks, | |
| n_repeats=args.conv_tasnet_params.n_repeats, | |
| bn_chan=args.conv_tasnet_params.bn_chan, | |
| hid_chan=args.conv_tasnet_params.hid_chan, | |
| skip_chan=args.conv_tasnet_params.skip_chan, | |
| # conv_kernel_size=args.conv_tasnet_params.conv_kernel_size, | |
| norm_type=args.conv_tasnet_params.norm_type if args.conv_tasnet_params.norm_type else 'gLN', | |
| mask_act=args.conv_tasnet_params.mask_act, | |
| # causal=args.conv_tasnet_params.causal, | |
| ) | |
| model = BaseEncoderMaskerDecoderWithConfigsMaskOnOutput( | |
| encoder, | |
| masker, | |
| decoder, | |
| encoder_activation=args.conv_tasnet_params.encoder_activation, | |
| use_encoder=True, | |
| apply_mask=True, | |
| use_decoder=True, | |
| decoder_activation=args.conv_tasnet_params.decoder_activation, | |
| ) | |
| model.use_encoder_to_target = False | |
| elif args.model_loss_params.architecture == "conv_tasnet": | |
| encoder, decoder = make_enc_dec( | |
| "free", | |
| n_filters=args.conv_tasnet_params.n_filters, | |
| kernel_size=args.conv_tasnet_params.kernel_size, | |
| stride=args.conv_tasnet_params.stride, | |
| sample_rate=args.sample_rate, | |
| ) | |
| masker = TDConvNet( | |
| in_chan=encoder.n_feats_out * args.data_params.nb_channels, # stereo | |
| n_src=args.conv_tasnet_params.n_src, # for de-limit task with the standard conv-tasnet setting. | |
| out_chan=encoder.n_feats_out, | |
| n_blocks=args.conv_tasnet_params.n_blocks, | |
| n_repeats=args.conv_tasnet_params.n_repeats, | |
| bn_chan=args.conv_tasnet_params.bn_chan, | |
| hid_chan=args.conv_tasnet_params.hid_chan, | |
| skip_chan=args.conv_tasnet_params.skip_chan, | |
| # conv_kernel_size=args.conv_tasnet_params.conv_kernel_size, | |
| norm_type=args.conv_tasnet_params.norm_type if args.conv_tasnet_params.norm_type else 'gLN', | |
| mask_act=args.conv_tasnet_params.mask_act, | |
| # causal=args.conv_tasnet_params.causal, | |
| ) | |
| model = BaseEncoderMaskerDecoderWithConfigsMultiChannelAsteroid( | |
| encoder, | |
| masker, | |
| decoder, | |
| encoder_activation=args.conv_tasnet_params.encoder_activation, | |
| use_encoder=True, | |
| apply_mask=False if args.conv_tasnet_params.synthesis else True, | |
| use_decoder=True, | |
| decoder_activation=args.conv_tasnet_params.decoder_activation, | |
| ) | |
| model.use_encoder_to_target = False | |
| return model | |