| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ResNet in PyTorch. |
| |
| Some modifications from the original architecture: |
| 1. Smaller kernel size for the input layer |
| 2. Smaller number of Channels |
| 3. No max_pooling involved |
| |
| Reference: |
| [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun |
| Deep Residual Learning for Image Recognition. arXiv:1512.03385 |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
|
|
|
|
| |
|
|
|
|
| class TSTP(nn.Module): |
| """ |
| Temporal statistics pooling, concatenate mean and std, which is used in |
| x-vector |
| Comment: simple concatenation can not make full use of both statistics |
| """ |
|
|
| def __init__(self, in_dim=0, **kwargs): |
| super(TSTP, self).__init__() |
| self.in_dim = in_dim |
|
|
| def forward(self, x): |
| |
| |
| pooling_mean = x.mean(dim=-1) |
| pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7) |
| pooling_mean = pooling_mean.flatten(start_dim=1) |
| pooling_std = pooling_std.flatten(start_dim=1) |
| stats = torch.cat((pooling_mean, pooling_std), 1) |
| return stats |
|
|
| def get_out_dim(self): |
| self.out_dim = self.in_dim * 2 |
| return self.out_dim |
|
|
|
|
| class BasicBlock(nn.Module): |
| expansion = 1 |
|
|
| def __init__(self, in_planes, planes, stride=1): |
| super(BasicBlock, self).__init__() |
| self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) |
| self.bn1 = nn.BatchNorm2d(planes) |
| self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) |
| self.bn2 = nn.BatchNorm2d(planes) |
|
|
| self.shortcut = nn.Sequential() |
| if stride != 1 or in_planes != self.expansion * planes: |
| self.shortcut = nn.Sequential( |
| nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), |
| nn.BatchNorm2d(self.expansion * planes), |
| ) |
|
|
| def forward(self, x): |
| out = F.relu(self.bn1(self.conv1(x))) |
| out = self.bn2(self.conv2(out)) |
| out += self.shortcut(x) |
| out = F.relu(out) |
| return out |
|
|
|
|
| class Bottleneck(nn.Module): |
| expansion = 4 |
|
|
| def __init__(self, in_planes, planes, stride=1): |
| super(Bottleneck, self).__init__() |
| self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) |
| self.bn1 = nn.BatchNorm2d(planes) |
| self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) |
| self.bn2 = nn.BatchNorm2d(planes) |
| self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) |
| self.bn3 = nn.BatchNorm2d(self.expansion * planes) |
|
|
| self.shortcut = nn.Sequential() |
| if stride != 1 or in_planes != self.expansion * planes: |
| self.shortcut = nn.Sequential( |
| nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), |
| nn.BatchNorm2d(self.expansion * planes), |
| ) |
|
|
| def forward(self, x): |
| out = F.relu(self.bn1(self.conv1(x))) |
| out = F.relu(self.bn2(self.conv2(out))) |
| out = self.bn3(self.conv3(out)) |
| out += self.shortcut(x) |
| out = F.relu(out) |
| return out |
|
|
|
|
| class ResNet(nn.Module): |
| def __init__( |
| self, |
| block, |
| num_blocks, |
| m_channels=32, |
| feat_dim=40, |
| embed_dim=128, |
| pooling_func="TSTP", |
| two_emb_layer=True, |
| context_window=5, |
| ): |
| super(ResNet, self).__init__() |
| self.in_planes = m_channels |
| self.feat_dim = feat_dim |
| self.embed_dim = embed_dim |
| self.stats_dim = int(feat_dim / 8) * m_channels * 8 |
| self.two_emb_layer = two_emb_layer |
|
|
| self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False) |
| self.bn1 = nn.BatchNorm2d(m_channels) |
| self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1) |
| self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2) |
| self.layer3 = self._make_layer(block, m_channels * 4, num_blocks[2], stride=2) |
| self.layer4 = self._make_layer(block, m_channels * 8, num_blocks[3], stride=2) |
|
|
| self.pool = eval(pooling_func)(in_dim=self.stats_dim * block.expansion) |
| self.pool_out_dim = self.pool.get_out_dim() |
| self.seg_1 = nn.Linear(self.pool_out_dim, embed_dim) |
| if self.two_emb_layer: |
| self.seg_bn_1 = nn.BatchNorm1d(embed_dim, affine=False) |
| self.seg_2 = nn.Linear(embed_dim, embed_dim) |
| else: |
| self.seg_bn_1 = nn.Identity() |
| self.seg_2 = nn.Identity() |
|
|
| |
| self.context_window = context_window |
|
|
| def _make_layer(self, block, planes, num_blocks, stride): |
| strides = [stride] + [1] * (num_blocks - 1) |
| layers = [] |
| for stride in strides: |
| layers.append(block(self.in_planes, planes, stride)) |
| self.in_planes = planes * block.expansion |
| return nn.Sequential(*layers) |
|
|
| def forward(self, x): |
| x = x.permute(0, 2, 1) |
|
|
| x = x.unsqueeze_(1) |
| out = F.relu(self.bn1(self.conv1(x))) |
| out = self.layer1(out) |
| out = self.layer2(out) |
| out = self.layer3(out) |
| out = self.layer4(out) |
|
|
| |
| |
|
|
| |
| |
| stats = self.pool(out) |
|
|
| embed_a = self.seg_1(stats) |
| if self.two_emb_layer: |
| out = F.relu(embed_a) |
| out = self.seg_bn_1(out) |
| embed_b = self.seg_2(out) |
| return embed_a, embed_b |
| else: |
| return torch.tensor(0.0), embed_a |
| |
|
|
| def inference_frame(self, x, downsample_rate=1): |
| batch_size = x.shape[0] |
| x = rearrange(x, "b t f -> b () f t") |
|
|
| |
| out = F.relu(self.bn1(self.conv1(x))) |
| out = self.layer1(out) |
| out = self.layer2(out) |
| out = self.layer3(out) |
| out = self.layer4(out) |
|
|
| out = rearrange(out, "b c f t -> (b c) f t") |
| out = F.pad(out, (self.context_window // 2, self.context_window // 2), "replicate") |
| out = rearrange(out, "(b c) f t -> b c f t", b=batch_size) |
|
|
| unfold_out = out.unfold(-1, self.context_window, 1) |
| out = rearrange(unfold_out, "b c f n w -> (b n) c f w") |
|
|
| stats = self.pool(out) |
| |
| |
|
|
| embed_a = self.seg_1(stats) |
| embed_a = rearrange(embed_a, "(b n) c -> b c n", b=batch_size) |
| embed_a = embed_a[..., ::downsample_rate] |
|
|
| if self.two_emb_layer: |
| out = F.relu(embed_a) |
| out = self.seg_bn_1(out) |
| embed_b = self.seg_2(out) |
| return embed_a, embed_b |
| else: |
| return torch.tensor(0.0), embed_a |
|
|
|
|
| def ResNet18(feat_dim, embed_dim, pooling_func="TSTP", two_emb_layer=True, context_window=5): |
| return ResNet( |
| BasicBlock, |
| [2, 2, 2, 2], |
| feat_dim=feat_dim, |
| embed_dim=embed_dim, |
| pooling_func=pooling_func, |
| two_emb_layer=two_emb_layer, |
| context_window=context_window, |
| ) |
|
|
|
|
| def ResNet34(feat_dim, embed_dim, pooling_func="TSTP", two_emb_layer=True, context_window=5): |
| return ResNet( |
| BasicBlock, |
| [3, 4, 6, 3], |
| feat_dim=feat_dim, |
| embed_dim=embed_dim, |
| pooling_func=pooling_func, |
| two_emb_layer=two_emb_layer, |
| context_window=context_window, |
| ) |
|
|
|
|
| def ResNet50(feat_dim, embed_dim, pooling_func="TSTP", two_emb_layer=True, context_window=5): |
| return ResNet( |
| Bottleneck, |
| [3, 4, 6, 3], |
| feat_dim=feat_dim, |
| embed_dim=embed_dim, |
| pooling_func=pooling_func, |
| two_emb_layer=two_emb_layer, |
| context_window=context_window, |
| ) |
|
|
|
|
| def ResNet101(feat_dim, embed_dim, pooling_func="TSTP", two_emb_layer=True, context_window=5): |
| return ResNet( |
| Bottleneck, |
| [3, 4, 23, 3], |
| feat_dim=feat_dim, |
| embed_dim=embed_dim, |
| pooling_func=pooling_func, |
| two_emb_layer=two_emb_layer, |
| context_window=context_window, |
| ) |
|
|
|
|
| def ResNet152(feat_dim, embed_dim, pooling_func="TSTP", two_emb_layer=True, context_window=5): |
| return ResNet( |
| Bottleneck, |
| [3, 8, 36, 3], |
| feat_dim=feat_dim, |
| embed_dim=embed_dim, |
| pooling_func=pooling_func, |
| two_emb_layer=two_emb_layer, |
| context_window=context_window, |
| ) |
|
|
|
|
| def ResNet221(feat_dim, embed_dim, pooling_func="TSTP", two_emb_layer=True, context_window=5): |
| return ResNet( |
| Bottleneck, |
| [6, 16, 48, 3], |
| feat_dim=feat_dim, |
| embed_dim=embed_dim, |
| pooling_func=pooling_func, |
| two_emb_layer=two_emb_layer, |
| context_window=context_window, |
| ) |
|
|
|
|
| def ResNet293(feat_dim, embed_dim, pooling_func="TSTP", two_emb_layer=True, context_window=5): |
| return ResNet( |
| Bottleneck, |
| [10, 20, 64, 3], |
| feat_dim=feat_dim, |
| embed_dim=embed_dim, |
| pooling_func=pooling_func, |
| two_emb_layer=two_emb_layer, |
| context_window=context_window, |
| ) |
|
|
|
|
| def compute_fbank(waveform): |
| |
| |
|
|
| waveform = (waveform * (1 << 15)).round() |
| |
| mat = [] |
| for i in range(waveform.shape[0]): |
| feat = kaldi.fbank( |
| waveform[i].unsqueeze(0), |
| num_mel_bins=80, |
| frame_length=25, |
| frame_shift=10, |
| dither=0.0, |
| sample_frequency=16000, |
| window_type="hamming", |
| use_energy=False, |
| ) |
|
|
| feat = feat - torch.mean(feat, dim=0) |
| mat.append(feat) |
| mat = torch.stack(mat) |
| return mat |
|
|
|
|
| if __name__ == "__main__": |
| import torchaudio |
| import torchaudio.compliance.kaldi as kaldi |
| from tqdm import tqdm |
|
|
| model_path = "/apdcephfs/share_1149801/speech_user/tomasyu/experiment/frontend/Enhancement-Paas/egs2/PDNS/exp/modelzoo/svmodel/voxceleb_resnet293_LM/voxceleb_resnet293_LM.pt" |
|
|
| wav1, sr = torchaudio.load("/apdcephfs/InstructSpeech/packup/LJ037-0171.wav") |
| if sr != 16000: |
| wav1 = torchaudio.transforms.Resample(sr, 16000)(wav1) |
| |
| wav1 = wav1[..., :16000] |
| x = compute_fbank(wav1) |
|
|
| state_dict = torch.load(model_path, map_location="cpu") |
| model = ResNet293(feat_dim=80, embed_dim=256, pooling_func="TSTP", two_emb_layer=False, context_window=10) |
| |
| model.load_state_dict(state_dict, strict=False) |
|
|
| |
| |
| |
| |
| |
| model.eval() |
| out = model(x) |
|
|
| out_frame = model.inference_frame(x, downsample_rate=3) |
| import pdb |
|
|
| pdb.set_trace() |
| |
| |
| |
|
|
| with torch.no_grad(): |
| for i in tqdm(range(100)): |
| model(torch.randn(1, 600, 80)) |
|
|
| print(out[-1].size()) |
|
|
| num_params = sum(p.numel() for p in model.parameters()) |
| print("{} M".format(num_params / 1e6)) |
|
|
| from thop import profile |
|
|
| x_np = torch.randn(1, 100, 80) |
| flops, params = profile(model, inputs=(x_np,)) |
| print("FLOPS: {} G, Params: {} M".format(flops / 1e9, params / 1e6)) |
|
|
| |
|
|
|
|
| |
|
|
| |
|
|