WatchMeSpeak / Main.py
Siddarth's picture
Upload Main.py
d922c74
raw
history blame
1.3 kB
import torch
# To speed-up training process
torch.autograd.set_detect_anomaly(False)
torch.autograd.profiler.profile(False)
torch.autograd.profiler.emit_nvtx(False)
import warnings
import pickle
from Transformer import Transformer
import librosa
import os.path as path
import json
from tqdm import tqdm
import time
import numpy as np
warnings.filterwarnings("ignore")
with open("config.json") as json_data_file:
data = json.load(json_data_file)
lr = data['learn_rate']
epochs = data['epochs']
batch_size = data['batch_size']
Training = data['Training']
Testing = data['Testing']
main_path = data['MainPath']
device = data['Device']
diag_attn = data['DiagAttn']
BestModelPath = 'Best_GlobalModel_500_0_0.pt'
def pre_process_mfcc(mfcc):
mfcc = mfcc.T
mean_G = np.mean(mfcc, axis=0)
std_G = np.std(mfcc, axis=0)
mfcc = 0.5*(mfcc-mean_G)/std_G
return mfcc
def wav2art(wav):
rate = 16000
mfcc = librosa.feature.mfcc(wav, 16000, n_mfcc=13, hop_length=int(0.010*rate), n_fft=int(0.020*rate))
mfcc = pre_process_mfcc(mfcc)
mfcc = torch.tensor([mfcc]).float()
test_model = torch.load(BestModelPath, map_location=torch.device('cpu')).float()
test_model.eval()
p = test_model(mfcc, 0, 0, 0)
p = p[0].detach().numpy()
return p