| |
| import io, os, torch, numpy as np, soundfile as sf |
| from huggingface_hub import snapshot_download |
| from model import UFormer, UFormerConfig |
|
|
| |
| |
| |
| REPO_ID = "yongyizang/MSR_UFormers" |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| local_dir= snapshot_download(REPO_ID) |
| config = UFormerConfig() |
| _model_cache = {} |
|
|
| VALID_CKPTS = [ |
| "acoustic_guitar","bass","electric_guitar","guitars","keyboards", |
| "orchestra","rhythm_section","synth","vocals" |
| ] |
|
|
| def _get_model(ckpt_name: str): |
| if ckpt_name not in VALID_CKPTS: |
| raise ValueError(f"Invalid checkpoint {ckpt_name!r}, choose from {VALID_CKPTS}") |
| if ckpt_name in _model_cache: |
| return _model_cache[ckpt_name] |
| path = os.path.join(local_dir, "checkpoints", f"{ckpt_name}.pth") |
| m = UFormer(config).to(device).eval() |
| sd = torch.load(path, map_location="cpu") |
| m.load_state_dict(sd) |
| _model_cache[ckpt_name] = m |
| return m |
|
|
| |
| |
| |
| def _overlap_add(model, x: np.ndarray, sr: int, chunk_s: float=5., hop_s: float=2.5): |
| C, T = x.shape |
| chunk, hop = int(sr*chunk_s), int(sr*hop_s) |
| pad = (-(T - chunk) % hop) if T>chunk else 0 |
| x_pad = np.pad(x, ((0,0),(0,pad)), mode="reflect") |
| win = np.hanning(chunk)[None,:] |
| out = np.zeros_like(x_pad); norm = np.zeros((1,x_pad.shape[1])) |
| n_chunks = 1 + (x_pad.shape[1] - chunk)//hop |
|
|
| for i in range(n_chunks): |
| s = i*hop |
| seg = x_pad[:, s:s+chunk] |
| seg = seg.astype(np.float32) |
| with torch.no_grad(): |
| y = model(torch.from_numpy(seg[None]).to(device)).squeeze(0).cpu().numpy() |
| out[:, s:s+chunk] += y * win |
| norm[:, s:s+chunk] += win |
|
|
| return (out / norm)[:, :T] |
|
|
| |
| |
| |
| def inference(input_bytes: bytes, checkpoint: str = "guitars") -> bytes: |
| """ |
| audio_bytes in β restored_bytes out. |
| Pass {"inputs": <bytes>, "parameters": {"checkpoint": "<name>"}} to choose. |
| """ |
| audio, sr = sf.read(io.BytesIO(input_bytes)) |
| if audio.ndim==1: audio = np.stack([audio,audio],axis=1) |
| x = audio.T |
|
|
| model = _get_model(checkpoint) |
| if x.shape[1] <= sr*5: |
| with torch.no_grad(): |
| y = model(torch.from_numpy(x[None]).to(device)).squeeze(0).cpu().numpy() |
| else: |
| y = _overlap_add(model, x, sr) |
|
|
| buf = io.BytesIO() |
| sf.write(buf, y.T, sr, format="WAV") |
| return buf.getvalue() |
|
|
| |
| |
| |
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser("UFormer RESTORE") |
| parser.add_argument("-i","--input", type=str, help="noisy WAV") |
| parser.add_argument("-o","--output",type=str, help="restored WAV") |
| parser.add_argument("-c","--checkpoint",type=str,default="guitars", |
| choices=VALID_CKPTS) |
| parser.add_argument("--serve",action="store_true", help="launch Gradio") |
| args = parser.parse_args() |
|
|
| if args.serve: |
| import gradio as gr |
| def _gr(path, ckpt): |
| return inference(open(path,"rb").read(), checkpoint=ckpt) |
| gr.Interface( |
| fn=_gr, |
| inputs=[ |
| gr.Audio(source="upload", type="filepath"), |
| gr.Dropdown(VALID_CKPTS, label="Checkpoint") |
| ], |
| outputs=gr.Audio(type="filepath"), |
| title="π΅ Music Source Restoration Restoration", |
| description="Choose which instrument/group model to run." |
| ).launch() |
|
|
| else: |
| assert args.input and args.output |
| out = inference(open(args.input,"rb").read(), |
| checkpoint=args.checkpoint) |
| open(args.output,"wb").write(out) |
| print(f"β
Restored β {args.output} using {args.checkpoint}") |
|
|