Spaces:
Running
on
Zero
Running
on
Zero
| #!/usr/bin/env python3 | |
| # -*- encoding: utf-8 -*- | |
| # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. | |
| # MIT License (https://opensource.org/licenses/MIT) | |
| # Modified from 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker) | |
| import time | |
| import torch | |
| import numpy as np | |
| from collections import OrderedDict | |
| from contextlib import contextmanager | |
| from distutils.version import LooseVersion | |
| from funasr_detach.register import tables | |
| from funasr_detach.models.campplus.utils import extract_feature | |
| from funasr_detach.utils.load_utils import load_audio_text_image_video | |
| from funasr_detach.models.campplus.components import ( | |
| DenseLayer, | |
| StatsPool, | |
| TDNNLayer, | |
| CAMDenseTDNNBlock, | |
| TransitLayer, | |
| get_nonlinear, | |
| FCM, | |
| ) | |
| if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): | |
| from torch.cuda.amp import autocast | |
| else: | |
| # Nothing to do if torch<1.6.0 | |
| def autocast(enabled=True): | |
| yield | |
| class CAMPPlus(torch.nn.Module): | |
| def __init__( | |
| self, | |
| feat_dim=80, | |
| embedding_size=192, | |
| growth_rate=32, | |
| bn_size=4, | |
| init_channels=128, | |
| config_str="batchnorm-relu", | |
| memory_efficient=True, | |
| output_level="segment", | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.head = FCM(feat_dim=feat_dim) | |
| channels = self.head.out_channels | |
| self.output_level = output_level | |
| self.xvector = torch.nn.Sequential( | |
| OrderedDict( | |
| [ | |
| ( | |
| "tdnn", | |
| TDNNLayer( | |
| channels, | |
| init_channels, | |
| 5, | |
| stride=2, | |
| dilation=1, | |
| padding=-1, | |
| config_str=config_str, | |
| ), | |
| ), | |
| ] | |
| ) | |
| ) | |
| channels = init_channels | |
| for i, (num_layers, kernel_size, dilation) in enumerate( | |
| zip((12, 24, 16), (3, 3, 3), (1, 2, 2)) | |
| ): | |
| block = CAMDenseTDNNBlock( | |
| num_layers=num_layers, | |
| in_channels=channels, | |
| out_channels=growth_rate, | |
| bn_channels=bn_size * growth_rate, | |
| kernel_size=kernel_size, | |
| dilation=dilation, | |
| config_str=config_str, | |
| memory_efficient=memory_efficient, | |
| ) | |
| self.xvector.add_module("block%d" % (i + 1), block) | |
| channels = channels + num_layers * growth_rate | |
| self.xvector.add_module( | |
| "transit%d" % (i + 1), | |
| TransitLayer( | |
| channels, channels // 2, bias=False, config_str=config_str | |
| ), | |
| ) | |
| channels //= 2 | |
| self.xvector.add_module("out_nonlinear", get_nonlinear(config_str, channels)) | |
| if self.output_level == "segment": | |
| self.xvector.add_module("stats", StatsPool()) | |
| self.xvector.add_module( | |
| "dense", | |
| DenseLayer(channels * 2, embedding_size, config_str="batchnorm_"), | |
| ) | |
| else: | |
| assert ( | |
| self.output_level == "frame" | |
| ), "`output_level` should be set to 'segment' or 'frame'. " | |
| for m in self.modules(): | |
| if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)): | |
| torch.nn.init.kaiming_normal_(m.weight.data) | |
| if m.bias is not None: | |
| torch.nn.init.zeros_(m.bias) | |
| def forward(self, x): | |
| x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T) | |
| x = self.head(x) | |
| x = self.xvector(x) | |
| if self.output_level == "frame": | |
| x = x.transpose(1, 2) | |
| return x | |
| def inference( | |
| self, | |
| data_in, | |
| data_lengths=None, | |
| key: list = None, | |
| tokenizer=None, | |
| frontend=None, | |
| **kwargs, | |
| ): | |
| # extract fbank feats | |
| meta_data = {} | |
| time1 = time.perf_counter() | |
| audio_sample_list = load_audio_text_image_video( | |
| data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound" | |
| ) | |
| time2 = time.perf_counter() | |
| meta_data["load_data"] = f"{time2 - time1:0.3f}" | |
| speech, speech_lengths, speech_times = extract_feature(audio_sample_list) | |
| speech = speech.to(device=kwargs["device"]) | |
| time3 = time.perf_counter() | |
| meta_data["extract_feat"] = f"{time3 - time2:0.3f}" | |
| meta_data["batch_data_time"] = np.array(speech_times).sum().item() / 16000.0 | |
| results = [{"spk_embedding": self.forward(speech.to(torch.float32))}] | |
| return results, meta_data | |