File size: 6,622 Bytes
92ca38c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import streamlit as st
st.set_page_config(page_title="Neural Beatbox Live", page_icon="πŸ₯")

import torch
import torchaudio
import numpy as np
from pathlib import Path
import tempfile
from streamlit_mic_recorder import mic_recorder

SAMPLE_RATE = 22050
TARGET_LENGTH = 17640
N_FFT = 1024
HOP_LENGTH = 256
N_MELS = 64
F_MAX = 8000

mel_transform = torchaudio.transforms.MelSpectrogram(
    sample_rate=SAMPLE_RATE,
    n_fft=N_FFT,
    hop_length=HOP_LENGTH,
    n_mels=N_MELS,
    f_min=0,
    f_max=F_MAX,
    window_fn=torch.hann_window,
    power=2.0,
)

# model
class NeuralBeatboxCNN(torch.nn.Module):
    def __init__(self, num_classes=3):
        super().__init__()
        
        def dw_block(in_ch, out_ch, stride=1):
            return torch.nn.Sequential(
                torch.nn.Conv2d(in_ch, in_ch, 3, stride=stride, padding=1, groups=in_ch, bias=False),
                torch.nn.BatchNorm2d(in_ch),
                torch.nn.ReLU6(inplace=True),
                torch.nn.Conv2d(in_ch, out_ch, 1, bias=False),
                torch.nn.BatchNorm2d(out_ch),
                torch.nn.ReLU6(inplace=True),
            )
        
        self.features = torch.nn.Sequential(
            torch.nn.Conv2d(1, 32, 3, stride=1, padding=1, bias=False),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU6(inplace=True),
            dw_block(32, 64),
            dw_block(64, 128),
            dw_block(128, 128),
            dw_block(128, 256),
            dw_block(256, 256),
        )
        
        self.classifier = torch.nn.Sequential(
            torch.nn.AdaptiveAvgPool2d(1),
            torch.nn.Flatten(),
            torch.nn.Dropout(0.3),
            torch.nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

@st.cache_resource
def load_model():
    model = NeuralBeatboxCNN()
    state_path = Path("state_dict.pth")
    if not state_path.exists():
        st.error("Model weights missing! Add state_dict.pth to repo.")
        st.stop()
    model.load_state_dict(torch.load(state_path, map_location="cpu"))
    model.eval()
    return model

model = load_model()

# preprocessing
def preprocess_waveform(waveform: torch.Tensor, orig_sr: int) -> torch.Tensor:
    if orig_sr != SAMPLE_RATE:
        waveform = torchaudio.transforms.Resample(orig_sr, SAMPLE_RATE)(waveform)
    waveform = waveform.mean(dim=0, keepdim=True)
    if waveform.shape[1] < TARGET_LENGTH:
        pad = TARGET_LENGTH - waveform.shape[1]
        waveform = torch.nn.functional.pad(waveform, (pad//2, pad - pad//2))
    else:
        start = (waveform.shape[1] - TARGET_LENGTH) // 2
        waveform = waveform[:, start:start + TARGET_LENGTH]
    mel = mel_transform(waveform)
    log_mel = torch.log(mel + 1e-9)
    return log_mel.unsqueeze(0)

# drums
drum_samples = {
    0: "kick.wav",
    1: "snare.wav",
    2: "hihat.wav"
}

label_names = {0: "Boom (Kick)", 1: "Kah (Snare)", 2: "Tss (Hi-hat)"}

# UI
st.title("πŸ₯ Neural Beatbox β€” Live Mic + Upload")
st.markdown("**Beatbox into your mic or upload a recording** β†’ get a clean drum version instantly!")

tab1, tab2 = st.tabs(["🎀 Live Mic Recording", "πŸ“ Upload File"])

with tab1:
    st.markdown("Click to start/stop recording (beatbox clearly!)")
    audio_info = mic_recorder(
        start_prompt="🎀 Start Beatboxing",
        stop_prompt="⏹️ Stop",
        format="wav",
        key="live_mic"
    )
    
    if audio_info:
        st.audio(audio_info['bytes'], format='audio/wav')
        waveform = torch.from_numpy(np.frombuffer(audio_info['bytes'], np.int16)).float() / 32768
        waveform = waveform.unsqueeze(0)  # (1, samples)
        orig_sr = 44100  # typical browser recording rate

with tab2:
    uploaded_file = st.file_uploader("Or upload WAV/MP3", type=["wav", "mp3", "ogg"])
    if uploaded_file:
        st.audio(uploaded_file)
        with tempfile.NamedTemporaryFile(delete=False, suffix=Path(uploaded_file.name).suffix) as tmp:
            tmp.write(uploaded_file.getvalue())
            waveform, orig_sr = torchaudio.load(tmp.name)

# Process if audio available
if 'waveform' in locals() and waveform is not None:
    # Sliding window + energy trigger
    window_samples = TARGET_LENGTH
    hop_samples = TARGET_LENGTH // 2
    energy_threshold = 0.02
    confidence_threshold = 0.6

    predictions = []
    positions_ms = []

    with torch.no_grad():
        for start in range(0, waveform.shape[1] - window_samples + 1, hop_samples):
            segment = waveform[:, start:start + window_samples]
            energy = torch.mean(segment ** 2).item()
            if energy > energy_threshold:
                spec = preprocess_waveform(segment, orig_sr)
                output = model(spec)
                pred = output.argmax(dim=1).item()
                conf = torch.softmax(output, dim=1).max().item()
                if conf > confidence_threshold:
                    predictions.append(pred)
                    positions_ms.append(int(start / SAMPLE_RATE * 1000))

    st.write(f"**Detected {len(predictions)} drum hits**")
    if predictions:
        seq_text = " β†’ ".join([label_names[p] for p in predictions])
        st.markdown(f"**Beat sequence:** {seq_text}")

        # Build drum beat
        combined = torch.zeros(1, 0)
        prev_pos = 0
        for i, pred in enumerate(predictions):
            drum_wave, drum_sr = torchaudio.load(drum_samples[pred])
            drum_wave = drum_wave.mean(dim=0, keepdim=True)
            if drum_sr != SAMPLE_RATE:
                drum_wave = torchaudio.transforms.Resample(drum_sr, SAMPLE_RATE)(drum_wave)
            
            silence_ms = positions_ms[i] - prev_pos if i > 0 else positions_ms[i]
            silence_samples = int(silence_ms / 1000 * SAMPLE_RATE)
            silence = torch.zeros(1, silence_samples)
            
            drum_segment = torch.cat([silence, drum_wave], dim=1)
            if combined.shape[1] < drum_segment.shape[1]:
                pad = drum_segment.shape[1] - combined.shape[1]
                combined = torch.nn.functional.pad(combined, (0, pad))
            combined += drum_segment[:, :combined.shape[1]]
            
            prev_pos = positions_ms[i]

        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as out:
            torchaudio.save(out.name, combined, SAMPLE_RATE)
            st.audio(out.name)
            st.balloons()
            st.success("Your beatbox β†’ studio drums! πŸŽ‰")
    else:
        st.info("No clear hits detected β€” try louder, sharper sounds!")