File size: 4,043 Bytes
b319084
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96535cf
 
 
2192134
 
b319084
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2192134
b319084
 
 
 
2192134
b319084
 
 
 
 
 
 
 
 
 
 
 
 
 
 
542bfee
 
a63b267
 
 
 
 
 
 
 
b319084
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a63b267
b319084
 
 
 
 
b128274
b319084
96535cf
 
b319084
 
96535cf
a63b267
b319084
 
8ad2c9f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import h5py, math, numpy as np
from pydub import AudioSegment
import torchaudio
import torch
import torch.nn as nn
from StutterNet import SEP28KDataset, StutterNet, sigmoid
import matplotlib.pyplot as plt
import gradio as gr

def clip(audio, step=500, clip_len=3000):
  # clip to 3 seconds
  clips = []
  start = 0
  while start + clip_len < int(len(audio)):
    clips.append(audio[start:start+clip_len])
    start += step
  return clips

def align(audio, mean, std):
  # normalize
  audio = (audio - audio.mean()) / audio.std()
  # align with training data
  audio = audio * std + mean
  return audio

def stutterPrediction(f1, f2):
  SPEECH_FILE = f1 if f1 != None else f2

  mean = -3.0667358666907646e-05
  std = 0.08361649180070069

  # clip
  audio = AudioSegment.from_wav(SPEECH_FILE)
  audio = audio.set_frame_rate(16000)
  audio = audio.set_channels(1)
  step, clip_len = 500, 3000
  audio = clip(audio, step, clip_len)

  for i in range(len(audio)):
    audio[i] = np.array(audio[i].get_array_of_samples())
  audio = np.array(audio)
  audio = audio.astype(np.float32)

  # align by mean & std
  for i in range(audio.shape[0]):
    audio[i] = align(audio[i], mean, std)

  # get device
  device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # get device

  # prepare data and predict
  # create required transforms
  spec = torchaudio.transforms.MelSpectrogram(n_mels=80, sample_rate=16000,
                                                n_fft=512, f_max=8000, f_min=0,
                                                power=0.5, hop_length=152, win_length=480)
  db = torchaudio.transforms.AmplitudeToDB()
  transforms = torch.jit.script(nn.Sequential(spec, db))

  # create datasets
  ds = SEP28KDataset(audio, np.ones((audio.shape[0],12)), transform=transforms)
  dataloader = torch.utils.data.DataLoader(ds, batch_size=audio.shape[0], shuffle=False, num_workers=2, pin_memory=True)

  # ensemble learning
  state = torch.load("StutterNet2.pth", map_location=device)
  net1 = StutterNet(80, dropout=0.2, scale=2).to(device)
  net1.load_state_dict(state['state_dict'])
  net1.eval()

  state = torch.load("StutterNet.pth", map_location=device)
  net2 = StutterNet(80, dropout=0.2).to(device)
  net2.load_state_dict(state['state_dict'])
  net2.eval()

  # prediction placeholders
  preds = np.zeros((len(ds), 12))

  for data in iter(dataloader):
    # get features and labels
    inputs, labels = data[0].to(device), data[1].detach().cpu().numpy()

    # get predictions
    preds = (net1(inputs).detach().cpu().numpy() + net2(inputs).detach().cpu().numpy()) / 2

  preds = sigmoid(preds)
  names = np.loadtxt('classes.txt', dtype=str)
  
  fig1 = plt.figure()
  for i in range(6,12):
    plt.plot(np.arange(preds.shape[0])*step/1000, preds[:,i], label=names[i])
  plt.legend()
  plt.xlabel("Seconds (time-axis)")
  plt.ylabel("Class")
  plt.tight_layout()

  dpreds = np.zeros(preds.shape)

  dpreds[0] = preds[0]
  for i in range(1, preds.shape[0]):
    dpreds[i] = preds[i] - preds[i-1]

  # ad-hoc
  first_threshold = 0.3
  second_threshold = 0.2
  cur_stat = None
  pred = []
  for i in range(preds.shape[0]):
    ac = np.argmax(dpreds[i])
    pr = np.argmax(preds[i])
    if cur_stat != None and dpreds[i,cur_stat] < -second_threshold:
      cur_stat = None
    if dpreds[i,ac] > second_threshold:
      cur_stat = ac
    if cur_stat == None and preds[i,pr] > first_threshold:
      cur_stat = pr
    if cur_stat == None:
      pred.append('None')
    else:
      pred.append(names[cur_stat])

  # plot
  fig2 = plt.figure()
  plt.plot(np.arange(preds.shape[0])*step/1000, pred)
  plt.xlabel("Seconds (time-axis)")
  plt.ylabel("Class")
  plt.tight_layout()

  return fig1, fig2

microphone, upload = gr.Audio(sources="microphone", type="filepath"), gr.Audio(sources="upload", type="filepath")

demo = gr.Interface(
    fn = stutterPrediction,
    inputs = [microphone, upload],
    outputs = [gr.Plot(label="Classification map"), gr.Plot(label="Predicted map")]
)

demo.launch(share=True, show_error=True)