Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torchaudio | |
| import matplotlib.pyplot as plt | |
| import io | |
| from PIL import Image | |
| # --------------------- Force Dark Mode JavaScript --------------------- # | |
| js_func = """ | |
| function refresh() { | |
| const url = new URL(window.location); | |
| if (url.searchParams.get('__theme') !== 'dark') { | |
| url.searchParams.set('__theme', 'dark'); | |
| window.location.href = url.href; | |
| } | |
| } | |
| """ | |
| # ---------------------------------------------------------------------- # | |
| # Device and label IDs | |
| device_ids = ['a', 'b', 'c', 's1', 's2', 's3'] | |
| label_ids = [ | |
| 'airport', 'bus', 'metro', 'metro_station', 'park', | |
| 'public_square', 'shopping_mall', 'street_pedestrian', | |
| 'street_traffic', 'tram' | |
| ] | |
| # Directories | |
| audio_dir = os.path.join('demo', 'audio') | |
| ir_dir = os.path.join('demo', 'impulse_responses') | |
| ir_names = ['Altec_639.wav', 'Altec_670A.wav', 'Altec_670B.wav'] | |
| # Load impulse responses | |
| irs = [] | |
| for ir_name in ir_names: | |
| ir_path = os.path.join(ir_dir, ir_name) | |
| ir, _ = torchaudio.load(ir_path) | |
| irs.append(ir) | |
| # Transforms and spectrogram parameters | |
| orig_sample_rate = 44100 | |
| sample_rate = 32000 | |
| resample = torchaudio.transforms.Resample(orig_freq=orig_sample_rate, new_freq=sample_rate) | |
| n_fft = 4096 | |
| window_length = 3072 | |
| hop_length = 500 | |
| n_mels = 256 | |
| f_min = 0 | |
| f_max = None | |
| mel_spectrogram = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=sample_rate, | |
| n_fft=n_fft, | |
| win_length=window_length, | |
| hop_length=hop_length, | |
| n_mels=n_mels, | |
| f_min=f_min, | |
| f_max=f_max | |
| ) | |
| # Augmentation transforms | |
| freqm = 48 | |
| timem = 0 | |
| freq_mask = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=True) | |
| time_mask = torchaudio.transforms.TimeMasking(timem, iid_masks=True) | |
| mel_augment = torch.nn.Sequential( | |
| freq_mask, | |
| time_mask | |
| ) | |
| def mixstyle(x, p=0.4, alpha=0.3, eps=1e-6): | |
| """ | |
| MixStyle data augmentation. | |
| """ | |
| if np.random.rand() > p: | |
| return x | |
| batch_size = x.size(0) | |
| f_mu = x.mean(dim=[1, 3], keepdim=True) | |
| f_var = x.var(dim=[1, 3], keepdim=True) | |
| f_sig = (f_var + eps).sqrt() | |
| f_mu, f_sig = f_mu.detach(), f_sig.detach() | |
| x_normed = (x - f_mu) / f_sig | |
| perm = torch.randperm(batch_size) | |
| f_mu_perm, f_sig_perm = f_mu[perm], f_sig[perm] | |
| lmda = torch.distributions.Beta(alpha, alpha).sample((batch_size, 1, 1, 1)).to(x.device) | |
| mu_mix = f_mu * lmda + f_mu_perm * (1 - lmda) | |
| sig_mix = f_sig * lmda + f_sig_perm * (1 - lmda) | |
| x = x_normed * sig_mix + mu_mix | |
| return x | |
| # Model definition | |
| class Residual(nn.Module): | |
| def __init__(self, fn): | |
| super().__init__() | |
| self.fn = fn | |
| def forward(self, x): | |
| return self.fn(x) + x | |
| def ConvMixer(in_channels, filter, depth, kernel_size, patch_size, n_classes): | |
| return nn.Sequential( | |
| nn.Conv2d(in_channels, filter, kernel_size=patch_size, stride=patch_size), | |
| nn.GELU(), | |
| nn.BatchNorm2d(filter), | |
| *[nn.Sequential( | |
| Residual(nn.Sequential( | |
| nn.Conv2d(filter, filter, kernel_size, groups=filter, padding="same"), | |
| nn.GELU(), | |
| nn.BatchNorm2d(filter) | |
| )), | |
| nn.Conv2d(filter, filter, kernel_size=1), | |
| nn.GELU(), | |
| nn.BatchNorm2d(filter) | |
| ) for i in range(depth)], | |
| nn.AdaptiveAvgPool2d((1,1)), | |
| nn.Flatten(), | |
| nn.Linear(filter, n_classes) | |
| ) | |
| # Instantiate and load the model | |
| in_channels = 1 | |
| filter = 64 | |
| depth = 9 | |
| kernel_size = 3 | |
| patch_size = 5 | |
| n_classes = 10 | |
| model = ConvMixer(in_channels, filter, depth, kernel_size, patch_size, n_classes) | |
| model_path = 'model.pth' | |
| if os.path.exists(model_path): | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True)) | |
| model.eval() | |
| else: | |
| print(f"Model file '{model_path}' not found. Please place the model file in the same directory.") | |
| def process_audio(selected_label, selected_device): | |
| # Find matching audio files based on selected label and device | |
| matching_files = [] | |
| for filename in os.listdir(audio_dir): | |
| if not filename.endswith('.wav'): | |
| continue | |
| basename = os.path.splitext(filename)[0] | |
| parts = basename.split('-') | |
| if len(parts) < 6: | |
| continue | |
| scene, city, x, y, z, device = parts | |
| if scene == selected_label and device == selected_device: | |
| matching_files.append(filename) | |
| if len(matching_files) >= 3: | |
| break | |
| # If no matching files, return placeholders | |
| if not matching_files: | |
| return ["No matching audio files found"] * 21 # 21 outputs | |
| outputs = [] | |
| for audio_file in matching_files: | |
| # Load and resample audio | |
| audio_path = os.path.join(audio_dir, audio_file) | |
| waveform, sr = torchaudio.load(audio_path) | |
| waveform_resampled = resample(waveform) | |
| original_audio = (sample_rate, waveform_resampled.squeeze().numpy()) | |
| outputs.append(original_audio) | |
| # Augment audio with a random impulse response | |
| ir = irs[np.random.randint(len(irs))] | |
| augmented_waveform = torchaudio.functional.convolve(waveform_resampled, ir) | |
| # Ensure augmented waveform isn't longer than the original | |
| augmented_waveform = augmented_waveform[:, :waveform_resampled.shape[1]] | |
| augmented_audio = (sample_rate, augmented_waveform.squeeze().numpy()) | |
| outputs.append(augmented_audio) | |
| # 1) Waveform Plot | |
| fig, ax = plt.subplots(figsize=(10, 3)) | |
| ax.plot(waveform_resampled.squeeze().numpy(), label='normal') | |
| ax.plot(augmented_waveform.squeeze().numpy(), label='augmented', linestyle='-.', alpha=0.8) | |
| ax.set_title(f'Label: {selected_label}') | |
| ax.legend() | |
| ax.set_xlabel('Time Samples') | |
| ax.set_ylabel('Amplitude') | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=100) | |
| plt.close(fig) | |
| buf.seek(0) | |
| waveform_plot_image = Image.open(buf) | |
| outputs.append(waveform_plot_image) | |
| # 2) Mel-Spectrogram (no text, no axes) | |
| mel_spec = mel_spectrogram(augmented_waveform) | |
| mel_spec_db = (mel_spec + 1e-5).log() | |
| fig, ax = plt.subplots(figsize=(1, 4), dpi=100) | |
| ax.imshow(mel_spec_db.squeeze().numpy(), origin='lower', aspect='auto', interpolation='nearest') | |
| plt.axis('off') | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) | |
| plt.close(fig) | |
| buf.seek(0) | |
| mel_spec_image = Image.open(buf) | |
| outputs.append(mel_spec_image) | |
| # 3) Frequency & Time Masking (no text, no axes) | |
| masked_mel_spec = mel_augment(mel_spec_db) | |
| fig, ax = plt.subplots(figsize=(1, 4), dpi=100) | |
| ax.imshow(masked_mel_spec.squeeze().numpy(), origin='lower', aspect='auto', interpolation='nearest') | |
| plt.axis('off') | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) | |
| plt.close(fig) | |
| buf.seek(0) | |
| masked_mel_spec_image = Image.open(buf) | |
| outputs.append(masked_mel_spec_image) | |
| # 4) MixStyle (no text, no axes) | |
| x_mix = mixstyle(masked_mel_spec.unsqueeze(0), p=1.0) | |
| fig, ax = plt.subplots(figsize=(1, 4), dpi=100) | |
| ax.imshow(x_mix.squeeze().numpy(), origin='lower', aspect='auto', interpolation='nearest') | |
| plt.axis('off') | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) | |
| plt.close(fig) | |
| buf.seek(0) | |
| mixstyle_image = Image.open(buf) | |
| outputs.append(mixstyle_image) | |
| # 5) Model Prediction | |
| with torch.no_grad(): | |
| x = resample(waveform) | |
| x = mel_spectrogram(x) | |
| x = (x + 1e-5).log().unsqueeze(0) | |
| y_hat = model(x) | |
| predicted_idx = y_hat.argmax(dim=1).item() | |
| predicted_label = label_ids[predicted_idx] | |
| outputs.append(f"{predicted_label}") | |
| # Pad outputs if fewer than 3 audio files are found | |
| total_outputs_needed = 3 * 7 # 3 files * 7 outputs/file | |
| outputs += [""] * (total_outputs_needed - len(outputs)) | |
| return outputs | |
| # --------------------- The Main Gradio Demo --------------------- # | |
| with gr.Blocks(js=js_func) as demo: | |
| gr.Markdown("# ASCDomain") | |
| gr.Markdown(""" | |
| **ASCDomain:** Domain Invariant Device-Self-Challenging Isotopic Convolutional Neural Architecture | |
| [ASCDomain repository](https://github.com/hubtru/ASCDomain) | |
| Explore different acoustic scenes and mobile devices in our latest model. | |
| **Options:** | |
| - **Acoustic Scene:** Airport, Indoor shopping mall, Metro station, Pedestrian street, Public square, Street with medium level of traffic, Travelling by tram, Travelling by bus, Travelling by underground metro, Urban park | |
| - **Mobile Device:** a, b, c, s1, s2, s3 | |
| """) | |
| with gr.Row(): | |
| # Left column: dropdowns | |
| with gr.Column(): | |
| label_dropdown = gr.Dropdown( | |
| choices=label_ids, | |
| label="Select Label", | |
| interactive=True | |
| ) | |
| device_dropdown = gr.Dropdown( | |
| choices=device_ids, | |
| label="Select Device", | |
| interactive=True | |
| ) | |
| # Right column: outputs | |
| with gr.Column(): | |
| outputs = [ | |
| gr.Audio(label="Original Audio 1"), | |
| gr.Audio(label="Augmented Audio 1"), | |
| gr.Image(label="Waveform Plot 1"), | |
| gr.Image(label="Mel-Spectrogram 1"), | |
| gr.Image(label="Frequency and Time Masking 1"), | |
| gr.Image(label="MixStyle 1"), | |
| gr.Textbox(label="Predicted Class 1"), | |
| gr.Audio(label="Original Audio 2"), | |
| gr.Audio(label="Augmented Audio 2"), | |
| gr.Image(label="Waveform Plot 2"), | |
| gr.Image(label="Mel-Spectrogram 2"), | |
| gr.Image(label="Frequency and Time Masking 2"), | |
| gr.Image(label="MixStyle 2"), | |
| gr.Textbox(label="Predicted Class 2"), | |
| gr.Audio(label="Original Audio 3"), | |
| gr.Audio(label="Augmented Audio 3"), | |
| gr.Image(label="Waveform Plot 3"), | |
| gr.Image(label="Mel-Spectrogram 3"), | |
| gr.Image(label="Frequency and Time Masking 3"), | |
| gr.Image(label="MixStyle 3"), | |
| gr.Textbox(label="Predicted Class 3"), | |
| ] | |
| # Trigger process_audio automatically when either dropdown changes | |
| label_dropdown.change( | |
| fn=process_audio, | |
| inputs=[label_dropdown, device_dropdown], | |
| outputs=outputs | |
| ) | |
| device_dropdown.change( | |
| fn=process_audio, | |
| inputs=[label_dropdown, device_dropdown], | |
| outputs=outputs | |
| ) | |
| demo.launch() | |