Spaces:
Sleeping
Sleeping
File size: 6,622 Bytes
92ca38c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 | import streamlit as st
st.set_page_config(page_title="Neural Beatbox Live", page_icon="π₯")
import torch
import torchaudio
import numpy as np
from pathlib import Path
import tempfile
from streamlit_mic_recorder import mic_recorder
SAMPLE_RATE = 22050
TARGET_LENGTH = 17640
N_FFT = 1024
HOP_LENGTH = 256
N_MELS = 64
F_MAX = 8000
mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=SAMPLE_RATE,
n_fft=N_FFT,
hop_length=HOP_LENGTH,
n_mels=N_MELS,
f_min=0,
f_max=F_MAX,
window_fn=torch.hann_window,
power=2.0,
)
# model
class NeuralBeatboxCNN(torch.nn.Module):
def __init__(self, num_classes=3):
super().__init__()
def dw_block(in_ch, out_ch, stride=1):
return torch.nn.Sequential(
torch.nn.Conv2d(in_ch, in_ch, 3, stride=stride, padding=1, groups=in_ch, bias=False),
torch.nn.BatchNorm2d(in_ch),
torch.nn.ReLU6(inplace=True),
torch.nn.Conv2d(in_ch, out_ch, 1, bias=False),
torch.nn.BatchNorm2d(out_ch),
torch.nn.ReLU6(inplace=True),
)
self.features = torch.nn.Sequential(
torch.nn.Conv2d(1, 32, 3, stride=1, padding=1, bias=False),
torch.nn.BatchNorm2d(32),
torch.nn.ReLU6(inplace=True),
dw_block(32, 64),
dw_block(64, 128),
dw_block(128, 128),
dw_block(128, 256),
dw_block(256, 256),
)
self.classifier = torch.nn.Sequential(
torch.nn.AdaptiveAvgPool2d(1),
torch.nn.Flatten(),
torch.nn.Dropout(0.3),
torch.nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
@st.cache_resource
def load_model():
model = NeuralBeatboxCNN()
state_path = Path("state_dict.pth")
if not state_path.exists():
st.error("Model weights missing! Add state_dict.pth to repo.")
st.stop()
model.load_state_dict(torch.load(state_path, map_location="cpu"))
model.eval()
return model
model = load_model()
# preprocessing
def preprocess_waveform(waveform: torch.Tensor, orig_sr: int) -> torch.Tensor:
if orig_sr != SAMPLE_RATE:
waveform = torchaudio.transforms.Resample(orig_sr, SAMPLE_RATE)(waveform)
waveform = waveform.mean(dim=0, keepdim=True)
if waveform.shape[1] < TARGET_LENGTH:
pad = TARGET_LENGTH - waveform.shape[1]
waveform = torch.nn.functional.pad(waveform, (pad//2, pad - pad//2))
else:
start = (waveform.shape[1] - TARGET_LENGTH) // 2
waveform = waveform[:, start:start + TARGET_LENGTH]
mel = mel_transform(waveform)
log_mel = torch.log(mel + 1e-9)
return log_mel.unsqueeze(0)
# drums
drum_samples = {
0: "kick.wav",
1: "snare.wav",
2: "hihat.wav"
}
label_names = {0: "Boom (Kick)", 1: "Kah (Snare)", 2: "Tss (Hi-hat)"}
# UI
st.title("π₯ Neural Beatbox β Live Mic + Upload")
st.markdown("**Beatbox into your mic or upload a recording** β get a clean drum version instantly!")
tab1, tab2 = st.tabs(["π€ Live Mic Recording", "π Upload File"])
with tab1:
st.markdown("Click to start/stop recording (beatbox clearly!)")
audio_info = mic_recorder(
start_prompt="π€ Start Beatboxing",
stop_prompt="βΉοΈ Stop",
format="wav",
key="live_mic"
)
if audio_info:
st.audio(audio_info['bytes'], format='audio/wav')
waveform = torch.from_numpy(np.frombuffer(audio_info['bytes'], np.int16)).float() / 32768
waveform = waveform.unsqueeze(0) # (1, samples)
orig_sr = 44100 # typical browser recording rate
with tab2:
uploaded_file = st.file_uploader("Or upload WAV/MP3", type=["wav", "mp3", "ogg"])
if uploaded_file:
st.audio(uploaded_file)
with tempfile.NamedTemporaryFile(delete=False, suffix=Path(uploaded_file.name).suffix) as tmp:
tmp.write(uploaded_file.getvalue())
waveform, orig_sr = torchaudio.load(tmp.name)
# Process if audio available
if 'waveform' in locals() and waveform is not None:
# Sliding window + energy trigger
window_samples = TARGET_LENGTH
hop_samples = TARGET_LENGTH // 2
energy_threshold = 0.02
confidence_threshold = 0.6
predictions = []
positions_ms = []
with torch.no_grad():
for start in range(0, waveform.shape[1] - window_samples + 1, hop_samples):
segment = waveform[:, start:start + window_samples]
energy = torch.mean(segment ** 2).item()
if energy > energy_threshold:
spec = preprocess_waveform(segment, orig_sr)
output = model(spec)
pred = output.argmax(dim=1).item()
conf = torch.softmax(output, dim=1).max().item()
if conf > confidence_threshold:
predictions.append(pred)
positions_ms.append(int(start / SAMPLE_RATE * 1000))
st.write(f"**Detected {len(predictions)} drum hits**")
if predictions:
seq_text = " β ".join([label_names[p] for p in predictions])
st.markdown(f"**Beat sequence:** {seq_text}")
# Build drum beat
combined = torch.zeros(1, 0)
prev_pos = 0
for i, pred in enumerate(predictions):
drum_wave, drum_sr = torchaudio.load(drum_samples[pred])
drum_wave = drum_wave.mean(dim=0, keepdim=True)
if drum_sr != SAMPLE_RATE:
drum_wave = torchaudio.transforms.Resample(drum_sr, SAMPLE_RATE)(drum_wave)
silence_ms = positions_ms[i] - prev_pos if i > 0 else positions_ms[i]
silence_samples = int(silence_ms / 1000 * SAMPLE_RATE)
silence = torch.zeros(1, silence_samples)
drum_segment = torch.cat([silence, drum_wave], dim=1)
if combined.shape[1] < drum_segment.shape[1]:
pad = drum_segment.shape[1] - combined.shape[1]
combined = torch.nn.functional.pad(combined, (0, pad))
combined += drum_segment[:, :combined.shape[1]]
prev_pos = positions_ms[i]
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as out:
torchaudio.save(out.name, combined, SAMPLE_RATE)
st.audio(out.name)
st.balloons()
st.success("Your beatbox β studio drums! π")
else:
st.info("No clear hits detected β try louder, sharper sounds!")
|