atd_demo / app.py
mmzoellner's picture
Update app.py
4b52ccc verified
import torch
import torchaudio
import torchaudio.transforms as AT
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
import io, PIL.Image
from torchvision.models import mobilenet_v3_small
from huggingface_hub import hf_hub_download
# --- Performance/Device ---
torch.set_num_threads(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Config ---
SAMPLE_RATE = 16000
CHUNK_SEC = 2
STRIDE_SEC = 0.1 # 🔥 stride = 0.2s → ~5 preds per second
MAX_SEC = 60
IMG_SIZE = 224
N_MELS = 128
N_FFT = 1024
HOP_LENGTH = 512
WIDTH_MULT = 1
HF_REPO_ID = "mmzoellner/MobileNet_ATD"
MOBILENET_WEIGHTS_FILENAME = "mobilenet_v3_small_fold2.pt"
# --- Event Detection Thresholds & Smoothing ---
START_THRESHOLD = 0.3 # Probability to start an event
END_THRESHOLD = 0.1 # Probability to end an event
MOVING_AVG_WINDOW = 10 # Window size for smoothing (5 samples = 1 sec of preds)
# --- Audio -> Mel-dB -> 3ch 224x224 ---
mel_spec = AT.MelSpectrogram(
sample_rate=SAMPLE_RATE, n_fft=N_FFT, hop_length=HOP_LENGTH, n_mels=N_MELS)
to_db = AT.AmplitudeToDB()
IM_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
IM_STD = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
def waveform_to_tensor_img(wav_1ch_16k: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
mel = mel_spec(wav_1ch_16k)
mel_db = to_db(mel)
x3 = mel_db.expand(3, -1, -1)
x3 = torch.nn.functional.interpolate(
x3.unsqueeze(0), size=(IMG_SIZE, IMG_SIZE),
mode="bilinear", align_corners=False
).squeeze(0)
x3 = (x3 - IM_MEAN) / IM_STD
return x3
# --- Model Loading ---
def load_model():
model = mobilenet_v3_small(width_mult=WIDTH_MULT, weights=None)
in_features = model.classifier[3].in_features
model.classifier[3] = torch.nn.Linear(in_features, 2)
weights_path = hf_hub_download(repo_id=HF_REPO_ID, filename=MOBILENET_WEIGHTS_FILENAME)
state = torch.load(weights_path, map_location="cpu")
model.load_state_dict(state, strict=True)
model.to(device).eval()
return model
model = load_model()
# --- Plotting ---
def plot_event_predictions(times, raw_probs, smoothed_probs, events):
"""Plots raw and smoothed probabilities, highlighting detected event spans."""
midpoints = [t + (CHUNK_SEC / 2) for t in times]
fig, ax = plt.subplots(figsize=(10, 3))
# Plot raw and smoothed probabilities
ax.plot(midpoints, raw_probs, color='lightblue', linestyle='-', label="Raw Probability")
ax.plot(midpoints, smoothed_probs, color='blue', marker='.', linestyle='-', markersize=4, label="Smoothed Probability")
# Add threshold lines
ax.axhline(y=START_THRESHOLD, color='green', linestyle='--', label=f'Start Threshold ({START_THRESHOLD})')
ax.axhline(y=END_THRESHOLD, color='red', linestyle='--', label=f'End Threshold ({END_THRESHOLD})')
# Highlight detected event spans
for event in events:
ax.axvspan(event['start'], event['end'], color='orange', alpha=0.3, label='_nolegend_')
ax.set_xlabel("Time (s)")
ax.set_ylabel("Probability")
ax.set_ylim(0, 1)
ax.set_title("Train Detection Events Over Time")
ax.legend(loc='upper left')
ax.grid(True)
buf = io.BytesIO()
plt.tight_layout()
plt.savefig(buf, format='png')
plt.close(fig)
buf.seek(0)
image = PIL.Image.open(buf)
return np.array(image)
# --- Audio Loader ---
def load_audio_16k(path):
wav, sr = torchaudio.load(path)
wav = wav.mean(dim=0, keepdim=True) # Mono
if sr != SAMPLE_RATE:
wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)
return wav
# --- Helper function for smoothing ---
def moving_average(data, window_size):
"""Calculates the moving average using convolution."""
return np.convolve(data, np.ones(window_size)/window_size, mode='same')
# --- Rolling Window Inference with Event Detection ---
def predict(audio_path):
wav = load_audio_16k(audio_path)
T_max = SAMPLE_RATE * MAX_SEC
if wav.shape[-1] > T_max:
wav = wav[..., :T_max]
chunk_len = SAMPLE_RATE * CHUNK_SEC
stride = int(SAMPLE_RATE * STRIDE_SEC)
if wav.shape[-1] < chunk_len:
return "Audio is too short (must be at least 2 seconds).", np.zeros((10,10,3), dtype=np.uint8)
# Create chunks and get predictions
chunks, times = [], []
for start in range(0, wav.shape[-1] - chunk_len + 1, stride):
c = wav[..., start:start+chunk_len]
chunks.append(waveform_to_tensor_img(c))
times.append(start / SAMPLE_RATE)
x = torch.stack(chunks, dim=0).to(device)
with torch.no_grad():
logits = model(x)
raw_probs = torch.softmax(logits, dim=1)[:, 1].cpu().numpy()
# Smooth the raw probabilities to reduce noise
smoothed_probs = moving_average(raw_probs, MOVING_AVG_WINDOW)
# --- State machine logic to find events using smoothed probabilities ---
events = []
is_event_active = False
event_start_time = 0
event_probs_agg = [] # To aggregate raw probabilities during an event
for i, p_smooth in enumerate(smoothed_probs):
chunk_center_time = times[i] + (CHUNK_SEC / 2)
if not is_event_active:
if p_smooth > START_THRESHOLD:
is_event_active = True
event_start_time = chunk_center_time
event_probs_agg.append(raw_probs[i])
else:
if p_smooth < END_THRESHOLD:
is_event_active = False
event_end_time = chunk_center_time
events.append({
"start": round(event_start_time, 1),
"end": round(event_end_time, 1),
"confidence": round(float(np.mean(event_probs_agg)), 2)
})
event_probs_agg = []
else:
event_probs_agg.append(raw_probs[i])
if is_event_active:
final_time = times[-1] + (CHUNK_SEC / 2)
events.append({
"start": round(event_start_time, 1),
"end": round(final_time, 1),
"confidence": round(float(np.mean(event_probs_agg)), 2)
})
if not events:
text = "No Train events detected."
else:
text = "\n".join(
f"Train detected: {e['start']}s – {e['end']}s (Avg. Confidence: {e['confidence']*100:.1f}%)"
for e in events
)
plot = plot_event_predictions(times, raw_probs, smoothed_probs, events)
return text, plot
# --- Gradio ---
demo = gr.Interface(
fn=predict,
inputs=gr.Audio(type="filepath", label="Upload Audio (max 60s)"),
outputs=[
gr.Textbox(label="Detected Events"),
gr.Image(label="Detection Timeline")
],
title="ATD: Acoustic Train Detection (Event-Based & Smoothed)",
description="This demo detects the start and end of train sounds using smoothed probabilities. An event starts when confidence exceeds 50% and ends when it drops below 10%."
)
if __name__ == "__main__":
demo.launch(share=True)