WatchMeSpeak / Main.py
Siddarth's picture
Update Main.py
a4681ba verified
import torch
# To speed-up training process
torch.autograd.set_detect_anomaly(False)
torch.autograd.profiler.profile(False)
torch.autograd.profiler.emit_nvtx(False)
import warnings
# import pickle
# from Transformer import Transformer
import librosa
# import os.path as path
import json
# from tqdm import tqdm
# import time
import numpy as np
import os
warnings.filterwarnings("ignore")
with open("config.json") as json_data_file:
data = json.load(json_data_file)
lr = data['learn_rate']
epochs = data['epochs']
batch_size = data['batch_size']
Training = data['Training']
Testing = data['Testing']
main_path = data['MainPath']
device = data['Device']
diag_attn = data['DiagAttn']
dir_path = os.path.dirname(os.path.realpath(__file__))
BestModelPath = dir_path + '/Best_GlobalModel_500_0_0.pt'
def pre_process_mfcc(mfcc):
mfcc = mfcc.T
mean_G = np.mean(mfcc, axis=0)
std_G = np.std(mfcc, axis=0)
mfcc = 0.5*(mfcc-mean_G)/std_G
return mfcc
def wav2art(wav):
rate = 16000
mfcc = librosa.feature.mfcc(y = wav, sr = 16000, n_mfcc=13, hop_length=int(0.010*rate), n_fft=int(0.020*rate))
mfcc = pre_process_mfcc(mfcc)
mfcc = torch.tensor([mfcc]).float()
test_model = torch.load(BestModelPath, map_location=torch.device('cpu')).float()
test_model.eval()
p = test_model(mfcc, 0, 0, 0)
p = p[0].detach().numpy()
return p