| import soundfile as sf | |
| import torch | |
| import fire | |
| import torch.nn.functional as F | |
| from torchaudio.functional import resample | |
| from models.ecapa_tdnn import ECAPA_TDNN_SMALL | |
| MODEL_LIST = ['ecapa_tdnn', 'hubert_large', 'wav2vec2_xlsr', 'unispeech_sat'] | |
| def init_model(model_name, checkpoint=None): | |
| if model_name == 'unispeech_sat': | |
| config_path = 'config/unispeech_sat.th' | |
| model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='unispeech_sat', config_path=config_path) | |
| elif model_name == 'hubert_large': | |
| config_path = None | |
| model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='hubert_large_ll60k', config_path=config_path) | |
| elif model_name == 'wav2vec2_xlsr': | |
| config_path = None | |
| model = ECAPA_TDNN_SMALL(feat_dim=1024, feat_type='wav2vec2_xlsr', config_path=config_path) | |
| else: | |
| model = ECAPA_TDNN_SMALL(feat_dim=40, feat_type='fbank') | |
| if checkpoint is not None: | |
| state_dict = torch.load(checkpoint, map_location=lambda storage, loc: storage) | |
| model.load_state_dict(state_dict['model'], strict=False) | |
| return model | |
| def verification(model_name, wav1, wav2, use_gpu=True, checkpoint=None): | |
| assert model_name in MODEL_LIST, 'The model_name should be in {}'.format(MODEL_LIST) | |
| model = init_model(model_name, checkpoint) | |
| wav1, sr1 = sf.read(wav1) | |
| wav2, sr2 = sf.read(wav2) | |
| wav1 = torch.from_numpy(wav1).unsqueeze(0).float() | |
| wav2 = torch.from_numpy(wav2).unsqueeze(0).float() | |
| wav1 = resample(wav1, orig_freq=sr1, new_freq=16000) | |
| wav2 = resample(wav2, orig_freq=sr2, new_freq=16000) | |
| if use_gpu: | |
| model = model.cuda() | |
| wav1 = wav1.cuda() | |
| wav2 = wav2.cuda() | |
| model.eval() | |
| with torch.no_grad(): | |
| emb1 = model(wav1) | |
| emb2 = model(wav2) | |
| sim = F.cosine_similarity(emb1, emb2) | |
| print("The similarity score between two audios is {:.4f} (-1.0, 1.0).".format(sim[0].item())) | |
| if __name__ == "__main__": | |
| fire.Fire(verification) | |