Spaces:
Sleeping
Sleeping
| import torch | |
| import torchaudio | |
| import torchaudio.transforms as AT | |
| import numpy as np | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import io, PIL.Image | |
| from torchvision.models import mobilenet_v3_small | |
| from huggingface_hub import hf_hub_download | |
| # --- Performance/Device --- | |
| torch.set_num_threads(1) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # --- Config --- | |
| SAMPLE_RATE = 16000 | |
| CHUNK_SEC = 2 | |
| STRIDE_SEC = 0.1 # 🔥 stride = 0.2s → ~5 preds per second | |
| MAX_SEC = 60 | |
| IMG_SIZE = 224 | |
| N_MELS = 128 | |
| N_FFT = 1024 | |
| HOP_LENGTH = 512 | |
| WIDTH_MULT = 1 | |
| HF_REPO_ID = "mmzoellner/MobileNet_ATD" | |
| MOBILENET_WEIGHTS_FILENAME = "mobilenet_v3_small_fold2.pt" | |
| # --- Event Detection Thresholds & Smoothing --- | |
| START_THRESHOLD = 0.3 # Probability to start an event | |
| END_THRESHOLD = 0.1 # Probability to end an event | |
| MOVING_AVG_WINDOW = 10 # Window size for smoothing (5 samples = 1 sec of preds) | |
| # --- Audio -> Mel-dB -> 3ch 224x224 --- | |
| mel_spec = AT.MelSpectrogram( | |
| sample_rate=SAMPLE_RATE, n_fft=N_FFT, hop_length=HOP_LENGTH, n_mels=N_MELS) | |
| to_db = AT.AmplitudeToDB() | |
| IM_MEAN = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) | |
| IM_STD = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) | |
| def waveform_to_tensor_img(wav_1ch_16k: torch.Tensor) -> torch.Tensor: | |
| with torch.no_grad(): | |
| mel = mel_spec(wav_1ch_16k) | |
| mel_db = to_db(mel) | |
| x3 = mel_db.expand(3, -1, -1) | |
| x3 = torch.nn.functional.interpolate( | |
| x3.unsqueeze(0), size=(IMG_SIZE, IMG_SIZE), | |
| mode="bilinear", align_corners=False | |
| ).squeeze(0) | |
| x3 = (x3 - IM_MEAN) / IM_STD | |
| return x3 | |
| # --- Model Loading --- | |
| def load_model(): | |
| model = mobilenet_v3_small(width_mult=WIDTH_MULT, weights=None) | |
| in_features = model.classifier[3].in_features | |
| model.classifier[3] = torch.nn.Linear(in_features, 2) | |
| weights_path = hf_hub_download(repo_id=HF_REPO_ID, filename=MOBILENET_WEIGHTS_FILENAME) | |
| state = torch.load(weights_path, map_location="cpu") | |
| model.load_state_dict(state, strict=True) | |
| model.to(device).eval() | |
| return model | |
| model = load_model() | |
| # --- Plotting --- | |
| def plot_event_predictions(times, raw_probs, smoothed_probs, events): | |
| """Plots raw and smoothed probabilities, highlighting detected event spans.""" | |
| midpoints = [t + (CHUNK_SEC / 2) for t in times] | |
| fig, ax = plt.subplots(figsize=(10, 3)) | |
| # Plot raw and smoothed probabilities | |
| ax.plot(midpoints, raw_probs, color='lightblue', linestyle='-', label="Raw Probability") | |
| ax.plot(midpoints, smoothed_probs, color='blue', marker='.', linestyle='-', markersize=4, label="Smoothed Probability") | |
| # Add threshold lines | |
| ax.axhline(y=START_THRESHOLD, color='green', linestyle='--', label=f'Start Threshold ({START_THRESHOLD})') | |
| ax.axhline(y=END_THRESHOLD, color='red', linestyle='--', label=f'End Threshold ({END_THRESHOLD})') | |
| # Highlight detected event spans | |
| for event in events: | |
| ax.axvspan(event['start'], event['end'], color='orange', alpha=0.3, label='_nolegend_') | |
| ax.set_xlabel("Time (s)") | |
| ax.set_ylabel("Probability") | |
| ax.set_ylim(0, 1) | |
| ax.set_title("Train Detection Events Over Time") | |
| ax.legend(loc='upper left') | |
| ax.grid(True) | |
| buf = io.BytesIO() | |
| plt.tight_layout() | |
| plt.savefig(buf, format='png') | |
| plt.close(fig) | |
| buf.seek(0) | |
| image = PIL.Image.open(buf) | |
| return np.array(image) | |
| # --- Audio Loader --- | |
| def load_audio_16k(path): | |
| wav, sr = torchaudio.load(path) | |
| wav = wav.mean(dim=0, keepdim=True) # Mono | |
| if sr != SAMPLE_RATE: | |
| wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE) | |
| return wav | |
| # --- Helper function for smoothing --- | |
| def moving_average(data, window_size): | |
| """Calculates the moving average using convolution.""" | |
| return np.convolve(data, np.ones(window_size)/window_size, mode='same') | |
| # --- Rolling Window Inference with Event Detection --- | |
| def predict(audio_path): | |
| wav = load_audio_16k(audio_path) | |
| T_max = SAMPLE_RATE * MAX_SEC | |
| if wav.shape[-1] > T_max: | |
| wav = wav[..., :T_max] | |
| chunk_len = SAMPLE_RATE * CHUNK_SEC | |
| stride = int(SAMPLE_RATE * STRIDE_SEC) | |
| if wav.shape[-1] < chunk_len: | |
| return "Audio is too short (must be at least 2 seconds).", np.zeros((10,10,3), dtype=np.uint8) | |
| # Create chunks and get predictions | |
| chunks, times = [], [] | |
| for start in range(0, wav.shape[-1] - chunk_len + 1, stride): | |
| c = wav[..., start:start+chunk_len] | |
| chunks.append(waveform_to_tensor_img(c)) | |
| times.append(start / SAMPLE_RATE) | |
| x = torch.stack(chunks, dim=0).to(device) | |
| with torch.no_grad(): | |
| logits = model(x) | |
| raw_probs = torch.softmax(logits, dim=1)[:, 1].cpu().numpy() | |
| # Smooth the raw probabilities to reduce noise | |
| smoothed_probs = moving_average(raw_probs, MOVING_AVG_WINDOW) | |
| # --- State machine logic to find events using smoothed probabilities --- | |
| events = [] | |
| is_event_active = False | |
| event_start_time = 0 | |
| event_probs_agg = [] # To aggregate raw probabilities during an event | |
| for i, p_smooth in enumerate(smoothed_probs): | |
| chunk_center_time = times[i] + (CHUNK_SEC / 2) | |
| if not is_event_active: | |
| if p_smooth > START_THRESHOLD: | |
| is_event_active = True | |
| event_start_time = chunk_center_time | |
| event_probs_agg.append(raw_probs[i]) | |
| else: | |
| if p_smooth < END_THRESHOLD: | |
| is_event_active = False | |
| event_end_time = chunk_center_time | |
| events.append({ | |
| "start": round(event_start_time, 1), | |
| "end": round(event_end_time, 1), | |
| "confidence": round(float(np.mean(event_probs_agg)), 2) | |
| }) | |
| event_probs_agg = [] | |
| else: | |
| event_probs_agg.append(raw_probs[i]) | |
| if is_event_active: | |
| final_time = times[-1] + (CHUNK_SEC / 2) | |
| events.append({ | |
| "start": round(event_start_time, 1), | |
| "end": round(final_time, 1), | |
| "confidence": round(float(np.mean(event_probs_agg)), 2) | |
| }) | |
| if not events: | |
| text = "No Train events detected." | |
| else: | |
| text = "\n".join( | |
| f"Train detected: {e['start']}s – {e['end']}s (Avg. Confidence: {e['confidence']*100:.1f}%)" | |
| for e in events | |
| ) | |
| plot = plot_event_predictions(times, raw_probs, smoothed_probs, events) | |
| return text, plot | |
| # --- Gradio --- | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Audio(type="filepath", label="Upload Audio (max 60s)"), | |
| outputs=[ | |
| gr.Textbox(label="Detected Events"), | |
| gr.Image(label="Detection Timeline") | |
| ], | |
| title="ATD: Acoustic Train Detection (Event-Based & Smoothed)", | |
| description="This demo detects the start and end of train sounds using smoothed probabilities. An event starts when confidence exceeds 50% and ends when it drops below 10%." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |