intrect commited on
Commit
27ba96f
·
verified ·
1 Parent(s): 85b8db6

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +228 -0
app.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from huggingface_hub import hf_hub_download
9
+
10
+
11
+ # ---------------------------------------------------------------------------
12
+ # Model definition (must match training code)
13
+ # ---------------------------------------------------------------------------
14
+ class NNStompGRU(nn.Module):
15
+ def __init__(self, cond_dim: int, hidden_size: int = 40):
16
+ super().__init__()
17
+ self.cond_dim = cond_dim
18
+ self.hidden_size = hidden_size
19
+ self.gru = nn.GRU(
20
+ input_size=1 + cond_dim,
21
+ hidden_size=hidden_size,
22
+ num_layers=1,
23
+ batch_first=True,
24
+ )
25
+ self.dense = nn.Linear(hidden_size, 1)
26
+ self.tanh = nn.Tanh()
27
+
28
+ def forward(self, x, cond, hidden=None):
29
+ batch, seq_len, _ = x.shape
30
+ cond_expanded = cond.unsqueeze(1).expand(-1, seq_len, -1)
31
+ inp = torch.cat([x, cond_expanded], dim=-1)
32
+ h, hidden_out = self.gru(inp, hidden)
33
+ out = self.tanh(self.dense(h))
34
+ return out, hidden_out
35
+
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # Model registry
39
+ # ---------------------------------------------------------------------------
40
+ MODELS = {
41
+ "Blackstar (Drive A/B)": {
42
+ "repo_file": "blackstar/best_model.pt",
43
+ "cond_dim": 2,
44
+ "controls": {
45
+ "Drive A": {"idx": 0, "min": 0, "max": 100, "default": 50},
46
+ "Drive B": {"idx": 1, "min": 0, "max": 100, "default": 0},
47
+ },
48
+ },
49
+ }
50
+
51
+ MODEL_REPO = "intrect/nnstomps-models"
52
+ _model_cache: dict[str, NNStompGRU] = {}
53
+
54
+
55
+ def load_model(name: str) -> NNStompGRU | None:
56
+ if name in _model_cache:
57
+ return _model_cache[name]
58
+
59
+ cfg = MODELS.get(name)
60
+ if cfg is None:
61
+ return None
62
+
63
+ local_path = hf_hub_download(
64
+ repo_id=MODEL_REPO,
65
+ filename=cfg["repo_file"],
66
+ token=os.environ.get("HF_TOKEN"),
67
+ )
68
+
69
+ ckpt = torch.load(local_path, map_location="cpu", weights_only=True)
70
+ model = NNStompGRU(ckpt["config"]["cond_dim"], ckpt["config"]["hidden_size"])
71
+ model.load_state_dict(ckpt["model_state"])
72
+ model.eval()
73
+ _model_cache[name] = model
74
+ return model
75
+
76
+
77
+ # ---------------------------------------------------------------------------
78
+ # Audio processing
79
+ # ---------------------------------------------------------------------------
80
+ def process_audio(
81
+ audio_input,
82
+ model_name: str,
83
+ param1: float,
84
+ param2: float,
85
+ mix: float,
86
+ input_gain_db: float,
87
+ ):
88
+ if audio_input is None:
89
+ return None
90
+
91
+ sr, data = audio_input
92
+
93
+ # float32
94
+ if data.dtype == np.int16:
95
+ data = data.astype(np.float32) / 32768.0
96
+ elif data.dtype == np.int32:
97
+ data = data.astype(np.float32) / 2147483648.0
98
+ elif data.dtype != np.float32:
99
+ data = data.astype(np.float32)
100
+
101
+ # stereo -> mono
102
+ if data.ndim == 2:
103
+ mono = data.mean(axis=1) if data.shape[1] <= 2 else data.mean(axis=0)
104
+ else:
105
+ mono = data
106
+
107
+ # input gain
108
+ gain = 10 ** (input_gain_db / 20.0)
109
+ mono = mono * gain
110
+
111
+ model = load_model(model_name)
112
+ if model is None:
113
+ return (sr, mono)
114
+
115
+ cfg = MODELS[model_name]
116
+ controls = cfg["controls"]
117
+
118
+ # build condition vector
119
+ cond = [0.0] * cfg["cond_dim"]
120
+ ctrl_list = list(controls.values())
121
+
122
+ if len(ctrl_list) >= 1:
123
+ c = ctrl_list[0]
124
+ cond[c["idx"]] = (param1 - c["min"]) / (c["max"] - c["min"])
125
+ if len(ctrl_list) >= 2:
126
+ c = ctrl_list[1]
127
+ cond[c["idx"]] = (param2 - c["min"]) / (c["max"] - c["min"])
128
+
129
+ # GRU inference (chunked)
130
+ chunk_size = 8192
131
+ output = np.zeros_like(mono)
132
+ hidden = None
133
+
134
+ with torch.no_grad():
135
+ cond_t = torch.tensor([cond], dtype=torch.float32)
136
+ for start in range(0, len(mono), chunk_size):
137
+ end = min(start + chunk_size, len(mono))
138
+ chunk = mono[start:end]
139
+ x = torch.from_numpy(chunk).unsqueeze(0).unsqueeze(-1)
140
+ pred, hidden = model(x, cond_t, hidden)
141
+ output[start:end] = pred[0, :, 0].numpy()
142
+
143
+ # dry/wet mix
144
+ wet = mono * (1 - mix) + output * mix
145
+
146
+ peak = np.max(np.abs(wet))
147
+ if peak > 0.99:
148
+ wet = wet * (0.99 / peak)
149
+
150
+ return (sr, wet.astype(np.float32))
151
+
152
+
153
+ def update_controls(model_name: str):
154
+ cfg = MODELS.get(model_name, {})
155
+ controls = cfg.get("controls", {})
156
+ ctrl_list = list(controls.items())
157
+
158
+ if len(ctrl_list) >= 1:
159
+ name1, c1 = ctrl_list[0]
160
+ p1_update = gr.update(
161
+ label=name1, minimum=c1["min"], maximum=c1["max"],
162
+ value=c1["default"], visible=True,
163
+ )
164
+ else:
165
+ p1_update = gr.update(visible=False)
166
+
167
+ if len(ctrl_list) >= 2:
168
+ name2, c2 = ctrl_list[1]
169
+ p2_update = gr.update(
170
+ label=name2, minimum=c2["min"], maximum=c2["max"],
171
+ value=c2["default"], visible=True,
172
+ )
173
+ else:
174
+ p2_update = gr.update(visible=False, value=0)
175
+
176
+ return p1_update, p2_update
177
+
178
+
179
+ # ---------------------------------------------------------------------------
180
+ # UI
181
+ # ---------------------------------------------------------------------------
182
+ with gr.Blocks(
183
+ title="NNStomps — Neural Drive",
184
+ theme=gr.themes.Soft(primary_hue="orange"),
185
+ ) as demo:
186
+ gr.Markdown(
187
+ "# NNStomps — Neural Drive\n"
188
+ "GRU neural network based saturation/distortion. "
189
+ "Upload audio and tweak the drive to hear the neural model in action."
190
+ )
191
+
192
+ with gr.Row():
193
+ with gr.Column(scale=1):
194
+ model_sel = gr.Dropdown(
195
+ choices=list(MODELS.keys()),
196
+ value=list(MODELS.keys())[0],
197
+ label="Model",
198
+ )
199
+ param1 = gr.Slider(
200
+ minimum=0, maximum=100, value=50, step=1, label="Drive A",
201
+ )
202
+ param2 = gr.Slider(
203
+ minimum=0, maximum=100, value=0, step=1, label="Drive B",
204
+ )
205
+ input_gain = gr.Slider(
206
+ minimum=-12, maximum=12, value=0, step=0.5,
207
+ label="Input Gain (dB)",
208
+ )
209
+ mix_slider = gr.Slider(
210
+ minimum=0, maximum=1.0, value=1.0, step=0.05,
211
+ label="Dry/Wet Mix",
212
+ )
213
+ process_btn = gr.Button("Process", variant="primary", size="lg")
214
+
215
+ with gr.Column(scale=2):
216
+ audio_in = gr.Audio(label="Input Audio", type="numpy")
217
+ audio_out = gr.Audio(label="Output Audio", type="numpy")
218
+
219
+ model_sel.change(
220
+ fn=update_controls, inputs=[model_sel], outputs=[param1, param2],
221
+ )
222
+ process_btn.click(
223
+ fn=process_audio,
224
+ inputs=[audio_in, model_sel, param1, param2, mix_slider, input_gain],
225
+ outputs=[audio_out],
226
+ )
227
+
228
+ demo.launch()