| import gradio as gr |
| import os |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torchaudio |
| import matplotlib.pyplot as plt |
| import io |
| from PIL import Image |
|
|
| |
| device_ids = ['a', 'b', 'c', 's1', 's2', 's3'] |
| label_ids = ['airport', 'bus', 'metro', 'metro_station', 'park', |
| 'public_square', 'shopping_mall', 'street_pedestrian', |
| 'street_traffic', 'tram'] |
|
|
| |
| audio_dir = os.path.join('demo', 'audio') |
| ir_dir = os.path.join('demo', 'impulse_responses') |
| ir_names = ['Altec_639.wav', 'Altec_670A.wav', 'Altec_670B.wav'] |
|
|
| |
| irs = [] |
| for ir_name in ir_names: |
| ir_path = os.path.join(ir_dir, ir_name) |
| ir, _ = torchaudio.load(ir_path) |
| irs.append(ir) |
|
|
| |
| orig_sample_rate = 44100 |
| sample_rate = 32000 |
| resample = torchaudio.transforms.Resample( |
| orig_freq=orig_sample_rate, |
| new_freq=sample_rate |
| ) |
| n_fft = 4096 |
| window_length = 3072 |
| hop_length = 500 |
| n_mels = 256 |
| f_min = 0 |
| f_max = None |
| mel_spectrogram = torchaudio.transforms.MelSpectrogram( |
| sample_rate=sample_rate, |
| n_fft=n_fft, |
| win_length=window_length, |
| hop_length=hop_length, |
| n_mels=n_mels, |
| f_min=f_min, |
| f_max=f_max |
| ) |
|
|
| freqm = 48 |
| timem = 0 |
| freq_mask = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=True) |
| time_mask = torchaudio.transforms.TimeMasking(timem, iid_masks=True) |
| mel_augment = torch.nn.Sequential( |
| freq_mask, |
| time_mask |
| ) |
|
|
| |
| def mixstyle(x, p=0.4, alpha=0.3, eps=1e-6): |
| if np.random.rand() > p: |
| return x |
| batch_size = x.size(0) |
| f_mu = x.mean(dim=[1, 3], keepdim=True) |
| f_var = x.var(dim=[1, 3], keepdim=True) |
| f_sig = (f_var + eps).sqrt() |
| f_mu, f_sig = f_mu.detach(), f_sig.detach() |
| x_normed = (x - f_mu) / f_sig |
| perm = torch.randperm(batch_size) |
| f_mu_perm, f_sig_perm = f_mu[perm], f_sig[perm] |
| lmda = torch.distributions.Beta(alpha, alpha).sample((batch_size, 1, 1, 1)) |
| lmda = lmda.to(x.device) |
| mu_mix = f_mu * lmda + f_mu_perm * (1 - lmda) |
| sig_mix = f_sig * lmda + f_sig_perm * (1 - lmda) |
| x = x_normed * sig_mix + mu_mix |
| return x |
|
|
| |
| class Residual(nn.Module): |
| def __init__(self, fn): |
| super().__init__() |
| self.fn = fn |
|
|
| def forward(self, x): |
| return self.fn(x) + x |
|
|
| def ConvMixer(in_channels, filter, depth, kernel_size, patch_size, n_classes): |
| return nn.Sequential( |
| nn.Conv2d(in_channels, filter, kernel_size=patch_size, stride=patch_size), |
| nn.GELU(), |
| nn.BatchNorm2d(filter), |
| *[nn.Sequential( |
| Residual(nn.Sequential( |
| nn.Conv2d(filter, filter, kernel_size, groups=filter, padding="same"), |
| nn.GELU(), |
| nn.BatchNorm2d(filter) |
| )), |
| nn.Conv2d(filter, filter, kernel_size=1), |
| nn.GELU(), |
| nn.BatchNorm2d(filter) |
| ) for i in range(depth)], |
| nn.AdaptiveAvgPool2d((1,1)), |
| nn.Flatten(), |
| nn.Linear(filter, n_classes) |
| ) |
|
|
| |
| |
| in_channels = 1 |
| filter = 64 |
| depth = 9 |
| kernel_size = 3 |
| patch_size = 5 |
| n_classes = 10 |
|
|
| model = ConvMixer(in_channels, filter, depth, kernel_size, patch_size, n_classes) |
| model_path = 'model.pth' |
|
|
| |
| if os.path.exists(model_path): |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True)) |
| model.eval() |
| else: |
| print(f"Model file '{model_path}' not found. Please place the model file in the same directory.") |
| |
| |
|
|
| |
| def process_audio(selected_label, selected_device): |
| |
| matching_files = [] |
| for filename in os.listdir(audio_dir): |
| if not filename.endswith('.wav'): |
| continue |
| basename = os.path.splitext(filename)[0] |
| parts = basename.split('-') |
| if len(parts) < 6: |
| continue |
| scene, city, x, y, z, device = parts |
| if scene == selected_label and device == selected_device: |
| matching_files.append(filename) |
| if len(matching_files) >= 3: |
| break |
| if not matching_files: |
| return ["No matching audio files found"] * 21 |
|
|
| outputs = [] |
| for audio_file in matching_files: |
| |
| audio_path = os.path.join(audio_dir, audio_file) |
| waveform, sr = torchaudio.load(audio_path) |
| |
| waveform_resampled = resample(waveform) |
| |
| original_audio = (sample_rate, waveform_resampled.squeeze().numpy()) |
| outputs.append(original_audio) |
|
|
| |
| ir = irs[np.random.randint(len(irs))] |
| augmented_waveform = torchaudio.functional.convolve(waveform_resampled, ir)[:, :waveform_resampled.shape[1]] |
| |
| augmented_audio = (sample_rate, augmented_waveform.squeeze().numpy()) |
| outputs.append(augmented_audio) |
|
|
| |
| fig, ax = plt.subplots() |
| ax.plot(waveform_resampled.squeeze().numpy(), label='normal') |
| ax.plot(augmented_waveform.squeeze().numpy(), label='augmented', linestyle='-.', alpha=0.8) |
| ax.set_title(f'Label: {selected_label}') |
| ax.legend() |
| ax.set_xlabel('Time Samples') |
| ax.set_ylabel('Amplitude') |
| buf = io.BytesIO() |
| plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) |
| plt.close(fig) |
| buf.seek(0) |
| waveform_plot_image = Image.open(buf) |
| outputs.append(waveform_plot_image) |
|
|
| |
| mel_spec = mel_spectrogram(augmented_waveform) |
| mel_spec_db = (mel_spec + 1e-5).log() |
| fig, ax = plt.subplots() |
| ax.imshow(mel_spec_db.squeeze().numpy(), origin='lower', aspect='auto') |
| ax.set_title('Mel-Spectrogram') |
| plt.axis('off') |
| buf = io.BytesIO() |
| plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) |
| plt.close(fig) |
| buf.seek(0) |
| mel_spec_image = Image.open(buf) |
| outputs.append(mel_spec_image) |
|
|
| |
| masked_mel_spec = mel_augment(mel_spec_db) |
| fig, ax = plt.subplots() |
| ax.imshow(masked_mel_spec.squeeze().numpy(), origin='lower', aspect='auto') |
| ax.set_title('Frequency and Time Masking') |
| plt.axis('off') |
| buf = io.BytesIO() |
| plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) |
| plt.close(fig) |
| buf.seek(0) |
| masked_mel_spec_image = Image.open(buf) |
| outputs.append(masked_mel_spec_image) |
|
|
| |
| x_mix = mixstyle(masked_mel_spec.unsqueeze(0), p=1.0) |
| fig, ax = plt.subplots() |
| ax.imshow(x_mix.squeeze().numpy(), origin='lower', aspect='auto') |
| ax.set_title('MixStyle') |
| plt.axis('off') |
| buf = io.BytesIO() |
| plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) |
| plt.close(fig) |
| buf.seek(0) |
| mixstyle_image = Image.open(buf) |
| outputs.append(mixstyle_image) |
|
|
| |
| with torch.no_grad(): |
| x = resample(waveform) |
| x = mel_spectrogram(x) |
| x = (x + 1e-5).log().unsqueeze(0) |
| y_hat = model(x) |
| predicted_idx = y_hat.argmax(dim=1).item() |
| predicted_label = label_ids[predicted_idx] |
| outputs.append(f"Predicted Class: {predicted_label}") |
|
|
| |
| total_outputs_needed = 3 * 7 |
| outputs += [""] * (total_outputs_needed - len(outputs)) |
| return outputs |
|
|
|
|
| def gradio_interface(): |
| theme = gr.themes.Base( |
| primary_hue="blue", |
| secondary_hue="blue", |
| neutral_hue="gray" |
| ) |
| theme.set( |
| body_background_fill="*primary_50", |
| body_background_fill_dark="*checkbox_background_color_focus", |
| body_text_color_dark="white", |
| body_text_color="*neutral_800", |
| background_fill_secondary_dark="*checkbox_border_color_hover", |
| block_background_fill="*background_fill_primary", |
| block_background_fill_dark="*neutral_800", |
| block_border_color_dark="*primary_100", |
| block_border_width_dark="4px", |
| block_border_width="4px", |
| block_border_color="*secondary_200" |
| ) |
|
|
| interface = gr.Interface( |
| fn=process_audio, |
| inputs=[ |
| gr.Dropdown(choices=label_ids, label="Select Label"), |
| gr.Dropdown(choices=device_ids, label="Select Device") |
| ], |
| outputs=[ |
| gr.Audio(label="Original Audio 1"), |
| gr.Audio(label="Augmented Audio 1"), |
| gr.Image(label="Waveform Plot 1"), |
| gr.Image(label="Mel-Spectrogram 1"), |
| gr.Image(label="Frequency and Time Masking 1"), |
| gr.Image(label="MixStyle 1"), |
| gr.Textbox(label="Predicted Class 1"), |
| |
| gr.Audio(label="Original Audio 2"), |
| gr.Audio(label="Augmented Audio 2"), |
| gr.Image(label="Waveform Plot 2"), |
| gr.Image(label="Mel-Spectrogram 2"), |
| gr.Image(label="Frequency and Time Masking 2"), |
| gr.Image(label="MixStyle 2"), |
| gr.Textbox(label="Predicted Class 2"), |
|
|
| gr.Audio(label="Original Audio 3"), |
| gr.Audio(label="Augmented Audio 3"), |
| gr.Image(label="Waveform Plot 3"), |
| gr.Image(label="Mel-Spectrogram 3"), |
| gr.Image(label="Frequency and Time Masking 3"), |
| gr.Image(label="MixStyle 3"), |
| gr.Textbox(label="Predicted Class 3") |
| ], |
| title="<h1 style='text-align: center; font-size: 1.5em;'>ASCDomain</h1>", |
| description=""" |
| <div style="font-size: 16px; letter-spacing: 1.2px; line-height: 1.8; text-align: justify;"> |
| <strong>ASCDomain:</strong> Domain Invariant Device-Self-Challenging Isotopic Convolutional Neural Architecture<br> |
| <strong>ASCDomain repository:</strong> <a href="https://github.com/hubtru/ASCDomain" target="_blank">ASCDomain</a><br> |
| Explore different acoustic scenes and mobile devices in our latest model.<br> |
| <strong>Options:</strong> <br> |
| <ul> |
| <li><strong>Acoustic Scene:</strong> Airport, Indoor shopping mall, Metro station, Pedestrian street, Public square, Street with medium level of traffic, Travelling by tram, Travelling by bus, Travelling by underground metro, Urban park</li> |
| <li><strong>Mobile Device:</strong> a, b, c, s1, s2, s3</li> |
| </ul> |
| </div> |
| """, |
| theme=theme, |
| live=True, |
| allow_flagging="never" |
| ) |
| interface.launch() |
|
|
| gradio_interface() |
|
|
|
|