Spaces:
Sleeping
Sleeping
File size: 4,043 Bytes
b319084 96535cf 2192134 b319084 2192134 b319084 2192134 b319084 542bfee a63b267 b319084 a63b267 b319084 b128274 b319084 96535cf b319084 96535cf a63b267 b319084 8ad2c9f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 | import h5py, math, numpy as np
from pydub import AudioSegment
import torchaudio
import torch
import torch.nn as nn
from StutterNet import SEP28KDataset, StutterNet, sigmoid
import matplotlib.pyplot as plt
import gradio as gr
def clip(audio, step=500, clip_len=3000):
# clip to 3 seconds
clips = []
start = 0
while start + clip_len < int(len(audio)):
clips.append(audio[start:start+clip_len])
start += step
return clips
def align(audio, mean, std):
# normalize
audio = (audio - audio.mean()) / audio.std()
# align with training data
audio = audio * std + mean
return audio
def stutterPrediction(f1, f2):
SPEECH_FILE = f1 if f1 != None else f2
mean = -3.0667358666907646e-05
std = 0.08361649180070069
# clip
audio = AudioSegment.from_wav(SPEECH_FILE)
audio = audio.set_frame_rate(16000)
audio = audio.set_channels(1)
step, clip_len = 500, 3000
audio = clip(audio, step, clip_len)
for i in range(len(audio)):
audio[i] = np.array(audio[i].get_array_of_samples())
audio = np.array(audio)
audio = audio.astype(np.float32)
# align by mean & std
for i in range(audio.shape[0]):
audio[i] = align(audio[i], mean, std)
# get device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # get device
# prepare data and predict
# create required transforms
spec = torchaudio.transforms.MelSpectrogram(n_mels=80, sample_rate=16000,
n_fft=512, f_max=8000, f_min=0,
power=0.5, hop_length=152, win_length=480)
db = torchaudio.transforms.AmplitudeToDB()
transforms = torch.jit.script(nn.Sequential(spec, db))
# create datasets
ds = SEP28KDataset(audio, np.ones((audio.shape[0],12)), transform=transforms)
dataloader = torch.utils.data.DataLoader(ds, batch_size=audio.shape[0], shuffle=False, num_workers=2, pin_memory=True)
# ensemble learning
state = torch.load("StutterNet2.pth", map_location=device)
net1 = StutterNet(80, dropout=0.2, scale=2).to(device)
net1.load_state_dict(state['state_dict'])
net1.eval()
state = torch.load("StutterNet.pth", map_location=device)
net2 = StutterNet(80, dropout=0.2).to(device)
net2.load_state_dict(state['state_dict'])
net2.eval()
# prediction placeholders
preds = np.zeros((len(ds), 12))
for data in iter(dataloader):
# get features and labels
inputs, labels = data[0].to(device), data[1].detach().cpu().numpy()
# get predictions
preds = (net1(inputs).detach().cpu().numpy() + net2(inputs).detach().cpu().numpy()) / 2
preds = sigmoid(preds)
names = np.loadtxt('classes.txt', dtype=str)
fig1 = plt.figure()
for i in range(6,12):
plt.plot(np.arange(preds.shape[0])*step/1000, preds[:,i], label=names[i])
plt.legend()
plt.xlabel("Seconds (time-axis)")
plt.ylabel("Class")
plt.tight_layout()
dpreds = np.zeros(preds.shape)
dpreds[0] = preds[0]
for i in range(1, preds.shape[0]):
dpreds[i] = preds[i] - preds[i-1]
# ad-hoc
first_threshold = 0.3
second_threshold = 0.2
cur_stat = None
pred = []
for i in range(preds.shape[0]):
ac = np.argmax(dpreds[i])
pr = np.argmax(preds[i])
if cur_stat != None and dpreds[i,cur_stat] < -second_threshold:
cur_stat = None
if dpreds[i,ac] > second_threshold:
cur_stat = ac
if cur_stat == None and preds[i,pr] > first_threshold:
cur_stat = pr
if cur_stat == None:
pred.append('None')
else:
pred.append(names[cur_stat])
# plot
fig2 = plt.figure()
plt.plot(np.arange(preds.shape[0])*step/1000, pred)
plt.xlabel("Seconds (time-axis)")
plt.ylabel("Class")
plt.tight_layout()
return fig1, fig2
microphone, upload = gr.Audio(sources="microphone", type="filepath"), gr.Audio(sources="upload", type="filepath")
demo = gr.Interface(
fn = stutterPrediction,
inputs = [microphone, upload],
outputs = [gr.Plot(label="Classification map"), gr.Plot(label="Predicted map")]
)
demo.launch(share=True, show_error=True)
|