haoxiangsnr's picture
Upload folder using huggingface_hub
50de2e0 verified
# Copyright (c) 2021 Shuai Wang (wsstriving@gmail.com)
# 2022 Zhengyang Chen (chenzhengyang117@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ResNet in PyTorch.
Some modifications from the original architecture:
1. Smaller kernel size for the input layer
2. Smaller number of Channels
3. No max_pooling involved
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Deep Residual Learning for Image Recognition. arXiv:1512.03385
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
# import pooling_layers
class TSTP(nn.Module):
"""
Temporal statistics pooling, concatenate mean and std, which is used in
x-vector
Comment: simple concatenation can not make full use of both statistics
"""
def __init__(self, in_dim=0, **kwargs):
super(TSTP, self).__init__()
self.in_dim = in_dim
def forward(self, x):
# import pdb; pdb.set_trace()
# The last dimension is the temporal axis
pooling_mean = x.mean(dim=-1)
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7)
pooling_mean = pooling_mean.flatten(start_dim=1)
pooling_std = pooling_std.flatten(start_dim=1)
stats = torch.cat((pooling_mean, pooling_std), 1)
return stats
def get_out_dim(self):
self.out_dim = self.in_dim * 2
return self.out_dim
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes),
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion * planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes),
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet(nn.Module):
def __init__(
self,
block,
num_blocks,
m_channels=32,
feat_dim=40,
embed_dim=128,
pooling_func="TSTP",
two_emb_layer=True,
context_window=5,
):
super(ResNet, self).__init__()
self.in_planes = m_channels
self.feat_dim = feat_dim
self.embed_dim = embed_dim
self.stats_dim = int(feat_dim / 8) * m_channels * 8
self.two_emb_layer = two_emb_layer
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(m_channels)
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, m_channels * 4, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, m_channels * 8, num_blocks[3], stride=2)
self.pool = eval(pooling_func)(in_dim=self.stats_dim * block.expansion)
self.pool_out_dim = self.pool.get_out_dim()
self.seg_1 = nn.Linear(self.pool_out_dim, embed_dim)
if self.two_emb_layer:
self.seg_bn_1 = nn.BatchNorm1d(embed_dim, affine=False)
self.seg_2 = nn.Linear(embed_dim, embed_dim)
else:
self.seg_bn_1 = nn.Identity()
self.seg_2 = nn.Identity()
# Frame level mode
self.context_window = context_window
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
# check = out.clone()
# out = out[...,3:8]
# print("utt", out, out.shape)
# import pdb; pdb.set_trace()
stats = self.pool(out)
embed_a = self.seg_1(stats)
if self.two_emb_layer:
out = F.relu(embed_a)
out = self.seg_bn_1(out)
embed_b = self.seg_2(out)
return embed_a, embed_b
else:
return torch.tensor(0.0), embed_a
# return torch.tensor(0.0), embed_a, check
def inference_frame(self, x, downsample_rate=1):
batch_size = x.shape[0]
x = rearrange(x, "b t f -> b () f t")
# print("inference_input shape", x.shape)
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = rearrange(out, "b c f t -> (b c) f t")
out = F.pad(out, (self.context_window // 2, self.context_window // 2), "replicate")
out = rearrange(out, "(b c) f t -> b c f t", b=batch_size)
unfold_out = out.unfold(-1, self.context_window, 1) # [b, c, f, n, w/t]
out = rearrange(unfold_out, "b c f n w -> (b n) c f w")
stats = self.pool(out)
# import pdb; pdb.set_trace()
# print("after pool", stats.shape)
embed_a = self.seg_1(stats)
embed_a = rearrange(embed_a, "(b n) c -> b c n", b=batch_size)
embed_a = embed_a[..., ::downsample_rate]
if self.two_emb_layer:
out = F.relu(embed_a)
out = self.seg_bn_1(out)
embed_b = self.seg_2(out)
return embed_a, embed_b
else:
return torch.tensor(0.0), embed_a
def ResNet18(feat_dim, embed_dim, pooling_func="TSTP", two_emb_layer=True, context_window=5):
return ResNet(
BasicBlock,
[2, 2, 2, 2],
feat_dim=feat_dim,
embed_dim=embed_dim,
pooling_func=pooling_func,
two_emb_layer=two_emb_layer,
context_window=context_window,
)
def ResNet34(feat_dim, embed_dim, pooling_func="TSTP", two_emb_layer=True, context_window=5):
return ResNet(
BasicBlock,
[3, 4, 6, 3],
feat_dim=feat_dim,
embed_dim=embed_dim,
pooling_func=pooling_func,
two_emb_layer=two_emb_layer,
context_window=context_window,
)
def ResNet50(feat_dim, embed_dim, pooling_func="TSTP", two_emb_layer=True, context_window=5):
return ResNet(
Bottleneck,
[3, 4, 6, 3],
feat_dim=feat_dim,
embed_dim=embed_dim,
pooling_func=pooling_func,
two_emb_layer=two_emb_layer,
context_window=context_window,
)
def ResNet101(feat_dim, embed_dim, pooling_func="TSTP", two_emb_layer=True, context_window=5):
return ResNet(
Bottleneck,
[3, 4, 23, 3],
feat_dim=feat_dim,
embed_dim=embed_dim,
pooling_func=pooling_func,
two_emb_layer=two_emb_layer,
context_window=context_window,
)
def ResNet152(feat_dim, embed_dim, pooling_func="TSTP", two_emb_layer=True, context_window=5):
return ResNet(
Bottleneck,
[3, 8, 36, 3],
feat_dim=feat_dim,
embed_dim=embed_dim,
pooling_func=pooling_func,
two_emb_layer=two_emb_layer,
context_window=context_window,
)
def ResNet221(feat_dim, embed_dim, pooling_func="TSTP", two_emb_layer=True, context_window=5):
return ResNet(
Bottleneck,
[6, 16, 48, 3],
feat_dim=feat_dim,
embed_dim=embed_dim,
pooling_func=pooling_func,
two_emb_layer=two_emb_layer,
context_window=context_window,
)
def ResNet293(feat_dim, embed_dim, pooling_func="TSTP", two_emb_layer=True, context_window=5):
return ResNet(
Bottleneck,
[10, 20, 64, 3],
feat_dim=feat_dim,
embed_dim=embed_dim,
pooling_func=pooling_func,
two_emb_layer=two_emb_layer,
context_window=context_window,
)
def compute_fbank(waveform):
# waveform = torchaudio.transforms.Resample(
# orig_freq=48000, new_freq=16000)(waveform)
waveform = (waveform * (1 << 15)).round()
# import pdb; pdb.set_trace()
mat = []
for i in range(waveform.shape[0]):
feat = kaldi.fbank(
waveform[i].unsqueeze(0),
num_mel_bins=80,
frame_length=25,
frame_shift=10,
dither=0.0,
sample_frequency=16000,
window_type="hamming",
use_energy=False,
)
feat = feat - torch.mean(feat, dim=0)
mat.append(feat)
mat = torch.stack(mat)
return mat
if __name__ == "__main__":
import torchaudio
import torchaudio.compliance.kaldi as kaldi
from tqdm import tqdm
model_path = "/apdcephfs/share_1149801/speech_user/tomasyu/experiment/frontend/Enhancement-Paas/egs2/PDNS/exp/modelzoo/svmodel/voxceleb_resnet293_LM/voxceleb_resnet293_LM.pt"
wav1, sr = torchaudio.load("/apdcephfs/InstructSpeech/packup/LJ037-0171.wav")
if sr != 16000:
wav1 = torchaudio.transforms.Resample(sr, 16000)(wav1)
# import pdb; pdb.set_trace()
wav1 = wav1[..., :16000]
x = compute_fbank(wav1)
state_dict = torch.load(model_path, map_location="cpu")
model = ResNet293(feat_dim=80, embed_dim=256, pooling_func="TSTP", two_emb_layer=False, context_window=10)
# import pdb; pdb.set_trace()
model.load_state_dict(state_dict, strict=False)
# # x = torch.zeros(10, 200, 80)
# model = ResNet34(feat_dim=80,
# embed_dim=256,
# pooling_func='TSTP',
# two_emb_layer=False)
model.eval()
out = model(x)
out_frame = model.inference_frame(x, downsample_rate=3)
import pdb
pdb.set_trace()
# print((out[1] - out_frame[1][...,5]).sum())
# import pdb; pdb.set_trace()
# print((out[2][...,3:8] - out_frame[2][5]).sum())
with torch.no_grad():
for i in tqdm(range(100)):
model(torch.randn(1, 600, 80))
print(out[-1].size())
num_params = sum(p.numel() for p in model.parameters())
print("{} M".format(num_params / 1e6))
from thop import profile
x_np = torch.randn(1, 100, 80)
flops, params = profile(model, inputs=(x_np,))
print("FLOPS: {} G, Params: {} M".format(flops / 1e9, params / 1e6))
# (model.seg_1(model.pool(out[2][...,3:8])) - model.seg_1(model.pool(out_frame[2][5]))).sum()
# model.seg_1(model.pool(out_frame[2].reshape(95*out_frame[2].shape[1], out_frame[2].shape[2], out_frame[2].shape[3], out_frame[2].shape[4])))
# model.seg_1(model.pool(out_frame[2].view(95*out_frame[2].shape[1], out_frame[2].shape[2], out_frame[2].shape[3], out_frame[2].shape[4])[0:6]))[-1] - model.pool(out[2][...,3:8])