hubtru commited on
Commit
bbd45e4
·
verified ·
1 Parent(s): c5efbb4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +281 -275
app.py CHANGED
@@ -1,276 +1,282 @@
1
- import gradio as gr
2
- import os
3
- import numpy as np
4
- import torch
5
- import torch.nn as nn
6
- import torchaudio
7
- import matplotlib.pyplot as plt
8
- import io
9
- from PIL import Image
10
-
11
- # Device and label IDs
12
- device_ids = ['a', 'b', 'c', 's1', 's2', 's3']
13
- label_ids = ['airport', 'bus', 'metro', 'metro_station', 'park',
14
- 'public_square', 'shopping_mall', 'street_pedestrian',
15
- 'street_traffic', 'tram']
16
-
17
- # Directories
18
- audio_dir = os.path.join('demo', 'audio')
19
- ir_dir = os.path.join('demo', 'impulse_responses')
20
- ir_names = ['Altec_639.wav', 'Altec_670A.wav', 'Altec_670B.wav']
21
-
22
- # Load impulse response files
23
- irs = []
24
- for ir_name in ir_names:
25
- ir_path = os.path.join(ir_dir, ir_name)
26
- ir, _ = torchaudio.load(ir_path)
27
- irs.append(ir)
28
-
29
- # Resampling and other transforms
30
- orig_sample_rate = 44100
31
- sample_rate = 32000
32
- resample = torchaudio.transforms.Resample(
33
- orig_freq=orig_sample_rate,
34
- new_freq=sample_rate
35
- )
36
- n_fft = 4096
37
- window_length = 3072
38
- hop_length = 500
39
- n_mels = 256
40
- f_min = 0
41
- f_max = None
42
- mel_spectrogram = torchaudio.transforms.MelSpectrogram(
43
- sample_rate=sample_rate,
44
- n_fft=n_fft,
45
- win_length=window_length,
46
- hop_length=hop_length,
47
- n_mels=n_mels,
48
- f_min=f_min,
49
- f_max=f_max
50
- )
51
-
52
- freqm = 48
53
- timem = 0
54
- freq_mask = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=True)
55
- time_mask = torchaudio.transforms.TimeMasking(timem, iid_masks=True)
56
- mel_augment = torch.nn.Sequential(
57
- freq_mask,
58
- time_mask
59
- )
60
-
61
- # Mixstyle function
62
- def mixstyle(x, p=0.4, alpha=0.3, eps=1e-6):
63
- if np.random.rand() > p:
64
- return x
65
- batch_size = x.size(0)
66
- f_mu = x.mean(dim=[1, 3], keepdim=True)
67
- f_var = x.var(dim=[1, 3], keepdim=True)
68
- f_sig = (f_var + eps).sqrt()
69
- f_mu, f_sig = f_mu.detach(), f_sig.detach()
70
- x_normed = (x - f_mu) / f_sig
71
- perm = torch.randperm(batch_size)
72
- f_mu_perm, f_sig_perm = f_mu[perm], f_sig[perm]
73
- lmda = torch.distributions.Beta(alpha, alpha).sample((batch_size, 1, 1, 1))
74
- lmda = lmda.to(x.device)
75
- mu_mix = f_mu * lmda + f_mu_perm * (1 - lmda)
76
- sig_mix = f_sig * lmda + f_sig_perm * (1 - lmda)
77
- x = x_normed * sig_mix + mu_mix
78
- return x
79
-
80
- # Model definition
81
- class Residual(nn.Module):
82
- def __init__(self, fn):
83
- super().__init__()
84
- self.fn = fn
85
-
86
- def forward(self, x):
87
- return self.fn(x) + x
88
-
89
- def ConvMixer(in_channels, filter, depth, kernel_size, patch_size, n_classes):
90
- return nn.Sequential(
91
- nn.Conv2d(in_channels, filter, kernel_size=patch_size, stride=patch_size),
92
- nn.GELU(),
93
- nn.BatchNorm2d(filter),
94
- *[nn.Sequential(
95
- Residual(nn.Sequential(
96
- nn.Conv2d(filter, filter, kernel_size, groups=filter, padding="same"),
97
- nn.GELU(),
98
- nn.BatchNorm2d(filter)
99
- )),
100
- nn.Conv2d(filter, filter, kernel_size=1),
101
- nn.GELU(),
102
- nn.BatchNorm2d(filter)
103
- ) for i in range(depth)],
104
- nn.AdaptiveAvgPool2d((1,1)),
105
- nn.Flatten(),
106
- nn.Linear(filter, n_classes)
107
- )
108
-
109
- # Instantiate and load the model
110
- # Model parameters (should match those used during training)
111
- in_channels = 1
112
- filter = 64
113
- depth = 9
114
- kernel_size = 3
115
- patch_size = 5
116
- n_classes = 10
117
-
118
- model = ConvMixer(in_channels, filter, depth, kernel_size, patch_size, n_classes)
119
- model_path = 'model.pth' # Path to the saved model weights
120
-
121
- # Load the model weights
122
- if os.path.exists(model_path):
123
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True))
124
- model.eval()
125
- else:
126
- print(f"Model file '{model_path}' not found. Please place the model file in the same directory.")
127
- # Optionally, you can raise an exception or exit
128
- # raise FileNotFoundError(f"Model file '{model_path}' not found.")
129
-
130
- # Function to process audio and generate outputs
131
- def process_audio(selected_label, selected_device):
132
- # Find matching audio files
133
- matching_files = []
134
- for filename in os.listdir(audio_dir):
135
- if not filename.endswith('.wav'):
136
- continue
137
- basename = os.path.splitext(filename)[0]
138
- parts = basename.split('-')
139
- if len(parts) < 6:
140
- continue
141
- scene, city, x, y, z, device = parts
142
- if scene == selected_label and device == selected_device:
143
- matching_files.append(filename)
144
- if len(matching_files) >= 3:
145
- break
146
- if not matching_files:
147
- return ["No matching audio files found"] * 21 # 21 outputs now
148
-
149
- outputs = []
150
- for audio_file in matching_files:
151
- # Load original audio
152
- audio_path = os.path.join(audio_dir, audio_file)
153
- waveform, sr = torchaudio.load(audio_path)
154
- # Resample
155
- waveform_resampled = resample(waveform)
156
- # Original audio player
157
- original_audio = (sample_rate, waveform_resampled.squeeze().numpy())
158
- outputs.append(original_audio)
159
-
160
- # Augment audio (apply impulse response)
161
- ir = irs[np.random.randint(len(irs))]
162
- augmented_waveform = torchaudio.functional.convolve(waveform_resampled, ir)[:, :waveform_resampled.shape[1]]
163
- # Augmented audio player
164
- augmented_audio = (sample_rate, augmented_waveform.squeeze().numpy())
165
- outputs.append(augmented_audio)
166
-
167
- # **Waveform plot of original vs augmented**
168
- fig, ax = plt.subplots()
169
- ax.plot(waveform_resampled.squeeze().numpy(), label='normal')
170
- ax.plot(augmented_waveform.squeeze().numpy(), label='augmented', linestyle='-.', alpha=0.8)
171
- ax.set_title(f'Label: {selected_label}')
172
- ax.legend()
173
- ax.set_xlabel('Time Samples')
174
- ax.set_ylabel('Amplitude')
175
- buf = io.BytesIO()
176
- plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
177
- plt.close(fig)
178
- buf.seek(0)
179
- waveform_plot_image = Image.open(buf)
180
- outputs.append(waveform_plot_image)
181
-
182
- # Mel-Spectrogram
183
- mel_spec = mel_spectrogram(augmented_waveform)
184
- mel_spec_db = (mel_spec + 1e-5).log()
185
- fig, ax = plt.subplots()
186
- ax.imshow(mel_spec_db.squeeze().numpy(), origin='lower', aspect='auto')
187
- ax.set_title('Mel-Spectrogram')
188
- plt.axis('off')
189
- buf = io.BytesIO()
190
- plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
191
- plt.close(fig)
192
- buf.seek(0)
193
- mel_spec_image = Image.open(buf)
194
- outputs.append(mel_spec_image)
195
-
196
- # Frequency and Time Masking
197
- masked_mel_spec = mel_augment(mel_spec_db)
198
- fig, ax = plt.subplots()
199
- ax.imshow(masked_mel_spec.squeeze().numpy(), origin='lower', aspect='auto')
200
- ax.set_title('Frequency and Time Masking')
201
- plt.axis('off')
202
- buf = io.BytesIO()
203
- plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
204
- plt.close(fig)
205
- buf.seek(0)
206
- masked_mel_spec_image = Image.open(buf)
207
- outputs.append(masked_mel_spec_image)
208
-
209
- # MixStyle Visualization
210
- x_mix = mixstyle(masked_mel_spec.unsqueeze(0), p=1.0)
211
- fig, ax = plt.subplots()
212
- ax.imshow(x_mix.squeeze().numpy(), origin='lower', aspect='auto')
213
- ax.set_title('MixStyle')
214
- plt.axis('off')
215
- buf = io.BytesIO()
216
- plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
217
- plt.close(fig)
218
- buf.seek(0)
219
- mixstyle_image = Image.open(buf)
220
- outputs.append(mixstyle_image)
221
-
222
- # Model Prediction
223
- with torch.no_grad():
224
- x = resample(waveform)
225
- x = mel_spectrogram(x)
226
- x = (x + 1e-5).log().unsqueeze(0)
227
- y_hat = model(x)
228
- predicted_idx = y_hat.argmax(dim=1).item()
229
- predicted_label = label_ids[predicted_idx]
230
- outputs.append(f"Predicted Class: {predicted_label}")
231
-
232
- # If less than 3 files, pad the outputs
233
- total_outputs_needed = 3 * 7 # 3 files * 7 outputs per file
234
- outputs += [""] * (total_outputs_needed - len(outputs))
235
- return outputs
236
-
237
- def gradio_interface():
238
- interface = gr.Interface(
239
- fn=process_audio,
240
- inputs=[
241
- gr.Dropdown(choices=label_ids, label="Select Label"),
242
- gr.Dropdown(choices=device_ids, label="Select Device")
243
- ],
244
- outputs=[
245
- gr.Audio(label="Original Audio 1"),
246
- gr.Audio(label="Augmented Audio 1"),
247
- gr.Image(label="Waveform Plot 1"),
248
- gr.Image(label="Mel-Spectrogram 1"),
249
- gr.Image(label="Frequency and Time Masking 1"),
250
- gr.Image(label="MixStyle 1"),
251
- gr.Textbox(label="Predicted Class 1"),
252
-
253
- gr.Audio(label="Original Audio 2"),
254
- gr.Audio(label="Augmented Audio 2"),
255
- gr.Image(label="Waveform Plot 2"),
256
- gr.Image(label="Mel-Spectrogram 2"),
257
- gr.Image(label="Frequency and Time Masking 2"),
258
- gr.Image(label="MixStyle 2"),
259
- gr.Textbox(label="Predicted Class 2"),
260
-
261
- gr.Audio(label="Original Audio 3"),
262
- gr.Audio(label="Augmented Audio 3"),
263
- gr.Image(label="Waveform Plot 3"),
264
- gr.Image(label="Mel-Spectrogram 3"),
265
- gr.Image(label="Frequency and Time Masking 3"),
266
- gr.Image(label="MixStyle 3"),
267
- gr.Textbox(label="Predicted Class 3")
268
- ],
269
- title="Acoustic Scene Classification Demo",
270
- description="Select a label and device to see audio examples, waveform plots, visualizations, and model predictions.",
271
- live=True,
272
- allow_flagging="never"
273
- )
274
- interface.launch()
275
-
 
 
 
 
 
 
276
  gradio_interface()
 
1
+ import gradio as gr
2
+ import os
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchaudio
7
+ import matplotlib.pyplot as plt
8
+ import io
9
+ from PIL import Image
10
+
11
+ # Device and label IDs
12
+ device_ids = ['a', 'b', 'c', 's1', 's2', 's3']
13
+ label_ids = ['airport', 'bus', 'metro', 'metro_station', 'park',
14
+ 'public_square', 'shopping_mall', 'street_pedestrian',
15
+ 'street_traffic', 'tram']
16
+
17
+ # Directories
18
+ audio_dir = os.path.join('demo', 'audio')
19
+ ir_dir = os.path.join('demo', 'impulse_responses')
20
+ ir_names = ['Altec_639.wav', 'Altec_670A.wav', 'Altec_670B.wav']
21
+
22
+ # Load impulse response files
23
+ irs = []
24
+ for ir_name in ir_names:
25
+ ir_path = os.path.join(ir_dir, ir_name)
26
+ ir, _ = torchaudio.load(ir_path)
27
+ irs.append(ir)
28
+
29
+ # Resampling and other transforms
30
+ orig_sample_rate = 44100
31
+ sample_rate = 32000
32
+ resample = torchaudio.transforms.Resample(
33
+ orig_freq=orig_sample_rate,
34
+ new_freq=sample_rate
35
+ )
36
+ n_fft = 4096
37
+ window_length = 3072
38
+ hop_length = 500
39
+ n_mels = 256
40
+ f_min = 0
41
+ f_max = None
42
+ mel_spectrogram = torchaudio.transforms.MelSpectrogram(
43
+ sample_rate=sample_rate,
44
+ n_fft=n_fft,
45
+ win_length=window_length,
46
+ hop_length=hop_length,
47
+ n_mels=n_mels,
48
+ f_min=f_min,
49
+ f_max=f_max
50
+ )
51
+
52
+ freqm = 48
53
+ timem = 0
54
+ freq_mask = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=True)
55
+ time_mask = torchaudio.transforms.TimeMasking(timem, iid_masks=True)
56
+ mel_augment = torch.nn.Sequential(
57
+ freq_mask,
58
+ time_mask
59
+ )
60
+
61
+ # Mixstyle function
62
+ def mixstyle(x, p=0.4, alpha=0.3, eps=1e-6):
63
+ if np.random.rand() > p:
64
+ return x
65
+ batch_size = x.size(0)
66
+ f_mu = x.mean(dim=[1, 3], keepdim=True)
67
+ f_var = x.var(dim=[1, 3], keepdim=True)
68
+ f_sig = (f_var + eps).sqrt()
69
+ f_mu, f_sig = f_mu.detach(), f_sig.detach()
70
+ x_normed = (x - f_mu) / f_sig
71
+ perm = torch.randperm(batch_size)
72
+ f_mu_perm, f_sig_perm = f_mu[perm], f_sig[perm]
73
+ lmda = torch.distributions.Beta(alpha, alpha).sample((batch_size, 1, 1, 1))
74
+ lmda = lmda.to(x.device)
75
+ mu_mix = f_mu * lmda + f_mu_perm * (1 - lmda)
76
+ sig_mix = f_sig * lmda + f_sig_perm * (1 - lmda)
77
+ x = x_normed * sig_mix + mu_mix
78
+ return x
79
+
80
+ # Model definition
81
+ class Residual(nn.Module):
82
+ def __init__(self, fn):
83
+ super().__init__()
84
+ self.fn = fn
85
+
86
+ def forward(self, x):
87
+ return self.fn(x) + x
88
+
89
+ def ConvMixer(in_channels, filter, depth, kernel_size, patch_size, n_classes):
90
+ return nn.Sequential(
91
+ nn.Conv2d(in_channels, filter, kernel_size=patch_size, stride=patch_size),
92
+ nn.GELU(),
93
+ nn.BatchNorm2d(filter),
94
+ *[nn.Sequential(
95
+ Residual(nn.Sequential(
96
+ nn.Conv2d(filter, filter, kernel_size, groups=filter, padding="same"),
97
+ nn.GELU(),
98
+ nn.BatchNorm2d(filter)
99
+ )),
100
+ nn.Conv2d(filter, filter, kernel_size=1),
101
+ nn.GELU(),
102
+ nn.BatchNorm2d(filter)
103
+ ) for i in range(depth)],
104
+ nn.AdaptiveAvgPool2d((1,1)),
105
+ nn.Flatten(),
106
+ nn.Linear(filter, n_classes)
107
+ )
108
+
109
+ # Instantiate and load the model
110
+ # Model parameters (should match those used during training)
111
+ in_channels = 1
112
+ filter = 64
113
+ depth = 9
114
+ kernel_size = 3
115
+ patch_size = 5
116
+ n_classes = 10
117
+
118
+ model = ConvMixer(in_channels, filter, depth, kernel_size, patch_size, n_classes)
119
+ model_path = 'model.pth' # Path to the saved model weights
120
+
121
+ # Load the model weights
122
+ if os.path.exists(model_path):
123
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True))
124
+ model.eval()
125
+ else:
126
+ print(f"Model file '{model_path}' not found. Please place the model file in the same directory.")
127
+ # Optionally, you can raise an exception or exit
128
+ # raise FileNotFoundError(f"Model file '{model_path}' not found.")
129
+
130
+ # Function to process audio and generate outputs
131
+ def process_audio(selected_label, selected_device):
132
+ # Find matching audio files
133
+ matching_files = []
134
+ for filename in os.listdir(audio_dir):
135
+ if not filename.endswith('.wav'):
136
+ continue
137
+ basename = os.path.splitext(filename)[0]
138
+ parts = basename.split('-')
139
+ if len(parts) < 6:
140
+ continue
141
+ scene, city, x, y, z, device = parts
142
+ if scene == selected_label and device == selected_device:
143
+ matching_files.append(filename)
144
+ if len(matching_files) >= 3:
145
+ break
146
+ if not matching_files:
147
+ return ["No matching audio files found"] * 21 # 21 outputs now
148
+
149
+ outputs = []
150
+ for audio_file in matching_files:
151
+ # Load original audio
152
+ audio_path = os.path.join(audio_dir, audio_file)
153
+ waveform, sr = torchaudio.load(audio_path)
154
+ # Resample
155
+ waveform_resampled = resample(waveform)
156
+ # Original audio player
157
+ original_audio = (sample_rate, waveform_resampled.squeeze().numpy())
158
+ outputs.append(original_audio)
159
+
160
+ # Augment audio (apply impulse response)
161
+ ir = irs[np.random.randint(len(irs))]
162
+ augmented_waveform = torchaudio.functional.convolve(waveform_resampled, ir)[:, :waveform_resampled.shape[1]]
163
+ # Augmented audio player
164
+ augmented_audio = (sample_rate, augmented_waveform.squeeze().numpy())
165
+ outputs.append(augmented_audio)
166
+
167
+ # **Waveform plot of original vs augmented**
168
+ fig, ax = plt.subplots()
169
+ ax.plot(waveform_resampled.squeeze().numpy(), label='normal')
170
+ ax.plot(augmented_waveform.squeeze().numpy(), label='augmented', linestyle='-.', alpha=0.8)
171
+ ax.set_title(f'Label: {selected_label}')
172
+ ax.legend()
173
+ ax.set_xlabel('Time Samples')
174
+ ax.set_ylabel('Amplitude')
175
+ buf = io.BytesIO()
176
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
177
+ plt.close(fig)
178
+ buf.seek(0)
179
+ waveform_plot_image = Image.open(buf)
180
+ outputs.append(waveform_plot_image)
181
+
182
+ # Mel-Spectrogram
183
+ mel_spec = mel_spectrogram(augmented_waveform)
184
+ mel_spec_db = (mel_spec + 1e-5).log()
185
+ fig, ax = plt.subplots()
186
+ ax.imshow(mel_spec_db.squeeze().numpy(), origin='lower', aspect='auto')
187
+ ax.set_title('Mel-Spectrogram')
188
+ plt.axis('off')
189
+ buf = io.BytesIO()
190
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
191
+ plt.close(fig)
192
+ buf.seek(0)
193
+ mel_spec_image = Image.open(buf)
194
+ outputs.append(mel_spec_image)
195
+
196
+ # Frequency and Time Masking
197
+ masked_mel_spec = mel_augment(mel_spec_db)
198
+ fig, ax = plt.subplots()
199
+ ax.imshow(masked_mel_spec.squeeze().numpy(), origin='lower', aspect='auto')
200
+ ax.set_title('Frequency and Time Masking')
201
+ plt.axis('off')
202
+ buf = io.BytesIO()
203
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
204
+ plt.close(fig)
205
+ buf.seek(0)
206
+ masked_mel_spec_image = Image.open(buf)
207
+ outputs.append(masked_mel_spec_image)
208
+
209
+ # MixStyle Visualization
210
+ x_mix = mixstyle(masked_mel_spec.unsqueeze(0), p=1.0)
211
+ fig, ax = plt.subplots()
212
+ ax.imshow(x_mix.squeeze().numpy(), origin='lower', aspect='auto')
213
+ ax.set_title('MixStyle')
214
+ plt.axis('off')
215
+ buf = io.BytesIO()
216
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
217
+ plt.close(fig)
218
+ buf.seek(0)
219
+ mixstyle_image = Image.open(buf)
220
+ outputs.append(mixstyle_image)
221
+
222
+ # Model Prediction
223
+ with torch.no_grad():
224
+ x = resample(waveform)
225
+ x = mel_spectrogram(x)
226
+ x = (x + 1e-5).log().unsqueeze(0)
227
+ y_hat = model(x)
228
+ predicted_idx = y_hat.argmax(dim=1).item()
229
+ predicted_label = label_ids[predicted_idx]
230
+ outputs.append(f"Predicted Class: {predicted_label}")
231
+
232
+ # If less than 3 files, pad the outputs
233
+ total_outputs_needed = 3 * 7 # 3 files * 7 outputs per file
234
+ outputs += [""] * (total_outputs_needed - len(outputs))
235
+ return outputs
236
+
237
+ def gradio_interface():
238
+ interface = gr.Interface(
239
+ fn=process_audio,
240
+ inputs=[
241
+ gr.Dropdown(choices=label_ids, label="Select Label"),
242
+ gr.Dropdown(choices=device_ids, label="Select Device")
243
+ ],
244
+ outputs=[
245
+ gr.Audio(label="Original Audio 1"),
246
+ gr.Audio(label="Augmented Audio 1"),
247
+ gr.Image(label="Waveform Plot 1"),
248
+ gr.Image(label="Mel-Spectrogram 1"),
249
+ gr.Image(label="Frequency and Time Masking 1"),
250
+ gr.Image(label="MixStyle 1"),
251
+ gr.Textbox(label="Predicted Class 1"),
252
+
253
+ gr.Audio(label="Original Audio 2"),
254
+ gr.Audio(label="Augmented Audio 2"),
255
+ gr.Image(label="Waveform Plot 2"),
256
+ gr.Image(label="Mel-Spectrogram 2"),
257
+ gr.Image(label="Frequency and Time Masking 2"),
258
+ gr.Image(label="MixStyle 2"),
259
+ gr.Textbox(label="Predicted Class 2"),
260
+
261
+ gr.Audio(label="Original Audio 3"),
262
+ gr.Audio(label="Augmented Audio 3"),
263
+ gr.Image(label="Waveform Plot 3"),
264
+ gr.Image(label="Mel-Spectrogram 3"),
265
+ gr.Image(label="Frequency and Time Masking 3"),
266
+ gr.Image(label="MixStyle 3"),
267
+ gr.Textbox(label="Predicted Class 3")
268
+ ],
269
+ title="ASCDomain",
270
+ description="
271
+ ASCDomain: Domain Invariant Device-Self-Challenging Isotopic Convolutional Neural Architecture
272
+ ASCDomain Repository: https://github.com/hubtru/ASCDomain
273
+ Options:
274
+ * Acoustic Scene: Airport, Indor shopping mall, metro station, pedestrian street, public square, street with medium level of traffic, travelling by a tram, travelling by a bus, travelling by an underground metro, urban park
275
+ * Mobile device: a, b, c, s1, s2, s3
276
+ Select a label and device to see audio examples, waveform plots, visualizations, and model predictions.",
277
+ live=True,
278
+ allow_flagging="never"
279
+ )
280
+ interface.launch()
281
+
282
  gradio_interface()