import os from pathlib import Path import gradio as gr import numpy as np import torch import torch.nn as nn from huggingface_hub import hf_hub_download # --------------------------------------------------------------------------- # Model definition (must match training code) # --------------------------------------------------------------------------- class NNStompGRU(nn.Module): def __init__(self, cond_dim: int, hidden_size: int = 40): super().__init__() self.cond_dim = cond_dim self.hidden_size = hidden_size self.gru = nn.GRU( input_size=1 + cond_dim, hidden_size=hidden_size, num_layers=1, batch_first=True, ) self.dense = nn.Linear(hidden_size, 1) self.tanh = nn.Tanh() def forward(self, x, cond, hidden=None): batch, seq_len, _ = x.shape cond_expanded = cond.unsqueeze(1).expand(-1, seq_len, -1) inp = torch.cat([x, cond_expanded], dim=-1) h, hidden_out = self.gru(inp, hidden) out = self.tanh(self.dense(h)) return out, hidden_out # --------------------------------------------------------------------------- # Model registry # --------------------------------------------------------------------------- MODELS = { "Blackstar (Drive A/B)": { "repo_file": "blackstar/best_model.pt", "cond_dim": 2, "controls": { "Drive A": {"idx": 0, "min": 0, "max": 100, "default": 50}, "Drive B": {"idx": 1, "min": 0, "max": 100, "default": 0}, }, }, } MODEL_REPO = "intrect/nnstomps-models" _model_cache: dict[str, NNStompGRU] = {} def load_model(name: str) -> NNStompGRU | None: if name in _model_cache: return _model_cache[name] cfg = MODELS.get(name) if cfg is None: return None local_path = hf_hub_download( repo_id=MODEL_REPO, filename=cfg["repo_file"], token=os.environ.get("HF_TOKEN"), ) ckpt = torch.load(local_path, map_location="cpu", weights_only=True) model = NNStompGRU(ckpt["config"]["cond_dim"], ckpt["config"]["hidden_size"]) model.load_state_dict(ckpt["model_state"]) model.eval() _model_cache[name] = model return model # --------------------------------------------------------------------------- # Audio processing # --------------------------------------------------------------------------- def process_audio( audio_input, model_name: str, param1: float, param2: float, mix: float, input_gain_db: float, ): if audio_input is None: return None sr, data = audio_input # float32 if data.dtype == np.int16: data = data.astype(np.float32) / 32768.0 elif data.dtype == np.int32: data = data.astype(np.float32) / 2147483648.0 elif data.dtype != np.float32: data = data.astype(np.float32) # stereo -> mono if data.ndim == 2: mono = data.mean(axis=1) if data.shape[1] <= 2 else data.mean(axis=0) else: mono = data # input gain gain = 10 ** (input_gain_db / 20.0) mono = mono * gain model = load_model(model_name) if model is None: return (sr, mono) cfg = MODELS[model_name] controls = cfg["controls"] # build condition vector cond = [0.0] * cfg["cond_dim"] ctrl_list = list(controls.values()) if len(ctrl_list) >= 1: c = ctrl_list[0] cond[c["idx"]] = (param1 - c["min"]) / (c["max"] - c["min"]) if len(ctrl_list) >= 2: c = ctrl_list[1] cond[c["idx"]] = (param2 - c["min"]) / (c["max"] - c["min"]) # GRU inference (chunked) chunk_size = 8192 output = np.zeros_like(mono) hidden = None with torch.no_grad(): cond_t = torch.tensor([cond], dtype=torch.float32) for start in range(0, len(mono), chunk_size): end = min(start + chunk_size, len(mono)) chunk = mono[start:end] x = torch.from_numpy(chunk).unsqueeze(0).unsqueeze(-1) pred, hidden = model(x, cond_t, hidden) output[start:end] = pred[0, :, 0].numpy() # dry/wet mix wet = mono * (1 - mix) + output * mix peak = np.max(np.abs(wet)) if peak > 0.99: wet = wet * (0.99 / peak) return (sr, wet.astype(np.float32)) def update_controls(model_name: str): cfg = MODELS.get(model_name, {}) controls = cfg.get("controls", {}) ctrl_list = list(controls.items()) if len(ctrl_list) >= 1: name1, c1 = ctrl_list[0] p1_update = gr.update( label=name1, minimum=c1["min"], maximum=c1["max"], value=c1["default"], visible=True, ) else: p1_update = gr.update(visible=False) if len(ctrl_list) >= 2: name2, c2 = ctrl_list[1] p2_update = gr.update( label=name2, minimum=c2["min"], maximum=c2["max"], value=c2["default"], visible=True, ) else: p2_update = gr.update(visible=False, value=0) return p1_update, p2_update # --------------------------------------------------------------------------- # UI # --------------------------------------------------------------------------- with gr.Blocks( title="NNStomps — Neural Drive", theme=gr.themes.Soft(primary_hue="orange"), ) as demo: gr.Markdown( "# NNStomps — Neural Drive\n" "GRU neural network based saturation/distortion. " "Upload audio and tweak the drive to hear the neural model in action." ) with gr.Row(): with gr.Column(scale=1): model_sel = gr.Dropdown( choices=list(MODELS.keys()), value=list(MODELS.keys())[0], label="Model", ) param1 = gr.Slider( minimum=0, maximum=100, value=50, step=1, label="Drive A", ) param2 = gr.Slider( minimum=0, maximum=100, value=0, step=1, label="Drive B", ) input_gain = gr.Slider( minimum=-12, maximum=12, value=0, step=0.5, label="Input Gain (dB)", ) mix_slider = gr.Slider( minimum=0, maximum=1.0, value=1.0, step=0.05, label="Dry/Wet Mix", ) process_btn = gr.Button("Process", variant="primary", size="lg") with gr.Column(scale=2): audio_in = gr.Audio(label="Input Audio", type="numpy") audio_out = gr.Audio(label="Output Audio", type="numpy") model_sel.change( fn=update_controls, inputs=[model_sel], outputs=[param1, param2], ) process_btn.click( fn=process_audio, inputs=[audio_in, model_sel, param1, param2, mix_slider, input_gain], outputs=[audio_out], ) demo.launch()