mirohristov commited on
Commit
fb4f4cb
·
verified ·
1 Parent(s): 92ae278

Fixed gr.Plot and sources to work with gradio>=3.16.0

Browse files
Files changed (1) hide show
  1. app.py +78 -89
app.py CHANGED
@@ -1,110 +1,99 @@
1
  import gradio as gr
2
- import json
3
- import pandas as pd
4
  import collections
5
  import scipy.signal
6
  import numpy as np
 
7
  from functools import partial
8
- from openwakeword.model import Model
9
-
10
- #################################################
11
 
12
  from openwakeword.utils import download_models
13
-
14
- # this will pull down all of the ONNX + TFLite wake-word models into
15
- # openwakeword/resources/models/
16
- download_models()
17
-
18
  from openwakeword.model import Model
19
- model = Model(inference_framework="onnx")
20
 
21
- ###############################################
 
22
 
23
- # Load openWakeWord models
24
  model = Model(inference_framework="onnx")
25
 
26
- # Define function to process audio
27
- def process_audio(audio, state=collections.defaultdict(partial(collections.deque, maxlen=60))):
28
- # Resample audio to 16khz if needed
29
- if audio[0] != 16000:
30
- data = scipy.signal.resample(audio[1], int(float(audio[1].shape[0])/audio[0]*16000))
31
-
32
- # Get predictions
33
- for i in range(0, data.shape[0], 1280):
34
- if len(data.shape) == 2 or data.shape[-1] == 2:
35
- chunk = data[i:i+1280][:, 0] # just get one channel of audio
36
- else:
37
- chunk = data[i:i+1280]
38
-
39
- if chunk.shape[0] == 1280:
40
- prediction = model.predict(chunk)
41
- for key in prediction:
42
- #Fill deque with zeros if it's empty
43
- if len(state[key]) == 0:
44
- state[key].extend(np.zeros(60))
45
-
46
- # Add prediction
47
- state[key].append(prediction[key])
48
-
49
- # Make line plot
50
- dfs = []
51
- for key in state.keys():
52
- df = pd.DataFrame({"x": np.arange(len(state[key])), "y": state[key], "Model": key})
53
- dfs.append(df)
54
-
55
- df = pd.concat(dfs)
56
- plot = gr.LinePlot().update(value = df, x='x', y='y', color="Model", y_lim = (0,1), tooltip="Model",
57
- width=600, height=300, x_title="Time (frames)", y_title="Model Score", color_legend_position="bottom")
58
-
59
- # Manually adjust how the legend is displayed
60
- tmp = json.loads(plot["value"]["plot"])
61
- tmp["layer"][0]['encoding']['color']['legend']["direction"] = "vertical"
62
- tmp["layer"][0]['encoding']['color']['legend']["columns"] = 4
63
- tmp["layer"][0]['encoding']['color']['legend']["labelFontSize"] = 12
64
- tmp["layer"][0]['encoding']['color']['legend']["titleFontSize"] = 14
65
-
66
- plot["value"]['plot'] = json.dumps(tmp)
67
-
68
- return plot, state
69
-
70
- # Create Gradio interface and launch
71
-
72
- desc = """
73
- This is a demo of the pre-trained models included in the latest release
74
- of the [openWakeWord](https://github.com/dscripka/openWakeWord) library.
75
-
76
- Click on the "record from microphone" button below to start capturing.
77
- The real-time scores from each model will be shown in the line plot. Hover over
78
- each line to see the name of the corresponding model.
79
-
80
- Different models will respond to different wake words/phrases (see [the model docs](https://github.com/dscripka/openWakeWord/tree/main/docs/models) for more details).
81
- If everything is working properly,
82
- you should see a spike in the score for a given model after speaking a related word/phrase. Below are some suggested phrases to try!
83
-
84
- | Model Name | Word/Phrase |
85
- | --- | --- |
86
- | alexa | "alexa" |
87
- | hey_mycroft | "hey mycroft"|
88
- | hey_jarvis | "hey jarvis"|
89
- | hey_rhasspy | "hey rhasspy"|
90
- | weather | "what's the weather", "tell me today's weather" |
91
- | x_minute_timer | "set a timer for 1 minute", "create 1 hour alarm" |
92
-
93
  """
94
 
95
- gr_int = gr.Interface(
96
- title = "openWakeWord Live Demo",
97
- description = desc,
98
- css = ".flex {flex-direction: column} .gr-panel {width: 100%}",
99
  fn=process_audio,
 
 
100
  inputs=[
101
  gr.Audio(sources=["microphone"], type="numpy", streaming=True, show_label=False),
102
- "state"
103
  ],
104
  outputs=[
105
- gr.LinePlot(show_label=False),
106
- "state"
 
107
  ],
108
- live=True)
 
 
109
 
110
- gr_int.launch()
 
 
1
  import gradio as gr
 
 
2
  import collections
3
  import scipy.signal
4
  import numpy as np
5
+ import matplotlib.pyplot as plt
6
  from functools import partial
 
 
 
7
 
8
  from openwakeword.utils import download_models
 
 
 
 
 
9
  from openwakeword.model import Model
 
10
 
11
+ # Download all ONNX + TFLite models once
12
+ download_models()
13
 
14
+ # Initialize the ONNX-based wake-word model
15
  model = Model(inference_framework="onnx")
16
 
17
+ # Factory for per-model rolling buffers
18
+ initial_state = collections.defaultdict(partial(collections.deque, maxlen=60))
19
+
20
+ def process_audio(audio, state):
21
+ sr, samples = audio
22
+ # Resample if not 16 kHz
23
+ if sr != 16000:
24
+ samples = scipy.signal.resample(samples, int(len(samples) / sr * 16000))
25
+ data = samples
26
+
27
+ detected_msg = "" # Will hold our “Detected X!” text
28
+
29
+ # Slide in 1280-sample windows
30
+ for i in range(0, len(data), 1280):
31
+ chunk = data[i : i + 1280]
32
+ # Stereo mono
33
+ if chunk.ndim == 2 and chunk.shape[1] == 2:
34
+ chunk = chunk[:, 0]
35
+ if len(chunk) == 1280:
36
+ preds = model.predict(chunk)
37
+ for name, score in preds.items():
38
+ # Prime with zeros the first time
39
+ if len(state[name]) == 0:
40
+ state[name].extend(np.zeros(60))
41
+ state[name].append(score)
42
+
43
+ # If you want a threshold trigger:
44
+ if score > 0.8 and not detected_msg:
45
+ detected_msg = f"🗣 Detected **{name}**!"
46
+
47
+ # Build the plot
48
+ fig, ax = plt.subplots()
49
+ for name, dq in state.items():
50
+ ax.plot(np.arange(len(dq)), list(dq), label=name)
51
+
52
+ ax.set_ylim(0, 1)
53
+ ax.set_xlabel("Time (frames)")
54
+ ax.set_ylabel("Model Score")
55
+
56
+ # Only add a legend if at least one line has a label
57
+ if state:
58
+ ax.legend(loc="center left", bbox_to_anchor=(1, 0.5), fontsize="small")
59
+
60
+ plt.tight_layout()
61
+
62
+ # Capture and close to avoid memory leak
63
+ out_fig = fig
64
+ plt.close(fig)
65
+
66
+ return out_fig, state, detected_msg
67
+
68
+ description = """
69
+ Speak one of the wake-words into your mic and watch its score spike!
70
+
71
+ | Model Name | Phrase |
72
+ |----------------|------------------------------------|
73
+ | alexa | "alexa" |
74
+ | hey_mycroft | "hey mycroft" |
75
+ | hey_jarvis | "hey jarvis" |
76
+ | hey_rhasspy | "hey rhasspy" |
77
+ | weather | "what's the weather" |
78
+ | x_minute_timer | "set a timer for 1 minute" |
 
 
 
 
 
79
  """
80
 
81
+ iface = gr.Interface(
 
 
 
82
  fn=process_audio,
83
+ title="openWakeWord Live Demo",
84
+ description=description,
85
  inputs=[
86
  gr.Audio(sources=["microphone"], type="numpy", streaming=True, show_label=False),
87
+ gr.State(initial_state),
88
  ],
89
  outputs=[
90
+ gr.Plot(label="Model Scores"),
91
+ gr.State(),
92
+ gr.Textbox(label="Detection", interactive=False),
93
  ],
94
+ live=True,
95
+ css=".flex {flex-direction: column} .gr-panel {width: 100%}"
96
+ )
97
 
98
+ if __name__ == "__main__":
99
+ iface.launch()