| import torch | |
| import torch.nn as nn | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from huggingface_hub import ModelCard | |
| from tasnet import ConvTasNetStereo | |
| class DynamicSourceSeparator(torch.nn.Module, PyTorchModelHubMixin): | |
| def __init__(self, pre_trained_models): | |
| super(DynamicSourceSeparator, self).__init__() | |
| self.models = nn.ModuleDict(pre_trained_models) | |
| def forward(self, mixture, indicator): | |
| separated_sources = {} | |
| for instrument, active in indicator.items(): | |
| if active: | |
| model = self.models[instrument] | |
| est_source = model(mixture) | |
| separated_sources[instrument] = est_source[:, 0, :, :] | |
| else: | |
| separated_sources[instrument] = torch.zeros_like(mixture) | |
| return separated_sources | |