import h5py, math, numpy as np from pydub import AudioSegment import torchaudio import torch import torch.nn as nn from StutterNet import SEP28KDataset, StutterNet, sigmoid import matplotlib.pyplot as plt import gradio as gr def clip(audio, step=500, clip_len=3000): # clip to 3 seconds clips = [] start = 0 while start + clip_len < int(len(audio)): clips.append(audio[start:start+clip_len]) start += step return clips def align(audio, mean, std): # normalize audio = (audio - audio.mean()) / audio.std() # align with training data audio = audio * std + mean return audio def stutterPrediction(f1, f2): SPEECH_FILE = f1 if f1 != None else f2 mean = -3.0667358666907646e-05 std = 0.08361649180070069 # clip audio = AudioSegment.from_wav(SPEECH_FILE) audio = audio.set_frame_rate(16000) audio = audio.set_channels(1) step, clip_len = 500, 3000 audio = clip(audio, step, clip_len) for i in range(len(audio)): audio[i] = np.array(audio[i].get_array_of_samples()) audio = np.array(audio) audio = audio.astype(np.float32) # align by mean & std for i in range(audio.shape[0]): audio[i] = align(audio[i], mean, std) # get device device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # get device # prepare data and predict # create required transforms spec = torchaudio.transforms.MelSpectrogram(n_mels=80, sample_rate=16000, n_fft=512, f_max=8000, f_min=0, power=0.5, hop_length=152, win_length=480) db = torchaudio.transforms.AmplitudeToDB() transforms = torch.jit.script(nn.Sequential(spec, db)) # create datasets ds = SEP28KDataset(audio, np.ones((audio.shape[0],12)), transform=transforms) dataloader = torch.utils.data.DataLoader(ds, batch_size=audio.shape[0], shuffle=False, num_workers=2, pin_memory=True) # ensemble learning state = torch.load("StutterNet2.pth", map_location=device) net1 = StutterNet(80, dropout=0.2, scale=2).to(device) net1.load_state_dict(state['state_dict']) net1.eval() state = torch.load("StutterNet.pth", map_location=device) net2 = StutterNet(80, dropout=0.2).to(device) net2.load_state_dict(state['state_dict']) net2.eval() # prediction placeholders preds = np.zeros((len(ds), 12)) for data in iter(dataloader): # get features and labels inputs, labels = data[0].to(device), data[1].detach().cpu().numpy() # get predictions preds = (net1(inputs).detach().cpu().numpy() + net2(inputs).detach().cpu().numpy()) / 2 preds = sigmoid(preds) names = np.loadtxt('classes.txt', dtype=str) fig1 = plt.figure() for i in range(6,12): plt.plot(np.arange(preds.shape[0])*step/1000, preds[:,i], label=names[i]) plt.legend() plt.xlabel("Seconds (time-axis)") plt.ylabel("Class") plt.tight_layout() dpreds = np.zeros(preds.shape) dpreds[0] = preds[0] for i in range(1, preds.shape[0]): dpreds[i] = preds[i] - preds[i-1] # ad-hoc first_threshold = 0.3 second_threshold = 0.2 cur_stat = None pred = [] for i in range(preds.shape[0]): ac = np.argmax(dpreds[i]) pr = np.argmax(preds[i]) if cur_stat != None and dpreds[i,cur_stat] < -second_threshold: cur_stat = None if dpreds[i,ac] > second_threshold: cur_stat = ac if cur_stat == None and preds[i,pr] > first_threshold: cur_stat = pr if cur_stat == None: pred.append('None') else: pred.append(names[cur_stat]) # plot fig2 = plt.figure() plt.plot(np.arange(preds.shape[0])*step/1000, pred) plt.xlabel("Seconds (time-axis)") plt.ylabel("Class") plt.tight_layout() return fig1, fig2 microphone, upload = gr.Audio(sources="microphone", type="filepath"), gr.Audio(sources="upload", type="filepath") demo = gr.Interface( fn = stutterPrediction, inputs = [microphone, upload], outputs = [gr.Plot(label="Classification map"), gr.Plot(label="Predicted map")] ) demo.launch(share=True, show_error=True)