neural-beatbox / app.py
AKMESSI's picture
Update app.py
92ca38c verified
import streamlit as st
st.set_page_config(page_title="Neural Beatbox Live", page_icon="πŸ₯")
import torch
import torchaudio
import numpy as np
from pathlib import Path
import tempfile
from streamlit_mic_recorder import mic_recorder
SAMPLE_RATE = 22050
TARGET_LENGTH = 17640
N_FFT = 1024
HOP_LENGTH = 256
N_MELS = 64
F_MAX = 8000
mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=SAMPLE_RATE,
n_fft=N_FFT,
hop_length=HOP_LENGTH,
n_mels=N_MELS,
f_min=0,
f_max=F_MAX,
window_fn=torch.hann_window,
power=2.0,
)
# model
class NeuralBeatboxCNN(torch.nn.Module):
def __init__(self, num_classes=3):
super().__init__()
def dw_block(in_ch, out_ch, stride=1):
return torch.nn.Sequential(
torch.nn.Conv2d(in_ch, in_ch, 3, stride=stride, padding=1, groups=in_ch, bias=False),
torch.nn.BatchNorm2d(in_ch),
torch.nn.ReLU6(inplace=True),
torch.nn.Conv2d(in_ch, out_ch, 1, bias=False),
torch.nn.BatchNorm2d(out_ch),
torch.nn.ReLU6(inplace=True),
)
self.features = torch.nn.Sequential(
torch.nn.Conv2d(1, 32, 3, stride=1, padding=1, bias=False),
torch.nn.BatchNorm2d(32),
torch.nn.ReLU6(inplace=True),
dw_block(32, 64),
dw_block(64, 128),
dw_block(128, 128),
dw_block(128, 256),
dw_block(256, 256),
)
self.classifier = torch.nn.Sequential(
torch.nn.AdaptiveAvgPool2d(1),
torch.nn.Flatten(),
torch.nn.Dropout(0.3),
torch.nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
@st.cache_resource
def load_model():
model = NeuralBeatboxCNN()
state_path = Path("state_dict.pth")
if not state_path.exists():
st.error("Model weights missing! Add state_dict.pth to repo.")
st.stop()
model.load_state_dict(torch.load(state_path, map_location="cpu"))
model.eval()
return model
model = load_model()
# preprocessing
def preprocess_waveform(waveform: torch.Tensor, orig_sr: int) -> torch.Tensor:
if orig_sr != SAMPLE_RATE:
waveform = torchaudio.transforms.Resample(orig_sr, SAMPLE_RATE)(waveform)
waveform = waveform.mean(dim=0, keepdim=True)
if waveform.shape[1] < TARGET_LENGTH:
pad = TARGET_LENGTH - waveform.shape[1]
waveform = torch.nn.functional.pad(waveform, (pad//2, pad - pad//2))
else:
start = (waveform.shape[1] - TARGET_LENGTH) // 2
waveform = waveform[:, start:start + TARGET_LENGTH]
mel = mel_transform(waveform)
log_mel = torch.log(mel + 1e-9)
return log_mel.unsqueeze(0)
# drums
drum_samples = {
0: "kick.wav",
1: "snare.wav",
2: "hihat.wav"
}
label_names = {0: "Boom (Kick)", 1: "Kah (Snare)", 2: "Tss (Hi-hat)"}
# UI
st.title("πŸ₯ Neural Beatbox β€” Live Mic + Upload")
st.markdown("**Beatbox into your mic or upload a recording** β†’ get a clean drum version instantly!")
tab1, tab2 = st.tabs(["🎀 Live Mic Recording", "πŸ“ Upload File"])
with tab1:
st.markdown("Click to start/stop recording (beatbox clearly!)")
audio_info = mic_recorder(
start_prompt="🎀 Start Beatboxing",
stop_prompt="⏹️ Stop",
format="wav",
key="live_mic"
)
if audio_info:
st.audio(audio_info['bytes'], format='audio/wav')
waveform = torch.from_numpy(np.frombuffer(audio_info['bytes'], np.int16)).float() / 32768
waveform = waveform.unsqueeze(0) # (1, samples)
orig_sr = 44100 # typical browser recording rate
with tab2:
uploaded_file = st.file_uploader("Or upload WAV/MP3", type=["wav", "mp3", "ogg"])
if uploaded_file:
st.audio(uploaded_file)
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(uploaded_file.name).suffix) as tmp:
tmp.write(uploaded_file.getvalue())
waveform, orig_sr = torchaudio.load(tmp.name)
# Process if audio available
if 'waveform' in locals() and waveform is not None:
# Sliding window + energy trigger
window_samples = TARGET_LENGTH
hop_samples = TARGET_LENGTH // 2
energy_threshold = 0.02
confidence_threshold = 0.6
predictions = []
positions_ms = []
with torch.no_grad():
for start in range(0, waveform.shape[1] - window_samples + 1, hop_samples):
segment = waveform[:, start:start + window_samples]
energy = torch.mean(segment ** 2).item()
if energy > energy_threshold:
spec = preprocess_waveform(segment, orig_sr)
output = model(spec)
pred = output.argmax(dim=1).item()
conf = torch.softmax(output, dim=1).max().item()
if conf > confidence_threshold:
predictions.append(pred)
positions_ms.append(int(start / SAMPLE_RATE * 1000))
st.write(f"**Detected {len(predictions)} drum hits**")
if predictions:
seq_text = " β†’ ".join([label_names[p] for p in predictions])
st.markdown(f"**Beat sequence:** {seq_text}")
# Build drum beat
combined = torch.zeros(1, 0)
prev_pos = 0
for i, pred in enumerate(predictions):
drum_wave, drum_sr = torchaudio.load(drum_samples[pred])
drum_wave = drum_wave.mean(dim=0, keepdim=True)
if drum_sr != SAMPLE_RATE:
drum_wave = torchaudio.transforms.Resample(drum_sr, SAMPLE_RATE)(drum_wave)
silence_ms = positions_ms[i] - prev_pos if i > 0 else positions_ms[i]
silence_samples = int(silence_ms / 1000 * SAMPLE_RATE)
silence = torch.zeros(1, silence_samples)
drum_segment = torch.cat([silence, drum_wave], dim=1)
if combined.shape[1] < drum_segment.shape[1]:
pad = drum_segment.shape[1] - combined.shape[1]
combined = torch.nn.functional.pad(combined, (0, pad))
combined += drum_segment[:, :combined.shape[1]]
prev_pos = positions_ms[i]
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as out:
torchaudio.save(out.name, combined, SAMPLE_RATE)
st.audio(out.name)
st.balloons()
st.success("Your beatbox β†’ studio drums! πŸŽ‰")
else:
st.info("No clear hits detected β€” try louder, sharper sounds!")