ASCDomain-d / app.py
hubtru's picture
Update app.py
b64e169 verified
import gradio as gr
import os
import numpy as np
import torch
import torch.nn as nn
import torchaudio
import matplotlib.pyplot as plt
import io
from PIL import Image
# --------------------- Force Dark Mode JavaScript --------------------- #
js_func = """
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'dark') {
url.searchParams.set('__theme', 'dark');
window.location.href = url.href;
}
}
"""
# ---------------------------------------------------------------------- #
# Device and label IDs
device_ids = ['a', 'b', 'c', 's1', 's2', 's3']
label_ids = [
'airport', 'bus', 'metro', 'metro_station', 'park',
'public_square', 'shopping_mall', 'street_pedestrian',
'street_traffic', 'tram'
]
# Directories
audio_dir = os.path.join('demo', 'audio')
ir_dir = os.path.join('demo', 'impulse_responses')
ir_names = ['Altec_639.wav', 'Altec_670A.wav', 'Altec_670B.wav']
# Load impulse responses
irs = []
for ir_name in ir_names:
ir_path = os.path.join(ir_dir, ir_name)
ir, _ = torchaudio.load(ir_path)
irs.append(ir)
# Transforms and spectrogram parameters
orig_sample_rate = 44100
sample_rate = 32000
resample = torchaudio.transforms.Resample(orig_freq=orig_sample_rate, new_freq=sample_rate)
n_fft = 4096
window_length = 3072
hop_length = 500
n_mels = 256
f_min = 0
f_max = None
mel_spectrogram = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
win_length=window_length,
hop_length=hop_length,
n_mels=n_mels,
f_min=f_min,
f_max=f_max
)
# Augmentation transforms
freqm = 48
timem = 0
freq_mask = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=True)
time_mask = torchaudio.transforms.TimeMasking(timem, iid_masks=True)
mel_augment = torch.nn.Sequential(
freq_mask,
time_mask
)
def mixstyle(x, p=0.4, alpha=0.3, eps=1e-6):
"""
MixStyle data augmentation.
"""
if np.random.rand() > p:
return x
batch_size = x.size(0)
f_mu = x.mean(dim=[1, 3], keepdim=True)
f_var = x.var(dim=[1, 3], keepdim=True)
f_sig = (f_var + eps).sqrt()
f_mu, f_sig = f_mu.detach(), f_sig.detach()
x_normed = (x - f_mu) / f_sig
perm = torch.randperm(batch_size)
f_mu_perm, f_sig_perm = f_mu[perm], f_sig[perm]
lmda = torch.distributions.Beta(alpha, alpha).sample((batch_size, 1, 1, 1)).to(x.device)
mu_mix = f_mu * lmda + f_mu_perm * (1 - lmda)
sig_mix = f_sig * lmda + f_sig_perm * (1 - lmda)
x = x_normed * sig_mix + mu_mix
return x
# Model definition
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
def ConvMixer(in_channels, filter, depth, kernel_size, patch_size, n_classes):
return nn.Sequential(
nn.Conv2d(in_channels, filter, kernel_size=patch_size, stride=patch_size),
nn.GELU(),
nn.BatchNorm2d(filter),
*[nn.Sequential(
Residual(nn.Sequential(
nn.Conv2d(filter, filter, kernel_size, groups=filter, padding="same"),
nn.GELU(),
nn.BatchNorm2d(filter)
)),
nn.Conv2d(filter, filter, kernel_size=1),
nn.GELU(),
nn.BatchNorm2d(filter)
) for i in range(depth)],
nn.AdaptiveAvgPool2d((1,1)),
nn.Flatten(),
nn.Linear(filter, n_classes)
)
# Instantiate and load the model
in_channels = 1
filter = 64
depth = 9
kernel_size = 3
patch_size = 5
n_classes = 10
model = ConvMixer(in_channels, filter, depth, kernel_size, patch_size, n_classes)
model_path = 'model.pth'
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True))
model.eval()
else:
print(f"Model file '{model_path}' not found. Please place the model file in the same directory.")
def process_audio(selected_label, selected_device):
# Find matching audio files based on selected label and device
matching_files = []
for filename in os.listdir(audio_dir):
if not filename.endswith('.wav'):
continue
basename = os.path.splitext(filename)[0]
parts = basename.split('-')
if len(parts) < 6:
continue
scene, city, x, y, z, device = parts
if scene == selected_label and device == selected_device:
matching_files.append(filename)
if len(matching_files) >= 3:
break
# If no matching files, return placeholders
if not matching_files:
return ["No matching audio files found"] * 21 # 21 outputs
outputs = []
for audio_file in matching_files:
# Load and resample audio
audio_path = os.path.join(audio_dir, audio_file)
waveform, sr = torchaudio.load(audio_path)
waveform_resampled = resample(waveform)
original_audio = (sample_rate, waveform_resampled.squeeze().numpy())
outputs.append(original_audio)
# Augment audio with a random impulse response
ir = irs[np.random.randint(len(irs))]
augmented_waveform = torchaudio.functional.convolve(waveform_resampled, ir)
# Ensure augmented waveform isn't longer than the original
augmented_waveform = augmented_waveform[:, :waveform_resampled.shape[1]]
augmented_audio = (sample_rate, augmented_waveform.squeeze().numpy())
outputs.append(augmented_audio)
# 1) Waveform Plot
fig, ax = plt.subplots(figsize=(10, 3))
ax.plot(waveform_resampled.squeeze().numpy(), label='normal')
ax.plot(augmented_waveform.squeeze().numpy(), label='augmented', linestyle='-.', alpha=0.8)
ax.set_title(f'Label: {selected_label}')
ax.legend()
ax.set_xlabel('Time Samples')
ax.set_ylabel('Amplitude')
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=100)
plt.close(fig)
buf.seek(0)
waveform_plot_image = Image.open(buf)
outputs.append(waveform_plot_image)
# 2) Mel-Spectrogram (no text, no axes)
mel_spec = mel_spectrogram(augmented_waveform)
mel_spec_db = (mel_spec + 1e-5).log()
fig, ax = plt.subplots(figsize=(1, 4), dpi=100)
ax.imshow(mel_spec_db.squeeze().numpy(), origin='lower', aspect='auto', interpolation='nearest')
plt.axis('off')
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
plt.close(fig)
buf.seek(0)
mel_spec_image = Image.open(buf)
outputs.append(mel_spec_image)
# 3) Frequency & Time Masking (no text, no axes)
masked_mel_spec = mel_augment(mel_spec_db)
fig, ax = plt.subplots(figsize=(1, 4), dpi=100)
ax.imshow(masked_mel_spec.squeeze().numpy(), origin='lower', aspect='auto', interpolation='nearest')
plt.axis('off')
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
plt.close(fig)
buf.seek(0)
masked_mel_spec_image = Image.open(buf)
outputs.append(masked_mel_spec_image)
# 4) MixStyle (no text, no axes)
x_mix = mixstyle(masked_mel_spec.unsqueeze(0), p=1.0)
fig, ax = plt.subplots(figsize=(1, 4), dpi=100)
ax.imshow(x_mix.squeeze().numpy(), origin='lower', aspect='auto', interpolation='nearest')
plt.axis('off')
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
plt.close(fig)
buf.seek(0)
mixstyle_image = Image.open(buf)
outputs.append(mixstyle_image)
# 5) Model Prediction
with torch.no_grad():
x = resample(waveform)
x = mel_spectrogram(x)
x = (x + 1e-5).log().unsqueeze(0)
y_hat = model(x)
predicted_idx = y_hat.argmax(dim=1).item()
predicted_label = label_ids[predicted_idx]
outputs.append(f"{predicted_label}")
# Pad outputs if fewer than 3 audio files are found
total_outputs_needed = 3 * 7 # 3 files * 7 outputs/file
outputs += [""] * (total_outputs_needed - len(outputs))
return outputs
# --------------------- The Main Gradio Demo --------------------- #
with gr.Blocks(js=js_func) as demo:
gr.Markdown("# ASCDomain")
gr.Markdown("""
**ASCDomain:** Domain Invariant Device-Self-Challenging Isotopic Convolutional Neural Architecture
[ASCDomain repository](https://github.com/hubtru/ASCDomain)
Explore different acoustic scenes and mobile devices in our latest model.
**Options:**
- **Acoustic Scene:** Airport, Indoor shopping mall, Metro station, Pedestrian street, Public square, Street with medium level of traffic, Travelling by tram, Travelling by bus, Travelling by underground metro, Urban park
- **Mobile Device:** a, b, c, s1, s2, s3
""")
with gr.Row():
# Left column: dropdowns
with gr.Column():
label_dropdown = gr.Dropdown(
choices=label_ids,
label="Select Label",
interactive=True
)
device_dropdown = gr.Dropdown(
choices=device_ids,
label="Select Device",
interactive=True
)
# Right column: outputs
with gr.Column():
outputs = [
gr.Audio(label="Original Audio 1"),
gr.Audio(label="Augmented Audio 1"),
gr.Image(label="Waveform Plot 1"),
gr.Image(label="Mel-Spectrogram 1"),
gr.Image(label="Frequency and Time Masking 1"),
gr.Image(label="MixStyle 1"),
gr.Textbox(label="Predicted Class 1"),
gr.Audio(label="Original Audio 2"),
gr.Audio(label="Augmented Audio 2"),
gr.Image(label="Waveform Plot 2"),
gr.Image(label="Mel-Spectrogram 2"),
gr.Image(label="Frequency and Time Masking 2"),
gr.Image(label="MixStyle 2"),
gr.Textbox(label="Predicted Class 2"),
gr.Audio(label="Original Audio 3"),
gr.Audio(label="Augmented Audio 3"),
gr.Image(label="Waveform Plot 3"),
gr.Image(label="Mel-Spectrogram 3"),
gr.Image(label="Frequency and Time Masking 3"),
gr.Image(label="MixStyle 3"),
gr.Textbox(label="Predicted Class 3"),
]
# Trigger process_audio automatically when either dropdown changes
label_dropdown.change(
fn=process_audio,
inputs=[label_dropdown, device_dropdown],
outputs=outputs
)
device_dropdown.change(
fn=process_audio,
inputs=[label_dropdown, device_dropdown],
outputs=outputs
)
demo.launch()