| import os |
| from pathlib import Path |
|
|
| import gradio as gr |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from huggingface_hub import hf_hub_download |
|
|
|
|
| |
| |
| |
| class NNStompGRU(nn.Module): |
| def __init__(self, cond_dim: int, hidden_size: int = 40): |
| super().__init__() |
| self.cond_dim = cond_dim |
| self.hidden_size = hidden_size |
| self.gru = nn.GRU( |
| input_size=1 + cond_dim, |
| hidden_size=hidden_size, |
| num_layers=1, |
| batch_first=True, |
| ) |
| self.dense = nn.Linear(hidden_size, 1) |
| self.tanh = nn.Tanh() |
|
|
| def forward(self, x, cond, hidden=None): |
| batch, seq_len, _ = x.shape |
| cond_expanded = cond.unsqueeze(1).expand(-1, seq_len, -1) |
| inp = torch.cat([x, cond_expanded], dim=-1) |
| h, hidden_out = self.gru(inp, hidden) |
| out = self.tanh(self.dense(h)) |
| return out, hidden_out |
|
|
|
|
| |
| |
| |
| MODELS = { |
| "Blackstar (Drive A/B)": { |
| "repo_file": "blackstar/best_model.pt", |
| "cond_dim": 2, |
| "controls": { |
| "Drive A": {"idx": 0, "min": 0, "max": 100, "default": 50}, |
| "Drive B": {"idx": 1, "min": 0, "max": 100, "default": 0}, |
| }, |
| }, |
| } |
|
|
| MODEL_REPO = "intrect/nnstomps-models" |
| _model_cache: dict[str, NNStompGRU] = {} |
|
|
|
|
| def load_model(name: str) -> NNStompGRU | None: |
| if name in _model_cache: |
| return _model_cache[name] |
|
|
| cfg = MODELS.get(name) |
| if cfg is None: |
| return None |
|
|
| local_path = hf_hub_download( |
| repo_id=MODEL_REPO, |
| filename=cfg["repo_file"], |
| token=os.environ.get("HF_TOKEN"), |
| ) |
|
|
| ckpt = torch.load(local_path, map_location="cpu", weights_only=True) |
| model = NNStompGRU(ckpt["config"]["cond_dim"], ckpt["config"]["hidden_size"]) |
| model.load_state_dict(ckpt["model_state"]) |
| model.eval() |
| _model_cache[name] = model |
| return model |
|
|
|
|
| |
| |
| |
| def process_audio( |
| audio_input, |
| model_name: str, |
| param1: float, |
| param2: float, |
| mix: float, |
| input_gain_db: float, |
| ): |
| if audio_input is None: |
| return None |
|
|
| sr, data = audio_input |
|
|
| |
| if data.dtype == np.int16: |
| data = data.astype(np.float32) / 32768.0 |
| elif data.dtype == np.int32: |
| data = data.astype(np.float32) / 2147483648.0 |
| elif data.dtype != np.float32: |
| data = data.astype(np.float32) |
|
|
| |
| if data.ndim == 2: |
| mono = data.mean(axis=1) if data.shape[1] <= 2 else data.mean(axis=0) |
| else: |
| mono = data |
|
|
| |
| gain = 10 ** (input_gain_db / 20.0) |
| mono = mono * gain |
|
|
| model = load_model(model_name) |
| if model is None: |
| return (sr, mono) |
|
|
| cfg = MODELS[model_name] |
| controls = cfg["controls"] |
|
|
| |
| cond = [0.0] * cfg["cond_dim"] |
| ctrl_list = list(controls.values()) |
|
|
| if len(ctrl_list) >= 1: |
| c = ctrl_list[0] |
| cond[c["idx"]] = (param1 - c["min"]) / (c["max"] - c["min"]) |
| if len(ctrl_list) >= 2: |
| c = ctrl_list[1] |
| cond[c["idx"]] = (param2 - c["min"]) / (c["max"] - c["min"]) |
|
|
| |
| chunk_size = 8192 |
| output = np.zeros_like(mono) |
| hidden = None |
|
|
| with torch.no_grad(): |
| cond_t = torch.tensor([cond], dtype=torch.float32) |
| for start in range(0, len(mono), chunk_size): |
| end = min(start + chunk_size, len(mono)) |
| chunk = mono[start:end] |
| x = torch.from_numpy(chunk).unsqueeze(0).unsqueeze(-1) |
| pred, hidden = model(x, cond_t, hidden) |
| output[start:end] = pred[0, :, 0].numpy() |
|
|
| |
| wet = mono * (1 - mix) + output * mix |
|
|
| peak = np.max(np.abs(wet)) |
| if peak > 0.99: |
| wet = wet * (0.99 / peak) |
|
|
| return (sr, wet.astype(np.float32)) |
|
|
|
|
| def update_controls(model_name: str): |
| cfg = MODELS.get(model_name, {}) |
| controls = cfg.get("controls", {}) |
| ctrl_list = list(controls.items()) |
|
|
| if len(ctrl_list) >= 1: |
| name1, c1 = ctrl_list[0] |
| p1_update = gr.update( |
| label=name1, minimum=c1["min"], maximum=c1["max"], |
| value=c1["default"], visible=True, |
| ) |
| else: |
| p1_update = gr.update(visible=False) |
|
|
| if len(ctrl_list) >= 2: |
| name2, c2 = ctrl_list[1] |
| p2_update = gr.update( |
| label=name2, minimum=c2["min"], maximum=c2["max"], |
| value=c2["default"], visible=True, |
| ) |
| else: |
| p2_update = gr.update(visible=False, value=0) |
|
|
| return p1_update, p2_update |
|
|
|
|
| |
| |
| |
| with gr.Blocks( |
| title="NNStomps — Neural Drive", |
| theme=gr.themes.Soft(primary_hue="orange"), |
| ) as demo: |
| gr.Markdown( |
| "# NNStomps — Neural Drive\n" |
| "GRU neural network based saturation/distortion. " |
| "Upload audio and tweak the drive to hear the neural model in action." |
| ) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| model_sel = gr.Dropdown( |
| choices=list(MODELS.keys()), |
| value=list(MODELS.keys())[0], |
| label="Model", |
| ) |
| param1 = gr.Slider( |
| minimum=0, maximum=100, value=50, step=1, label="Drive A", |
| ) |
| param2 = gr.Slider( |
| minimum=0, maximum=100, value=0, step=1, label="Drive B", |
| ) |
| input_gain = gr.Slider( |
| minimum=-12, maximum=12, value=0, step=0.5, |
| label="Input Gain (dB)", |
| ) |
| mix_slider = gr.Slider( |
| minimum=0, maximum=1.0, value=1.0, step=0.05, |
| label="Dry/Wet Mix", |
| ) |
| process_btn = gr.Button("Process", variant="primary", size="lg") |
|
|
| with gr.Column(scale=2): |
| audio_in = gr.Audio(label="Input Audio", type="numpy") |
| audio_out = gr.Audio(label="Output Audio", type="numpy") |
|
|
| model_sel.change( |
| fn=update_controls, inputs=[model_sel], outputs=[param1, param2], |
| ) |
| process_btn.click( |
| fn=process_audio, |
| inputs=[audio_in, model_sel, param1, param2, mix_slider, input_gain], |
| outputs=[audio_out], |
| ) |
|
|
| demo.launch() |
|
|