import gradio as gr import os import numpy as np import torch import torch.nn as nn import torchaudio import matplotlib.pyplot as plt import io from PIL import Image # --------------------- Force Dark Mode JavaScript --------------------- # js_func = """ function refresh() { const url = new URL(window.location); if (url.searchParams.get('__theme') !== 'dark') { url.searchParams.set('__theme', 'dark'); window.location.href = url.href; } } """ # ---------------------------------------------------------------------- # # Device and label IDs device_ids = ['a', 'b', 'c', 's1', 's2', 's3'] label_ids = [ 'airport', 'bus', 'metro', 'metro_station', 'park', 'public_square', 'shopping_mall', 'street_pedestrian', 'street_traffic', 'tram' ] # Directories audio_dir = os.path.join('demo', 'audio') ir_dir = os.path.join('demo', 'impulse_responses') ir_names = ['Altec_639.wav', 'Altec_670A.wav', 'Altec_670B.wav'] # Load impulse responses irs = [] for ir_name in ir_names: ir_path = os.path.join(ir_dir, ir_name) ir, _ = torchaudio.load(ir_path) irs.append(ir) # Transforms and spectrogram parameters orig_sample_rate = 44100 sample_rate = 32000 resample = torchaudio.transforms.Resample(orig_freq=orig_sample_rate, new_freq=sample_rate) n_fft = 4096 window_length = 3072 hop_length = 500 n_mels = 256 f_min = 0 f_max = None mel_spectrogram = torchaudio.transforms.MelSpectrogram( sample_rate=sample_rate, n_fft=n_fft, win_length=window_length, hop_length=hop_length, n_mels=n_mels, f_min=f_min, f_max=f_max ) # Augmentation transforms freqm = 48 timem = 0 freq_mask = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=True) time_mask = torchaudio.transforms.TimeMasking(timem, iid_masks=True) mel_augment = torch.nn.Sequential( freq_mask, time_mask ) def mixstyle(x, p=0.4, alpha=0.3, eps=1e-6): """ MixStyle data augmentation. """ if np.random.rand() > p: return x batch_size = x.size(0) f_mu = x.mean(dim=[1, 3], keepdim=True) f_var = x.var(dim=[1, 3], keepdim=True) f_sig = (f_var + eps).sqrt() f_mu, f_sig = f_mu.detach(), f_sig.detach() x_normed = (x - f_mu) / f_sig perm = torch.randperm(batch_size) f_mu_perm, f_sig_perm = f_mu[perm], f_sig[perm] lmda = torch.distributions.Beta(alpha, alpha).sample((batch_size, 1, 1, 1)).to(x.device) mu_mix = f_mu * lmda + f_mu_perm * (1 - lmda) sig_mix = f_sig * lmda + f_sig_perm * (1 - lmda) x = x_normed * sig_mix + mu_mix return x # Model definition class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x): return self.fn(x) + x def ConvMixer(in_channels, filter, depth, kernel_size, patch_size, n_classes): return nn.Sequential( nn.Conv2d(in_channels, filter, kernel_size=patch_size, stride=patch_size), nn.GELU(), nn.BatchNorm2d(filter), *[nn.Sequential( Residual(nn.Sequential( nn.Conv2d(filter, filter, kernel_size, groups=filter, padding="same"), nn.GELU(), nn.BatchNorm2d(filter) )), nn.Conv2d(filter, filter, kernel_size=1), nn.GELU(), nn.BatchNorm2d(filter) ) for i in range(depth)], nn.AdaptiveAvgPool2d((1,1)), nn.Flatten(), nn.Linear(filter, n_classes) ) # Instantiate and load the model in_channels = 1 filter = 64 depth = 9 kernel_size = 3 patch_size = 5 n_classes = 10 model = ConvMixer(in_channels, filter, depth, kernel_size, patch_size, n_classes) model_path = 'model.pth' if os.path.exists(model_path): model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True)) model.eval() else: print(f"Model file '{model_path}' not found. Please place the model file in the same directory.") def process_audio(selected_label, selected_device): # Find matching audio files based on selected label and device matching_files = [] for filename in os.listdir(audio_dir): if not filename.endswith('.wav'): continue basename = os.path.splitext(filename)[0] parts = basename.split('-') if len(parts) < 6: continue scene, city, x, y, z, device = parts if scene == selected_label and device == selected_device: matching_files.append(filename) if len(matching_files) >= 3: break # If no matching files, return placeholders if not matching_files: return ["No matching audio files found"] * 21 # 21 outputs outputs = [] for audio_file in matching_files: # Load and resample audio audio_path = os.path.join(audio_dir, audio_file) waveform, sr = torchaudio.load(audio_path) waveform_resampled = resample(waveform) original_audio = (sample_rate, waveform_resampled.squeeze().numpy()) outputs.append(original_audio) # Augment audio with a random impulse response ir = irs[np.random.randint(len(irs))] augmented_waveform = torchaudio.functional.convolve(waveform_resampled, ir) # Ensure augmented waveform isn't longer than the original augmented_waveform = augmented_waveform[:, :waveform_resampled.shape[1]] augmented_audio = (sample_rate, augmented_waveform.squeeze().numpy()) outputs.append(augmented_audio) # 1) Waveform Plot fig, ax = plt.subplots(figsize=(10, 3)) ax.plot(waveform_resampled.squeeze().numpy(), label='normal') ax.plot(augmented_waveform.squeeze().numpy(), label='augmented', linestyle='-.', alpha=0.8) ax.set_title(f'Label: {selected_label}') ax.legend() ax.set_xlabel('Time Samples') ax.set_ylabel('Amplitude') buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=100) plt.close(fig) buf.seek(0) waveform_plot_image = Image.open(buf) outputs.append(waveform_plot_image) # 2) Mel-Spectrogram (no text, no axes) mel_spec = mel_spectrogram(augmented_waveform) mel_spec_db = (mel_spec + 1e-5).log() fig, ax = plt.subplots(figsize=(1, 4), dpi=100) ax.imshow(mel_spec_db.squeeze().numpy(), origin='lower', aspect='auto', interpolation='nearest') plt.axis('off') buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) plt.close(fig) buf.seek(0) mel_spec_image = Image.open(buf) outputs.append(mel_spec_image) # 3) Frequency & Time Masking (no text, no axes) masked_mel_spec = mel_augment(mel_spec_db) fig, ax = plt.subplots(figsize=(1, 4), dpi=100) ax.imshow(masked_mel_spec.squeeze().numpy(), origin='lower', aspect='auto', interpolation='nearest') plt.axis('off') buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) plt.close(fig) buf.seek(0) masked_mel_spec_image = Image.open(buf) outputs.append(masked_mel_spec_image) # 4) MixStyle (no text, no axes) x_mix = mixstyle(masked_mel_spec.unsqueeze(0), p=1.0) fig, ax = plt.subplots(figsize=(1, 4), dpi=100) ax.imshow(x_mix.squeeze().numpy(), origin='lower', aspect='auto', interpolation='nearest') plt.axis('off') buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) plt.close(fig) buf.seek(0) mixstyle_image = Image.open(buf) outputs.append(mixstyle_image) # 5) Model Prediction with torch.no_grad(): x = resample(waveform) x = mel_spectrogram(x) x = (x + 1e-5).log().unsqueeze(0) y_hat = model(x) predicted_idx = y_hat.argmax(dim=1).item() predicted_label = label_ids[predicted_idx] outputs.append(f"{predicted_label}") # Pad outputs if fewer than 3 audio files are found total_outputs_needed = 3 * 7 # 3 files * 7 outputs/file outputs += [""] * (total_outputs_needed - len(outputs)) return outputs # --------------------- The Main Gradio Demo --------------------- # with gr.Blocks(js=js_func) as demo: gr.Markdown("# ASCDomain") gr.Markdown(""" **ASCDomain:** Domain Invariant Device-Self-Challenging Isotopic Convolutional Neural Architecture [ASCDomain repository](https://github.com/hubtru/ASCDomain) Explore different acoustic scenes and mobile devices in our latest model. **Options:** - **Acoustic Scene:** Airport, Indoor shopping mall, Metro station, Pedestrian street, Public square, Street with medium level of traffic, Travelling by tram, Travelling by bus, Travelling by underground metro, Urban park - **Mobile Device:** a, b, c, s1, s2, s3 """) with gr.Row(): # Left column: dropdowns with gr.Column(): label_dropdown = gr.Dropdown( choices=label_ids, label="Select Label", interactive=True ) device_dropdown = gr.Dropdown( choices=device_ids, label="Select Device", interactive=True ) # Right column: outputs with gr.Column(): outputs = [ gr.Audio(label="Original Audio 1"), gr.Audio(label="Augmented Audio 1"), gr.Image(label="Waveform Plot 1"), gr.Image(label="Mel-Spectrogram 1"), gr.Image(label="Frequency and Time Masking 1"), gr.Image(label="MixStyle 1"), gr.Textbox(label="Predicted Class 1"), gr.Audio(label="Original Audio 2"), gr.Audio(label="Augmented Audio 2"), gr.Image(label="Waveform Plot 2"), gr.Image(label="Mel-Spectrogram 2"), gr.Image(label="Frequency and Time Masking 2"), gr.Image(label="MixStyle 2"), gr.Textbox(label="Predicted Class 2"), gr.Audio(label="Original Audio 3"), gr.Audio(label="Augmented Audio 3"), gr.Image(label="Waveform Plot 3"), gr.Image(label="Mel-Spectrogram 3"), gr.Image(label="Frequency and Time Masking 3"), gr.Image(label="MixStyle 3"), gr.Textbox(label="Predicted Class 3"), ] # Trigger process_audio automatically when either dropdown changes label_dropdown.change( fn=process_audio, inputs=[label_dropdown, device_dropdown], outputs=outputs ) device_dropdown.change( fn=process_audio, inputs=[label_dropdown, device_dropdown], outputs=outputs ) demo.launch()