import streamlit as st st.set_page_config(page_title="Neural Beatbox Live", page_icon="đŸĨ") import torch import torchaudio import numpy as np from pathlib import Path import tempfile from streamlit_mic_recorder import mic_recorder SAMPLE_RATE = 22050 TARGET_LENGTH = 17640 N_FFT = 1024 HOP_LENGTH = 256 N_MELS = 64 F_MAX = 8000 mel_transform = torchaudio.transforms.MelSpectrogram( sample_rate=SAMPLE_RATE, n_fft=N_FFT, hop_length=HOP_LENGTH, n_mels=N_MELS, f_min=0, f_max=F_MAX, window_fn=torch.hann_window, power=2.0, ) # model class NeuralBeatboxCNN(torch.nn.Module): def __init__(self, num_classes=3): super().__init__() def dw_block(in_ch, out_ch, stride=1): return torch.nn.Sequential( torch.nn.Conv2d(in_ch, in_ch, 3, stride=stride, padding=1, groups=in_ch, bias=False), torch.nn.BatchNorm2d(in_ch), torch.nn.ReLU6(inplace=True), torch.nn.Conv2d(in_ch, out_ch, 1, bias=False), torch.nn.BatchNorm2d(out_ch), torch.nn.ReLU6(inplace=True), ) self.features = torch.nn.Sequential( torch.nn.Conv2d(1, 32, 3, stride=1, padding=1, bias=False), torch.nn.BatchNorm2d(32), torch.nn.ReLU6(inplace=True), dw_block(32, 64), dw_block(64, 128), dw_block(128, 128), dw_block(128, 256), dw_block(256, 256), ) self.classifier = torch.nn.Sequential( torch.nn.AdaptiveAvgPool2d(1), torch.nn.Flatten(), torch.nn.Dropout(0.3), torch.nn.Linear(256, num_classes) ) def forward(self, x): x = self.features(x) x = self.classifier(x) return x @st.cache_resource def load_model(): model = NeuralBeatboxCNN() state_path = Path("state_dict.pth") if not state_path.exists(): st.error("Model weights missing! Add state_dict.pth to repo.") st.stop() model.load_state_dict(torch.load(state_path, map_location="cpu")) model.eval() return model model = load_model() # preprocessing def preprocess_waveform(waveform: torch.Tensor, orig_sr: int) -> torch.Tensor: if orig_sr != SAMPLE_RATE: waveform = torchaudio.transforms.Resample(orig_sr, SAMPLE_RATE)(waveform) waveform = waveform.mean(dim=0, keepdim=True) if waveform.shape[1] < TARGET_LENGTH: pad = TARGET_LENGTH - waveform.shape[1] waveform = torch.nn.functional.pad(waveform, (pad//2, pad - pad//2)) else: start = (waveform.shape[1] - TARGET_LENGTH) // 2 waveform = waveform[:, start:start + TARGET_LENGTH] mel = mel_transform(waveform) log_mel = torch.log(mel + 1e-9) return log_mel.unsqueeze(0) # drums drum_samples = { 0: "kick.wav", 1: "snare.wav", 2: "hihat.wav" } label_names = {0: "Boom (Kick)", 1: "Kah (Snare)", 2: "Tss (Hi-hat)"} # UI st.title("đŸĨ Neural Beatbox — Live Mic + Upload") st.markdown("**Beatbox into your mic or upload a recording** → get a clean drum version instantly!") tab1, tab2 = st.tabs(["🎤 Live Mic Recording", "📁 Upload File"]) with tab1: st.markdown("Click to start/stop recording (beatbox clearly!)") audio_info = mic_recorder( start_prompt="🎤 Start Beatboxing", stop_prompt="âšī¸ Stop", format="wav", key="live_mic" ) if audio_info: st.audio(audio_info['bytes'], format='audio/wav') waveform = torch.from_numpy(np.frombuffer(audio_info['bytes'], np.int16)).float() / 32768 waveform = waveform.unsqueeze(0) # (1, samples) orig_sr = 44100 # typical browser recording rate with tab2: uploaded_file = st.file_uploader("Or upload WAV/MP3", type=["wav", "mp3", "ogg"]) if uploaded_file: st.audio(uploaded_file) with tempfile.NamedTemporaryFile(delete=False, suffix=Path(uploaded_file.name).suffix) as tmp: tmp.write(uploaded_file.getvalue()) waveform, orig_sr = torchaudio.load(tmp.name) # Process if audio available if 'waveform' in locals() and waveform is not None: # Sliding window + energy trigger window_samples = TARGET_LENGTH hop_samples = TARGET_LENGTH // 2 energy_threshold = 0.02 confidence_threshold = 0.6 predictions = [] positions_ms = [] with torch.no_grad(): for start in range(0, waveform.shape[1] - window_samples + 1, hop_samples): segment = waveform[:, start:start + window_samples] energy = torch.mean(segment ** 2).item() if energy > energy_threshold: spec = preprocess_waveform(segment, orig_sr) output = model(spec) pred = output.argmax(dim=1).item() conf = torch.softmax(output, dim=1).max().item() if conf > confidence_threshold: predictions.append(pred) positions_ms.append(int(start / SAMPLE_RATE * 1000)) st.write(f"**Detected {len(predictions)} drum hits**") if predictions: seq_text = " → ".join([label_names[p] for p in predictions]) st.markdown(f"**Beat sequence:** {seq_text}") # Build drum beat combined = torch.zeros(1, 0) prev_pos = 0 for i, pred in enumerate(predictions): drum_wave, drum_sr = torchaudio.load(drum_samples[pred]) drum_wave = drum_wave.mean(dim=0, keepdim=True) if drum_sr != SAMPLE_RATE: drum_wave = torchaudio.transforms.Resample(drum_sr, SAMPLE_RATE)(drum_wave) silence_ms = positions_ms[i] - prev_pos if i > 0 else positions_ms[i] silence_samples = int(silence_ms / 1000 * SAMPLE_RATE) silence = torch.zeros(1, silence_samples) drum_segment = torch.cat([silence, drum_wave], dim=1) if combined.shape[1] < drum_segment.shape[1]: pad = drum_segment.shape[1] - combined.shape[1] combined = torch.nn.functional.pad(combined, (0, pad)) combined += drum_segment[:, :combined.shape[1]] prev_pos = positions_ms[i] with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as out: torchaudio.save(out.name, combined, SAMPLE_RATE) st.audio(out.name) st.balloons() st.success("Your beatbox → studio drums! 🎉") else: st.info("No clear hits detected — try louder, sharper sounds!")