huazai676's picture
Update app.py
542bfee verified
import h5py, math, numpy as np
from pydub import AudioSegment
import torchaudio
import torch
import torch.nn as nn
from StutterNet import SEP28KDataset, StutterNet, sigmoid
import matplotlib.pyplot as plt
import gradio as gr
def clip(audio, step=500, clip_len=3000):
# clip to 3 seconds
clips = []
start = 0
while start + clip_len < int(len(audio)):
clips.append(audio[start:start+clip_len])
start += step
return clips
def align(audio, mean, std):
# normalize
audio = (audio - audio.mean()) / audio.std()
# align with training data
audio = audio * std + mean
return audio
def stutterPrediction(f1, f2):
SPEECH_FILE = f1 if f1 != None else f2
mean = -3.0667358666907646e-05
std = 0.08361649180070069
# clip
audio = AudioSegment.from_wav(SPEECH_FILE)
audio = audio.set_frame_rate(16000)
audio = audio.set_channels(1)
step, clip_len = 500, 3000
audio = clip(audio, step, clip_len)
for i in range(len(audio)):
audio[i] = np.array(audio[i].get_array_of_samples())
audio = np.array(audio)
audio = audio.astype(np.float32)
# align by mean & std
for i in range(audio.shape[0]):
audio[i] = align(audio[i], mean, std)
# get device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # get device
# prepare data and predict
# create required transforms
spec = torchaudio.transforms.MelSpectrogram(n_mels=80, sample_rate=16000,
n_fft=512, f_max=8000, f_min=0,
power=0.5, hop_length=152, win_length=480)
db = torchaudio.transforms.AmplitudeToDB()
transforms = torch.jit.script(nn.Sequential(spec, db))
# create datasets
ds = SEP28KDataset(audio, np.ones((audio.shape[0],12)), transform=transforms)
dataloader = torch.utils.data.DataLoader(ds, batch_size=audio.shape[0], shuffle=False, num_workers=2, pin_memory=True)
# ensemble learning
state = torch.load("StutterNet2.pth", map_location=device)
net1 = StutterNet(80, dropout=0.2, scale=2).to(device)
net1.load_state_dict(state['state_dict'])
net1.eval()
state = torch.load("StutterNet.pth", map_location=device)
net2 = StutterNet(80, dropout=0.2).to(device)
net2.load_state_dict(state['state_dict'])
net2.eval()
# prediction placeholders
preds = np.zeros((len(ds), 12))
for data in iter(dataloader):
# get features and labels
inputs, labels = data[0].to(device), data[1].detach().cpu().numpy()
# get predictions
preds = (net1(inputs).detach().cpu().numpy() + net2(inputs).detach().cpu().numpy()) / 2
preds = sigmoid(preds)
names = np.loadtxt('classes.txt', dtype=str)
fig1 = plt.figure()
for i in range(6,12):
plt.plot(np.arange(preds.shape[0])*step/1000, preds[:,i], label=names[i])
plt.legend()
plt.xlabel("Seconds (time-axis)")
plt.ylabel("Class")
plt.tight_layout()
dpreds = np.zeros(preds.shape)
dpreds[0] = preds[0]
for i in range(1, preds.shape[0]):
dpreds[i] = preds[i] - preds[i-1]
# ad-hoc
first_threshold = 0.3
second_threshold = 0.2
cur_stat = None
pred = []
for i in range(preds.shape[0]):
ac = np.argmax(dpreds[i])
pr = np.argmax(preds[i])
if cur_stat != None and dpreds[i,cur_stat] < -second_threshold:
cur_stat = None
if dpreds[i,ac] > second_threshold:
cur_stat = ac
if cur_stat == None and preds[i,pr] > first_threshold:
cur_stat = pr
if cur_stat == None:
pred.append('None')
else:
pred.append(names[cur_stat])
# plot
fig2 = plt.figure()
plt.plot(np.arange(preds.shape[0])*step/1000, pred)
plt.xlabel("Seconds (time-axis)")
plt.ylabel("Class")
plt.tight_layout()
return fig1, fig2
microphone, upload = gr.Audio(sources="microphone", type="filepath"), gr.Audio(sources="upload", type="filepath")
demo = gr.Interface(
fn = stutterPrediction,
inputs = [microphone, upload],
outputs = [gr.Plot(label="Classification map"), gr.Plot(label="Predicted map")]
)
demo.launch(share=True, show_error=True)