ASCDomain / app.py
hubtru's picture
Update app.py
129057a verified
import gradio as gr
import os
import numpy as np
import torch
import torch.nn as nn
import torchaudio
import matplotlib.pyplot as plt
import io
from PIL import Image
# Device and label IDs
device_ids = ['a', 'b', 'c', 's1', 's2', 's3']
label_ids = ['airport', 'bus', 'metro', 'metro_station', 'park',
'public_square', 'shopping_mall', 'street_pedestrian',
'street_traffic', 'tram']
# Directories
audio_dir = os.path.join('demo', 'audio')
ir_dir = os.path.join('demo', 'impulse_responses')
ir_names = ['Altec_639.wav', 'Altec_670A.wav', 'Altec_670B.wav']
# Load impulse response files
irs = []
for ir_name in ir_names:
ir_path = os.path.join(ir_dir, ir_name)
ir, _ = torchaudio.load(ir_path)
irs.append(ir)
# Resampling and other transforms
orig_sample_rate = 44100
sample_rate = 32000
resample = torchaudio.transforms.Resample(
orig_freq=orig_sample_rate,
new_freq=sample_rate
)
n_fft = 4096
window_length = 3072
hop_length = 500
n_mels = 256
f_min = 0
f_max = None
mel_spectrogram = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
win_length=window_length,
hop_length=hop_length,
n_mels=n_mels,
f_min=f_min,
f_max=f_max
)
freqm = 48
timem = 0
freq_mask = torchaudio.transforms.FrequencyMasking(freqm, iid_masks=True)
time_mask = torchaudio.transforms.TimeMasking(timem, iid_masks=True)
mel_augment = torch.nn.Sequential(
freq_mask,
time_mask
)
# Mixstyle function
def mixstyle(x, p=0.4, alpha=0.3, eps=1e-6):
if np.random.rand() > p:
return x
batch_size = x.size(0)
f_mu = x.mean(dim=[1, 3], keepdim=True)
f_var = x.var(dim=[1, 3], keepdim=True)
f_sig = (f_var + eps).sqrt()
f_mu, f_sig = f_mu.detach(), f_sig.detach()
x_normed = (x - f_mu) / f_sig
perm = torch.randperm(batch_size)
f_mu_perm, f_sig_perm = f_mu[perm], f_sig[perm]
lmda = torch.distributions.Beta(alpha, alpha).sample((batch_size, 1, 1, 1))
lmda = lmda.to(x.device)
mu_mix = f_mu * lmda + f_mu_perm * (1 - lmda)
sig_mix = f_sig * lmda + f_sig_perm * (1 - lmda)
x = x_normed * sig_mix + mu_mix
return x
# Model definition
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
def ConvMixer(in_channels, filter, depth, kernel_size, patch_size, n_classes):
return nn.Sequential(
nn.Conv2d(in_channels, filter, kernel_size=patch_size, stride=patch_size),
nn.GELU(),
nn.BatchNorm2d(filter),
*[nn.Sequential(
Residual(nn.Sequential(
nn.Conv2d(filter, filter, kernel_size, groups=filter, padding="same"),
nn.GELU(),
nn.BatchNorm2d(filter)
)),
nn.Conv2d(filter, filter, kernel_size=1),
nn.GELU(),
nn.BatchNorm2d(filter)
) for i in range(depth)],
nn.AdaptiveAvgPool2d((1,1)),
nn.Flatten(),
nn.Linear(filter, n_classes)
)
# Instantiate and load the model
# Model parameters (should match those used during training)
in_channels = 1
filter = 64
depth = 9
kernel_size = 3
patch_size = 5
n_classes = 10
model = ConvMixer(in_channels, filter, depth, kernel_size, patch_size, n_classes)
model_path = 'model.pth' # Path to the saved model weights
# Load the model weights
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True))
model.eval()
else:
print(f"Model file '{model_path}' not found. Please place the model file in the same directory.")
# Optionally, you can raise an exception or exit
# raise FileNotFoundError(f"Model file '{model_path}' not found.")
# Function to process audio and generate outputs
def process_audio(selected_label, selected_device):
# Find matching audio files
matching_files = []
for filename in os.listdir(audio_dir):
if not filename.endswith('.wav'):
continue
basename = os.path.splitext(filename)[0]
parts = basename.split('-')
if len(parts) < 6:
continue
scene, city, x, y, z, device = parts
if scene == selected_label and device == selected_device:
matching_files.append(filename)
if len(matching_files) >= 3:
break
if not matching_files:
return ["No matching audio files found"] * 21 # 21 outputs now
outputs = []
for audio_file in matching_files:
# Load original audio
audio_path = os.path.join(audio_dir, audio_file)
waveform, sr = torchaudio.load(audio_path)
# Resample
waveform_resampled = resample(waveform)
# Original audio player
original_audio = (sample_rate, waveform_resampled.squeeze().numpy())
outputs.append(original_audio)
# Augment audio (apply impulse response)
ir = irs[np.random.randint(len(irs))]
augmented_waveform = torchaudio.functional.convolve(waveform_resampled, ir)[:, :waveform_resampled.shape[1]]
# Augmented audio player
augmented_audio = (sample_rate, augmented_waveform.squeeze().numpy())
outputs.append(augmented_audio)
# **Waveform plot of original vs augmented**
fig, ax = plt.subplots()
ax.plot(waveform_resampled.squeeze().numpy(), label='normal')
ax.plot(augmented_waveform.squeeze().numpy(), label='augmented', linestyle='-.', alpha=0.8)
ax.set_title(f'Label: {selected_label}')
ax.legend()
ax.set_xlabel('Time Samples')
ax.set_ylabel('Amplitude')
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
plt.close(fig)
buf.seek(0)
waveform_plot_image = Image.open(buf)
outputs.append(waveform_plot_image)
# Mel-Spectrogram
mel_spec = mel_spectrogram(augmented_waveform)
mel_spec_db = (mel_spec + 1e-5).log()
fig, ax = plt.subplots()
ax.imshow(mel_spec_db.squeeze().numpy(), origin='lower', aspect='auto')
ax.set_title('Mel-Spectrogram')
plt.axis('off')
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
plt.close(fig)
buf.seek(0)
mel_spec_image = Image.open(buf)
outputs.append(mel_spec_image)
# Frequency and Time Masking
masked_mel_spec = mel_augment(mel_spec_db)
fig, ax = plt.subplots()
ax.imshow(masked_mel_spec.squeeze().numpy(), origin='lower', aspect='auto')
ax.set_title('Frequency and Time Masking')
plt.axis('off')
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
plt.close(fig)
buf.seek(0)
masked_mel_spec_image = Image.open(buf)
outputs.append(masked_mel_spec_image)
# MixStyle Visualization
x_mix = mixstyle(masked_mel_spec.unsqueeze(0), p=1.0)
fig, ax = plt.subplots()
ax.imshow(x_mix.squeeze().numpy(), origin='lower', aspect='auto')
ax.set_title('MixStyle')
plt.axis('off')
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
plt.close(fig)
buf.seek(0)
mixstyle_image = Image.open(buf)
outputs.append(mixstyle_image)
# Model Prediction
with torch.no_grad():
x = resample(waveform)
x = mel_spectrogram(x)
x = (x + 1e-5).log().unsqueeze(0)
y_hat = model(x)
predicted_idx = y_hat.argmax(dim=1).item()
predicted_label = label_ids[predicted_idx]
outputs.append(f"Predicted Class: {predicted_label}")
# If less than 3 files, pad the outputs
total_outputs_needed = 3 * 7 # 3 files * 7 outputs per file
outputs += [""] * (total_outputs_needed - len(outputs))
return outputs
def gradio_interface():
theme = gr.themes.Base(
primary_hue="blue",
secondary_hue="blue",
neutral_hue="gray"
)
theme.set(
body_background_fill="*primary_50",
body_background_fill_dark="*checkbox_background_color_focus",
body_text_color_dark="white",
body_text_color="*neutral_800",
background_fill_secondary_dark="*checkbox_border_color_hover",
block_background_fill="*background_fill_primary",
block_background_fill_dark="*neutral_800",
block_border_color_dark="*primary_100",
block_border_width_dark="4px",
block_border_width="4px",
block_border_color="*secondary_200"
)
interface = gr.Interface(
fn=process_audio,
inputs=[
gr.Dropdown(choices=label_ids, label="Select Label"),
gr.Dropdown(choices=device_ids, label="Select Device")
],
outputs=[
gr.Audio(label="Original Audio 1"),
gr.Audio(label="Augmented Audio 1"),
gr.Image(label="Waveform Plot 1"),
gr.Image(label="Mel-Spectrogram 1"),
gr.Image(label="Frequency and Time Masking 1"),
gr.Image(label="MixStyle 1"),
gr.Textbox(label="Predicted Class 1"),
gr.Audio(label="Original Audio 2"),
gr.Audio(label="Augmented Audio 2"),
gr.Image(label="Waveform Plot 2"),
gr.Image(label="Mel-Spectrogram 2"),
gr.Image(label="Frequency and Time Masking 2"),
gr.Image(label="MixStyle 2"),
gr.Textbox(label="Predicted Class 2"),
gr.Audio(label="Original Audio 3"),
gr.Audio(label="Augmented Audio 3"),
gr.Image(label="Waveform Plot 3"),
gr.Image(label="Mel-Spectrogram 3"),
gr.Image(label="Frequency and Time Masking 3"),
gr.Image(label="MixStyle 3"),
gr.Textbox(label="Predicted Class 3")
],
title="<h1 style='text-align: center; font-size: 1.5em;'>ASCDomain</h1>",
description="""
<div style="font-size: 16px; letter-spacing: 1.2px; line-height: 1.8; text-align: justify;">
<strong>ASCDomain:</strong> Domain Invariant Device-Self-Challenging Isotopic Convolutional Neural Architecture<br>
<strong>ASCDomain repository:</strong> <a href="https://github.com/hubtru/ASCDomain" target="_blank">ASCDomain</a><br>
Explore different acoustic scenes and mobile devices in our latest model.<br>
<strong>Options:</strong> <br>
<ul>
<li><strong>Acoustic Scene:</strong> Airport, Indoor shopping mall, Metro station, Pedestrian street, Public square, Street with medium level of traffic, Travelling by tram, Travelling by bus, Travelling by underground metro, Urban park</li>
<li><strong>Mobile Device:</strong> a, b, c, s1, s2, s3</li>
</ul>
</div>
""",
theme=theme,
live=True,
allow_flagging="never"
)
interface.launch()
gradio_interface()