import gradio as gr import os import numpy as np import torch import torch.nn as nn import torchaudio import matplotlib.pyplot as plt import io from PIL import Image # Device and label IDs device_ids = ['a', 'b', 'c', 's1', 's2', 's3'] label_ids = ['airport', 'bus', 'metro', 'metro_station', 'park', 'public_square', 'shopping_mall', 'street_pedestrian', 'street_traffic', 'tram'] # Directories audio_dir = os.path.join('demo', 'audio') ir_dir = os.path.join('demo', 'impulse_responses') ir_names = ['Altec_639.wav', 'Altec_670A.wav', 'Altec_670B.wav'] # Load impulse response files irs = [] for ir_name in ir_names: ir_path = os.path.join(ir_dir, ir_name) ir, _ = torchaudio.load(ir_path) irs.append(ir) # Resampling and other transforms orig_sample_rate = 44100 sample_rate = 32000 resample = torchaudio.transforms.Resample( orig_freq=orig_sample_rate, new_freq=sample_rate ) n_fft = 4096 window_length = 3072 hop_length = 500 n_mels = 256 f_min = 0 f_max = None mel_spectrogram = torchaudio.transforms.MelSpectrogram( sample_rate=sample_rate, n_fft=n_fft, win_length=window_length, hop_length=hop_length, n_mels=n_mels, f_min=f_min, f_max=f_max ) freqm = 48 timem = 0 freq_mask = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=True) time_mask = torchaudio.transforms.TimeMasking(timem, iid_masks=True) mel_augment = torch.nn.Sequential( freq_mask, time_mask ) # Mixstyle function def mixstyle(x, p=0.4, alpha=0.3, eps=1e-6): if np.random.rand() > p: return x batch_size = x.size(0) f_mu = x.mean(dim=[1, 3], keepdim=True) f_var = x.var(dim=[1, 3], keepdim=True) f_sig = (f_var + eps).sqrt() f_mu, f_sig = f_mu.detach(), f_sig.detach() x_normed = (x - f_mu) / f_sig perm = torch.randperm(batch_size) f_mu_perm, f_sig_perm = f_mu[perm], f_sig[perm] lmda = torch.distributions.Beta(alpha, alpha).sample((batch_size, 1, 1, 1)) lmda = lmda.to(x.device) mu_mix = f_mu * lmda + f_mu_perm * (1 - lmda) sig_mix = f_sig * lmda + f_sig_perm * (1 - lmda) x = x_normed * sig_mix + mu_mix return x # Model definition class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x): return self.fn(x) + x def ConvMixer(in_channels, filter, depth, kernel_size, patch_size, n_classes): return nn.Sequential( nn.Conv2d(in_channels, filter, kernel_size=patch_size, stride=patch_size), nn.GELU(), nn.BatchNorm2d(filter), *[nn.Sequential( Residual(nn.Sequential( nn.Conv2d(filter, filter, kernel_size, groups=filter, padding="same"), nn.GELU(), nn.BatchNorm2d(filter) )), nn.Conv2d(filter, filter, kernel_size=1), nn.GELU(), nn.BatchNorm2d(filter) ) for i in range(depth)], nn.AdaptiveAvgPool2d((1,1)), nn.Flatten(), nn.Linear(filter, n_classes) ) # Instantiate and load the model # Model parameters (should match those used during training) in_channels = 1 filter = 64 depth = 9 kernel_size = 3 patch_size = 5 n_classes = 10 model = ConvMixer(in_channels, filter, depth, kernel_size, patch_size, n_classes) model_path = 'model.pth' # Path to the saved model weights # Load the model weights if os.path.exists(model_path): model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True)) model.eval() else: print(f"Model file '{model_path}' not found. Please place the model file in the same directory.") # Optionally, you can raise an exception or exit # raise FileNotFoundError(f"Model file '{model_path}' not found.") # Function to process audio and generate outputs def process_audio(selected_label, selected_device): # Find matching audio files matching_files = [] for filename in os.listdir(audio_dir): if not filename.endswith('.wav'): continue basename = os.path.splitext(filename)[0] parts = basename.split('-') if len(parts) < 6: continue scene, city, x, y, z, device = parts if scene == selected_label and device == selected_device: matching_files.append(filename) if len(matching_files) >= 3: break if not matching_files: return ["No matching audio files found"] * 21 # 21 outputs now outputs = [] for audio_file in matching_files: # Load original audio audio_path = os.path.join(audio_dir, audio_file) waveform, sr = torchaudio.load(audio_path) # Resample waveform_resampled = resample(waveform) # Original audio player original_audio = (sample_rate, waveform_resampled.squeeze().numpy()) outputs.append(original_audio) # Augment audio (apply impulse response) ir = irs[np.random.randint(len(irs))] augmented_waveform = torchaudio.functional.convolve(waveform_resampled, ir)[:, :waveform_resampled.shape[1]] # Augmented audio player augmented_audio = (sample_rate, augmented_waveform.squeeze().numpy()) outputs.append(augmented_audio) # **Waveform plot of original vs augmented** fig, ax = plt.subplots() ax.plot(waveform_resampled.squeeze().numpy(), label='normal') ax.plot(augmented_waveform.squeeze().numpy(), label='augmented', linestyle='-.', alpha=0.8) ax.set_title(f'Label: {selected_label}') ax.legend() ax.set_xlabel('Time Samples') ax.set_ylabel('Amplitude') buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) plt.close(fig) buf.seek(0) waveform_plot_image = Image.open(buf) outputs.append(waveform_plot_image) # Mel-Spectrogram mel_spec = mel_spectrogram(augmented_waveform) mel_spec_db = (mel_spec + 1e-5).log() fig, ax = plt.subplots() ax.imshow(mel_spec_db.squeeze().numpy(), origin='lower', aspect='auto') ax.set_title('Mel-Spectrogram') plt.axis('off') buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) plt.close(fig) buf.seek(0) mel_spec_image = Image.open(buf) outputs.append(mel_spec_image) # Frequency and Time Masking masked_mel_spec = mel_augment(mel_spec_db) fig, ax = plt.subplots() ax.imshow(masked_mel_spec.squeeze().numpy(), origin='lower', aspect='auto') ax.set_title('Frequency and Time Masking') plt.axis('off') buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) plt.close(fig) buf.seek(0) masked_mel_spec_image = Image.open(buf) outputs.append(masked_mel_spec_image) # MixStyle Visualization x_mix = mixstyle(masked_mel_spec.unsqueeze(0), p=1.0) fig, ax = plt.subplots() ax.imshow(x_mix.squeeze().numpy(), origin='lower', aspect='auto') ax.set_title('MixStyle') plt.axis('off') buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) plt.close(fig) buf.seek(0) mixstyle_image = Image.open(buf) outputs.append(mixstyle_image) # Model Prediction with torch.no_grad(): x = resample(waveform) x = mel_spectrogram(x) x = (x + 1e-5).log().unsqueeze(0) y_hat = model(x) predicted_idx = y_hat.argmax(dim=1).item() predicted_label = label_ids[predicted_idx] outputs.append(f"Predicted Class: {predicted_label}") # If less than 3 files, pad the outputs total_outputs_needed = 3 * 7 # 3 files * 7 outputs per file outputs += [""] * (total_outputs_needed - len(outputs)) return outputs def gradio_interface(): theme = gr.themes.Base( primary_hue="blue", secondary_hue="blue", neutral_hue="gray" ) theme.set( body_background_fill="*primary_50", body_background_fill_dark="*checkbox_background_color_focus", body_text_color_dark="white", body_text_color="*neutral_800", background_fill_secondary_dark="*checkbox_border_color_hover", block_background_fill="*background_fill_primary", block_background_fill_dark="*neutral_800", block_border_color_dark="*primary_100", block_border_width_dark="4px", block_border_width="4px", block_border_color="*secondary_200" ) interface = gr.Interface( fn=process_audio, inputs=[ gr.Dropdown(choices=label_ids, label="Select Label"), gr.Dropdown(choices=device_ids, label="Select Device") ], outputs=[ gr.Audio(label="Original Audio 1"), gr.Audio(label="Augmented Audio 1"), gr.Image(label="Waveform Plot 1"), gr.Image(label="Mel-Spectrogram 1"), gr.Image(label="Frequency and Time Masking 1"), gr.Image(label="MixStyle 1"), gr.Textbox(label="Predicted Class 1"), gr.Audio(label="Original Audio 2"), gr.Audio(label="Augmented Audio 2"), gr.Image(label="Waveform Plot 2"), gr.Image(label="Mel-Spectrogram 2"), gr.Image(label="Frequency and Time Masking 2"), gr.Image(label="MixStyle 2"), gr.Textbox(label="Predicted Class 2"), gr.Audio(label="Original Audio 3"), gr.Audio(label="Augmented Audio 3"), gr.Image(label="Waveform Plot 3"), gr.Image(label="Mel-Spectrogram 3"), gr.Image(label="Frequency and Time Masking 3"), gr.Image(label="MixStyle 3"), gr.Textbox(label="Predicted Class 3") ], title="

ASCDomain

", description="""
ASCDomain: Domain Invariant Device-Self-Challenging Isotopic Convolutional Neural Architecture
ASCDomain repository: ASCDomain
Explore different acoustic scenes and mobile devices in our latest model.
Options:
""", theme=theme, live=True, allow_flagging="never" ) interface.launch() gradio_interface()