AKMESSI commited on
Commit
92ca38c
Β·
verified Β·
1 Parent(s): 7feb0e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -189
app.py CHANGED
@@ -1,189 +1,189 @@
1
- import streamlit as st
2
- st.set_page_config(page_title="Neural Beatbox Live", page_icon="πŸ₯")
3
-
4
- import torch
5
- import torchaudio
6
- import numpy as np
7
- from pathlib import Path
8
- import tempfile
9
- from streamlit_mic_recorder import mic_recorder
10
-
11
- SAMPLE_RATE = 22050
12
- TARGET_LENGTH = 17640
13
- N_FFT = 1024
14
- HOP_LENGTH = 256
15
- N_MELS = 64
16
- F_MAX = 8000
17
-
18
- mel_transform = torchaudio.transforms.MelSpectrogram(
19
- sample_rate=SAMPLE_RATE,
20
- n_fft=N_FFT,
21
- hop_length=HOP_LENGTH,
22
- n_mels=N_MELS,
23
- f_min=0,
24
- f_max=F_MAX,
25
- window_fn=torch.hann_window,
26
- power=2.0,
27
- )
28
-
29
- # model
30
- class NeuralBeatboxCNN(torch.nn.Module):
31
- def __init__(self, num_classes=3):
32
- super().__init__()
33
-
34
- def dw_block(in_ch, out_ch, stride=1):
35
- return torch.nn.Sequential(
36
- torch.nn.Conv2d(in_ch, in_ch, 3, stride=stride, padding=1, groups=in_ch, bias=False),
37
- torch.nn.BatchNorm2d(in_ch),
38
- torch.nn.ReLU6(inplace=True),
39
- torch.nn.Conv2d(in_ch, out_ch, 1, bias=False),
40
- torch.nn.BatchNorm2d(out_ch),
41
- torch.nn.ReLU6(inplace=True),
42
- )
43
-
44
- self.features = torch.nn.Sequential(
45
- torch.nn.Conv2d(1, 32, 3, stride=1, padding=1, bias=False),
46
- torch.nn.BatchNorm2d(32),
47
- torch.nn.ReLU6(inplace=True),
48
- dw_block(32, 64),
49
- dw_block(64, 128),
50
- dw_block(128, 128),
51
- dw_block(128, 256),
52
- dw_block(256, 256),
53
- )
54
-
55
- self.classifier = torch.nn.Sequential(
56
- torch.nn.AdaptiveAvgPool2d(1),
57
- torch.nn.Flatten(),
58
- torch.nn.Dropout(0.3),
59
- torch.nn.Linear(256, num_classes)
60
- )
61
-
62
- def forward(self, x):
63
- x = self.features(x)
64
- x = self.classifier(x)
65
- return x
66
-
67
- @st.cache_resource
68
- def load_model():
69
- model = NeuralBeatboxCNN()
70
- state_path = Path("state_dict.pth")
71
- if not state_path.exists():
72
- st.error("Model weights missing! Add state_dict.pth to repo.")
73
- st.stop()
74
- model.load_state_dict(torch.load(state_path, map_location="cpu"))
75
- model.eval()
76
- return model
77
-
78
- model = load_model()
79
-
80
- # preprocessing
81
- def preprocess_waveform(waveform: torch.Tensor, orig_sr: int) -> torch.Tensor:
82
- if orig_sr != SAMPLE_RATE:
83
- waveform = torchaudio.transforms.Resample(orig_sr, SAMPLE_RATE)(waveform)
84
- waveform = waveform.mean(dim=0, keepdim=True)
85
- if waveform.shape[1] < TARGET_LENGTH:
86
- pad = TARGET_LENGTH - waveform.shape[1]
87
- waveform = torch.nn.functional.pad(waveform, (pad//2, pad - pad//2))
88
- else:
89
- start = (waveform.shape[1] - TARGET_LENGTH) // 2
90
- waveform = waveform[:, start:start + TARGET_LENGTH]
91
- mel = mel_transform(waveform)
92
- log_mel = torch.log(mel + 1e-9)
93
- return log_mel.unsqueeze(0)
94
-
95
- # drums
96
- drum_samples = {
97
- 0: "drum_samples/kick.wav",
98
- 1: "drum_samples/snare.wav",
99
- 2: "drum_samples/hihat.wav"
100
- }
101
-
102
- label_names = {0: "Boom (Kick)", 1: "Kah (Snare)", 2: "Tss (Hi-hat)"}
103
-
104
- # UI
105
- st.title("πŸ₯ Neural Beatbox β€” Live Mic + Upload")
106
- st.markdown("**Beatbox into your mic or upload a recording** β†’ get a clean drum version instantly!")
107
-
108
- tab1, tab2 = st.tabs(["🎀 Live Mic Recording", "πŸ“ Upload File"])
109
-
110
- with tab1:
111
- st.markdown("Click to start/stop recording (beatbox clearly!)")
112
- audio_info = mic_recorder(
113
- start_prompt="🎀 Start Beatboxing",
114
- stop_prompt="⏹️ Stop",
115
- format="wav",
116
- key="live_mic"
117
- )
118
-
119
- if audio_info:
120
- st.audio(audio_info['bytes'], format='audio/wav')
121
- waveform = torch.from_numpy(np.frombuffer(audio_info['bytes'], np.int16)).float() / 32768
122
- waveform = waveform.unsqueeze(0) # (1, samples)
123
- orig_sr = 44100 # typical browser recording rate
124
-
125
- with tab2:
126
- uploaded_file = st.file_uploader("Or upload WAV/MP3", type=["wav", "mp3", "ogg"])
127
- if uploaded_file:
128
- st.audio(uploaded_file)
129
- with tempfile.NamedTemporaryFile(delete=False, suffix=Path(uploaded_file.name).suffix) as tmp:
130
- tmp.write(uploaded_file.getvalue())
131
- waveform, orig_sr = torchaudio.load(tmp.name)
132
-
133
- # Process if audio available
134
- if 'waveform' in locals() and waveform is not None:
135
- # Sliding window + energy trigger
136
- window_samples = TARGET_LENGTH
137
- hop_samples = TARGET_LENGTH // 2
138
- energy_threshold = 0.02
139
- confidence_threshold = 0.6
140
-
141
- predictions = []
142
- positions_ms = []
143
-
144
- with torch.no_grad():
145
- for start in range(0, waveform.shape[1] - window_samples + 1, hop_samples):
146
- segment = waveform[:, start:start + window_samples]
147
- energy = torch.mean(segment ** 2).item()
148
- if energy > energy_threshold:
149
- spec = preprocess_waveform(segment, orig_sr)
150
- output = model(spec)
151
- pred = output.argmax(dim=1).item()
152
- conf = torch.softmax(output, dim=1).max().item()
153
- if conf > confidence_threshold:
154
- predictions.append(pred)
155
- positions_ms.append(int(start / SAMPLE_RATE * 1000))
156
-
157
- st.write(f"**Detected {len(predictions)} drum hits**")
158
- if predictions:
159
- seq_text = " β†’ ".join([label_names[p] for p in predictions])
160
- st.markdown(f"**Beat sequence:** {seq_text}")
161
-
162
- # Build drum beat
163
- combined = torch.zeros(1, 0)
164
- prev_pos = 0
165
- for i, pred in enumerate(predictions):
166
- drum_wave, drum_sr = torchaudio.load(drum_samples[pred])
167
- drum_wave = drum_wave.mean(dim=0, keepdim=True)
168
- if drum_sr != SAMPLE_RATE:
169
- drum_wave = torchaudio.transforms.Resample(drum_sr, SAMPLE_RATE)(drum_wave)
170
-
171
- silence_ms = positions_ms[i] - prev_pos if i > 0 else positions_ms[i]
172
- silence_samples = int(silence_ms / 1000 * SAMPLE_RATE)
173
- silence = torch.zeros(1, silence_samples)
174
-
175
- drum_segment = torch.cat([silence, drum_wave], dim=1)
176
- if combined.shape[1] < drum_segment.shape[1]:
177
- pad = drum_segment.shape[1] - combined.shape[1]
178
- combined = torch.nn.functional.pad(combined, (0, pad))
179
- combined += drum_segment[:, :combined.shape[1]]
180
-
181
- prev_pos = positions_ms[i]
182
-
183
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as out:
184
- torchaudio.save(out.name, combined, SAMPLE_RATE)
185
- st.audio(out.name)
186
- st.balloons()
187
- st.success("Your beatbox β†’ studio drums! πŸŽ‰")
188
- else:
189
- st.info("No clear hits detected β€” try louder, sharper sounds!")
 
1
+ import streamlit as st
2
+ st.set_page_config(page_title="Neural Beatbox Live", page_icon="πŸ₯")
3
+
4
+ import torch
5
+ import torchaudio
6
+ import numpy as np
7
+ from pathlib import Path
8
+ import tempfile
9
+ from streamlit_mic_recorder import mic_recorder
10
+
11
+ SAMPLE_RATE = 22050
12
+ TARGET_LENGTH = 17640
13
+ N_FFT = 1024
14
+ HOP_LENGTH = 256
15
+ N_MELS = 64
16
+ F_MAX = 8000
17
+
18
+ mel_transform = torchaudio.transforms.MelSpectrogram(
19
+ sample_rate=SAMPLE_RATE,
20
+ n_fft=N_FFT,
21
+ hop_length=HOP_LENGTH,
22
+ n_mels=N_MELS,
23
+ f_min=0,
24
+ f_max=F_MAX,
25
+ window_fn=torch.hann_window,
26
+ power=2.0,
27
+ )
28
+
29
+ # model
30
+ class NeuralBeatboxCNN(torch.nn.Module):
31
+ def __init__(self, num_classes=3):
32
+ super().__init__()
33
+
34
+ def dw_block(in_ch, out_ch, stride=1):
35
+ return torch.nn.Sequential(
36
+ torch.nn.Conv2d(in_ch, in_ch, 3, stride=stride, padding=1, groups=in_ch, bias=False),
37
+ torch.nn.BatchNorm2d(in_ch),
38
+ torch.nn.ReLU6(inplace=True),
39
+ torch.nn.Conv2d(in_ch, out_ch, 1, bias=False),
40
+ torch.nn.BatchNorm2d(out_ch),
41
+ torch.nn.ReLU6(inplace=True),
42
+ )
43
+
44
+ self.features = torch.nn.Sequential(
45
+ torch.nn.Conv2d(1, 32, 3, stride=1, padding=1, bias=False),
46
+ torch.nn.BatchNorm2d(32),
47
+ torch.nn.ReLU6(inplace=True),
48
+ dw_block(32, 64),
49
+ dw_block(64, 128),
50
+ dw_block(128, 128),
51
+ dw_block(128, 256),
52
+ dw_block(256, 256),
53
+ )
54
+
55
+ self.classifier = torch.nn.Sequential(
56
+ torch.nn.AdaptiveAvgPool2d(1),
57
+ torch.nn.Flatten(),
58
+ torch.nn.Dropout(0.3),
59
+ torch.nn.Linear(256, num_classes)
60
+ )
61
+
62
+ def forward(self, x):
63
+ x = self.features(x)
64
+ x = self.classifier(x)
65
+ return x
66
+
67
+ @st.cache_resource
68
+ def load_model():
69
+ model = NeuralBeatboxCNN()
70
+ state_path = Path("state_dict.pth")
71
+ if not state_path.exists():
72
+ st.error("Model weights missing! Add state_dict.pth to repo.")
73
+ st.stop()
74
+ model.load_state_dict(torch.load(state_path, map_location="cpu"))
75
+ model.eval()
76
+ return model
77
+
78
+ model = load_model()
79
+
80
+ # preprocessing
81
+ def preprocess_waveform(waveform: torch.Tensor, orig_sr: int) -> torch.Tensor:
82
+ if orig_sr != SAMPLE_RATE:
83
+ waveform = torchaudio.transforms.Resample(orig_sr, SAMPLE_RATE)(waveform)
84
+ waveform = waveform.mean(dim=0, keepdim=True)
85
+ if waveform.shape[1] < TARGET_LENGTH:
86
+ pad = TARGET_LENGTH - waveform.shape[1]
87
+ waveform = torch.nn.functional.pad(waveform, (pad//2, pad - pad//2))
88
+ else:
89
+ start = (waveform.shape[1] - TARGET_LENGTH) // 2
90
+ waveform = waveform[:, start:start + TARGET_LENGTH]
91
+ mel = mel_transform(waveform)
92
+ log_mel = torch.log(mel + 1e-9)
93
+ return log_mel.unsqueeze(0)
94
+
95
+ # drums
96
+ drum_samples = {
97
+ 0: "kick.wav",
98
+ 1: "snare.wav",
99
+ 2: "hihat.wav"
100
+ }
101
+
102
+ label_names = {0: "Boom (Kick)", 1: "Kah (Snare)", 2: "Tss (Hi-hat)"}
103
+
104
+ # UI
105
+ st.title("πŸ₯ Neural Beatbox β€” Live Mic + Upload")
106
+ st.markdown("**Beatbox into your mic or upload a recording** β†’ get a clean drum version instantly!")
107
+
108
+ tab1, tab2 = st.tabs(["🎀 Live Mic Recording", "πŸ“ Upload File"])
109
+
110
+ with tab1:
111
+ st.markdown("Click to start/stop recording (beatbox clearly!)")
112
+ audio_info = mic_recorder(
113
+ start_prompt="🎀 Start Beatboxing",
114
+ stop_prompt="⏹️ Stop",
115
+ format="wav",
116
+ key="live_mic"
117
+ )
118
+
119
+ if audio_info:
120
+ st.audio(audio_info['bytes'], format='audio/wav')
121
+ waveform = torch.from_numpy(np.frombuffer(audio_info['bytes'], np.int16)).float() / 32768
122
+ waveform = waveform.unsqueeze(0) # (1, samples)
123
+ orig_sr = 44100 # typical browser recording rate
124
+
125
+ with tab2:
126
+ uploaded_file = st.file_uploader("Or upload WAV/MP3", type=["wav", "mp3", "ogg"])
127
+ if uploaded_file:
128
+ st.audio(uploaded_file)
129
+ with tempfile.NamedTemporaryFile(delete=False, suffix=Path(uploaded_file.name).suffix) as tmp:
130
+ tmp.write(uploaded_file.getvalue())
131
+ waveform, orig_sr = torchaudio.load(tmp.name)
132
+
133
+ # Process if audio available
134
+ if 'waveform' in locals() and waveform is not None:
135
+ # Sliding window + energy trigger
136
+ window_samples = TARGET_LENGTH
137
+ hop_samples = TARGET_LENGTH // 2
138
+ energy_threshold = 0.02
139
+ confidence_threshold = 0.6
140
+
141
+ predictions = []
142
+ positions_ms = []
143
+
144
+ with torch.no_grad():
145
+ for start in range(0, waveform.shape[1] - window_samples + 1, hop_samples):
146
+ segment = waveform[:, start:start + window_samples]
147
+ energy = torch.mean(segment ** 2).item()
148
+ if energy > energy_threshold:
149
+ spec = preprocess_waveform(segment, orig_sr)
150
+ output = model(spec)
151
+ pred = output.argmax(dim=1).item()
152
+ conf = torch.softmax(output, dim=1).max().item()
153
+ if conf > confidence_threshold:
154
+ predictions.append(pred)
155
+ positions_ms.append(int(start / SAMPLE_RATE * 1000))
156
+
157
+ st.write(f"**Detected {len(predictions)} drum hits**")
158
+ if predictions:
159
+ seq_text = " β†’ ".join([label_names[p] for p in predictions])
160
+ st.markdown(f"**Beat sequence:** {seq_text}")
161
+
162
+ # Build drum beat
163
+ combined = torch.zeros(1, 0)
164
+ prev_pos = 0
165
+ for i, pred in enumerate(predictions):
166
+ drum_wave, drum_sr = torchaudio.load(drum_samples[pred])
167
+ drum_wave = drum_wave.mean(dim=0, keepdim=True)
168
+ if drum_sr != SAMPLE_RATE:
169
+ drum_wave = torchaudio.transforms.Resample(drum_sr, SAMPLE_RATE)(drum_wave)
170
+
171
+ silence_ms = positions_ms[i] - prev_pos if i > 0 else positions_ms[i]
172
+ silence_samples = int(silence_ms / 1000 * SAMPLE_RATE)
173
+ silence = torch.zeros(1, silence_samples)
174
+
175
+ drum_segment = torch.cat([silence, drum_wave], dim=1)
176
+ if combined.shape[1] < drum_segment.shape[1]:
177
+ pad = drum_segment.shape[1] - combined.shape[1]
178
+ combined = torch.nn.functional.pad(combined, (0, pad))
179
+ combined += drum_segment[:, :combined.shape[1]]
180
+
181
+ prev_pos = positions_ms[i]
182
+
183
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as out:
184
+ torchaudio.save(out.name, combined, SAMPLE_RATE)
185
+ st.audio(out.name)
186
+ st.balloons()
187
+ st.success("Your beatbox β†’ studio drums! πŸŽ‰")
188
+ else:
189
+ st.info("No clear hits detected β€” try louder, sharper sounds!")