Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| st.set_page_config(page_title="Neural Beatbox Live", page_icon="π₯") | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| from pathlib import Path | |
| import tempfile | |
| from streamlit_mic_recorder import mic_recorder | |
| SAMPLE_RATE = 22050 | |
| TARGET_LENGTH = 17640 | |
| N_FFT = 1024 | |
| HOP_LENGTH = 256 | |
| N_MELS = 64 | |
| F_MAX = 8000 | |
| mel_transform = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=SAMPLE_RATE, | |
| n_fft=N_FFT, | |
| hop_length=HOP_LENGTH, | |
| n_mels=N_MELS, | |
| f_min=0, | |
| f_max=F_MAX, | |
| window_fn=torch.hann_window, | |
| power=2.0, | |
| ) | |
| # model | |
| class NeuralBeatboxCNN(torch.nn.Module): | |
| def __init__(self, num_classes=3): | |
| super().__init__() | |
| def dw_block(in_ch, out_ch, stride=1): | |
| return torch.nn.Sequential( | |
| torch.nn.Conv2d(in_ch, in_ch, 3, stride=stride, padding=1, groups=in_ch, bias=False), | |
| torch.nn.BatchNorm2d(in_ch), | |
| torch.nn.ReLU6(inplace=True), | |
| torch.nn.Conv2d(in_ch, out_ch, 1, bias=False), | |
| torch.nn.BatchNorm2d(out_ch), | |
| torch.nn.ReLU6(inplace=True), | |
| ) | |
| self.features = torch.nn.Sequential( | |
| torch.nn.Conv2d(1, 32, 3, stride=1, padding=1, bias=False), | |
| torch.nn.BatchNorm2d(32), | |
| torch.nn.ReLU6(inplace=True), | |
| dw_block(32, 64), | |
| dw_block(64, 128), | |
| dw_block(128, 128), | |
| dw_block(128, 256), | |
| dw_block(256, 256), | |
| ) | |
| self.classifier = torch.nn.Sequential( | |
| torch.nn.AdaptiveAvgPool2d(1), | |
| torch.nn.Flatten(), | |
| torch.nn.Dropout(0.3), | |
| torch.nn.Linear(256, num_classes) | |
| ) | |
| def forward(self, x): | |
| x = self.features(x) | |
| x = self.classifier(x) | |
| return x | |
| def load_model(): | |
| model = NeuralBeatboxCNN() | |
| state_path = Path("state_dict.pth") | |
| if not state_path.exists(): | |
| st.error("Model weights missing! Add state_dict.pth to repo.") | |
| st.stop() | |
| model.load_state_dict(torch.load(state_path, map_location="cpu")) | |
| model.eval() | |
| return model | |
| model = load_model() | |
| # preprocessing | |
| def preprocess_waveform(waveform: torch.Tensor, orig_sr: int) -> torch.Tensor: | |
| if orig_sr != SAMPLE_RATE: | |
| waveform = torchaudio.transforms.Resample(orig_sr, SAMPLE_RATE)(waveform) | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| if waveform.shape[1] < TARGET_LENGTH: | |
| pad = TARGET_LENGTH - waveform.shape[1] | |
| waveform = torch.nn.functional.pad(waveform, (pad//2, pad - pad//2)) | |
| else: | |
| start = (waveform.shape[1] - TARGET_LENGTH) // 2 | |
| waveform = waveform[:, start:start + TARGET_LENGTH] | |
| mel = mel_transform(waveform) | |
| log_mel = torch.log(mel + 1e-9) | |
| return log_mel.unsqueeze(0) | |
| # drums | |
| drum_samples = { | |
| 0: "kick.wav", | |
| 1: "snare.wav", | |
| 2: "hihat.wav" | |
| } | |
| label_names = {0: "Boom (Kick)", 1: "Kah (Snare)", 2: "Tss (Hi-hat)"} | |
| # UI | |
| st.title("π₯ Neural Beatbox β Live Mic + Upload") | |
| st.markdown("**Beatbox into your mic or upload a recording** β get a clean drum version instantly!") | |
| tab1, tab2 = st.tabs(["π€ Live Mic Recording", "π Upload File"]) | |
| with tab1: | |
| st.markdown("Click to start/stop recording (beatbox clearly!)") | |
| audio_info = mic_recorder( | |
| start_prompt="π€ Start Beatboxing", | |
| stop_prompt="βΉοΈ Stop", | |
| format="wav", | |
| key="live_mic" | |
| ) | |
| if audio_info: | |
| st.audio(audio_info['bytes'], format='audio/wav') | |
| waveform = torch.from_numpy(np.frombuffer(audio_info['bytes'], np.int16)).float() / 32768 | |
| waveform = waveform.unsqueeze(0) # (1, samples) | |
| orig_sr = 44100 # typical browser recording rate | |
| with tab2: | |
| uploaded_file = st.file_uploader("Or upload WAV/MP3", type=["wav", "mp3", "ogg"]) | |
| if uploaded_file: | |
| st.audio(uploaded_file) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=Path(uploaded_file.name).suffix) as tmp: | |
| tmp.write(uploaded_file.getvalue()) | |
| waveform, orig_sr = torchaudio.load(tmp.name) | |
| # Process if audio available | |
| if 'waveform' in locals() and waveform is not None: | |
| # Sliding window + energy trigger | |
| window_samples = TARGET_LENGTH | |
| hop_samples = TARGET_LENGTH // 2 | |
| energy_threshold = 0.02 | |
| confidence_threshold = 0.6 | |
| predictions = [] | |
| positions_ms = [] | |
| with torch.no_grad(): | |
| for start in range(0, waveform.shape[1] - window_samples + 1, hop_samples): | |
| segment = waveform[:, start:start + window_samples] | |
| energy = torch.mean(segment ** 2).item() | |
| if energy > energy_threshold: | |
| spec = preprocess_waveform(segment, orig_sr) | |
| output = model(spec) | |
| pred = output.argmax(dim=1).item() | |
| conf = torch.softmax(output, dim=1).max().item() | |
| if conf > confidence_threshold: | |
| predictions.append(pred) | |
| positions_ms.append(int(start / SAMPLE_RATE * 1000)) | |
| st.write(f"**Detected {len(predictions)} drum hits**") | |
| if predictions: | |
| seq_text = " β ".join([label_names[p] for p in predictions]) | |
| st.markdown(f"**Beat sequence:** {seq_text}") | |
| # Build drum beat | |
| combined = torch.zeros(1, 0) | |
| prev_pos = 0 | |
| for i, pred in enumerate(predictions): | |
| drum_wave, drum_sr = torchaudio.load(drum_samples[pred]) | |
| drum_wave = drum_wave.mean(dim=0, keepdim=True) | |
| if drum_sr != SAMPLE_RATE: | |
| drum_wave = torchaudio.transforms.Resample(drum_sr, SAMPLE_RATE)(drum_wave) | |
| silence_ms = positions_ms[i] - prev_pos if i > 0 else positions_ms[i] | |
| silence_samples = int(silence_ms / 1000 * SAMPLE_RATE) | |
| silence = torch.zeros(1, silence_samples) | |
| drum_segment = torch.cat([silence, drum_wave], dim=1) | |
| if combined.shape[1] < drum_segment.shape[1]: | |
| pad = drum_segment.shape[1] - combined.shape[1] | |
| combined = torch.nn.functional.pad(combined, (0, pad)) | |
| combined += drum_segment[:, :combined.shape[1]] | |
| prev_pos = positions_ms[i] | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as out: | |
| torchaudio.save(out.name, combined, SAMPLE_RATE) | |
| st.audio(out.name) | |
| st.balloons() | |
| st.success("Your beatbox β studio drums! π") | |
| else: | |
| st.info("No clear hits detected β try louder, sharper sounds!") | |