theadityamittal commited on
Commit
acf8aa4
·
1 Parent(s): 3b79f8c

Initial Space with model download + Gradio demo

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
config/default.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config/default.yaml
2
+
3
+ device: "mps"
4
+
5
+ data:
6
+ raw_path: data/raw
7
+ splits: ["train", "test"]
8
+ processed_path: data/processed
9
+ sample_rate: 16000
10
+ n_fft: 1024
11
+ hop_length: 512
12
+ n_mels: 80
13
+ segment_length: 256
14
+
15
+ # for DataLoader
16
+ batch_size: 16
17
+ num_workers: 4
18
+
19
+ # list of all sources (including mixture)
20
+ sources: ["mixture", "drums", "bass", "other", "vocals"]
21
+
22
+ model:
23
+ checkpoint_dir: models/checkpoints
24
+
25
+ # for UNet
26
+ chans: 32
27
+ num_pool_layers: 4
28
+
29
+ training:
30
+ # for training loop
31
+ epochs: 50
32
+ lr: 1e-4
33
+ max_steps: null
34
+ log_interval: 50 # how many batches between progress logs
35
+
36
+ augment:
37
+ # defaults for your SpectrogramTransforms
38
+ time_mask_param: 30
39
+ freq_mask_param: 15
40
+ time_warp_param: 40
41
+ stripe_time_width: 1
42
+ stripe_freq_width: 1
43
+ stripe_time_count: 2
44
+ stripe_freq_count: 2
45
+ noise_std: 0.01
46
+
47
+ experiment:
48
+ # MLflow experiment metadata
49
+ name: default_experiment
50
+ run_name: run1
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ librosa
3
+ fastapi
4
+ uvicorn
5
+ mlflow
6
+ dvc
7
+ pytest
8
+ gradio
9
+ soundfile
10
+ huggingface_hub
11
+ omegaconf
serve.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # serve.py
2
+
3
+ import os
4
+ import tempfile
5
+ import numpy as np
6
+ import torch
7
+ import librosa
8
+ import soundfile as sf
9
+ import gradio as gr
10
+ from omegaconf import OmegaConf
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ from src.models.unet import UNet
14
+
15
+ # 1) Load your config and model once at startup
16
+ CFG = OmegaConf.load("config/default.yaml")
17
+ DEVICE = torch.device("mps" if torch.mps else "cpu")
18
+
19
+ MODEL = UNet(
20
+ in_ch=1,
21
+ num_sources=len(CFG.data.sources) - 1,
22
+ chans=CFG.model.chans,
23
+ num_pool_layers=CFG.model.num_pool_layers
24
+ ).to(DEVICE)
25
+
26
+ # point this at your best checkpoint in the Space
27
+ ckpt_file = hf_hub_download(
28
+ repo_id="theadityamittal/music-separator-unet",
29
+ filename="checkpoints/unet_best.pt"
30
+ )
31
+ MODEL.load_state_dict(torch.load(ckpt_file, map_location=DEVICE))
32
+ MODEL.eval()
33
+
34
+
35
+ def separate_file(mix_path):
36
+ """
37
+ Given a file path to the uploaded mixture WAV, returns
38
+ a dict of { "drums": path, "bass": path, ... } to the separated .wav files.
39
+ """
40
+ # 1. Load audio & STFT
41
+ wav, sr = librosa.load(mix_path, sr=CFG.data.sample_rate, mono=True)
42
+ stft = librosa.stft(
43
+ wav, n_fft=CFG.data.n_fft, hop_length=CFG.data.hop_length
44
+ )
45
+ mag, phase = np.abs(stft), np.angle(stft)
46
+ F, T = mag.shape
47
+
48
+ # 2. Pad to multiple of segment_length
49
+ SEG = CFG.data.segment_length
50
+ pad = (SEG - (T % SEG)) % SEG
51
+ if pad:
52
+ mag = np.pad(mag, ((0,0),(0,pad)), constant_values=0)
53
+ phase = np.pad(phase, ((0,0),(0,pad)), constant_values=0)
54
+ n_seg = mag.shape[1] // SEG
55
+
56
+ # 3. Inference in chunks
57
+ preds = []
58
+ with torch.no_grad():
59
+ for i in range(n_seg):
60
+ mseg = mag[:, i*SEG:(i+1)*SEG]
61
+ x = torch.from_numpy(mseg).unsqueeze(0).unsqueeze(0).to(DEVICE).float()
62
+ y = MODEL(x) # (1, S, F, SEG)
63
+ preds.append(y.squeeze(0).cpu().numpy())
64
+ pred_mag = np.concatenate(preds, axis=2)[:, :, :T]
65
+ phase = phase[:, :T]
66
+
67
+ # 4. Reconstruct waveforms and write to temp files
68
+ out_paths = {}
69
+ for idx, src in enumerate(CFG.data.sources[1:]):
70
+ spec = pred_mag[idx] * np.exp(1j * phase)
71
+ est = librosa.istft(
72
+ spec,
73
+ hop_length=CFG.data.hop_length,
74
+ win_length=CFG.data.n_fft
75
+ )
76
+ # write to a temp WAV file
77
+ fd, path = tempfile.mkstemp(suffix=f"_{src}.wav")
78
+ os.close(fd)
79
+ sf.write(path, est, sr)
80
+ out_paths[src] = path
81
+
82
+ # return in the order drums, bass, other, vocals
83
+ return [out_paths[src] for src in CFG.data.sources[1:]]
84
+
85
+
86
+ # 5) Build Gradio interface
87
+ description = """
88
+ ## Music Source Separation
89
+
90
+ Upload a mix `.wav` and get back **drums**, **bass**, **other**, and **vocals** stems separated by a U-Net model.
91
+ """
92
+
93
+ iface = gr.Interface(
94
+ fn=separate_file,
95
+ inputs=gr.Audio(label="Mixture (.wav)", type="filepath"),
96
+ outputs=[
97
+ gr.Audio(label="Drums", type="filepath"),
98
+ gr.Audio(label="Bass", type="filepath"),
99
+ gr.Audio(label="Other", type="filepath"),
100
+ gr.Audio(label="Vocals", type="filepath"),
101
+ ],
102
+ title="U-Net Music Separator",
103
+ description=description,
104
+ allow_flagging="never",
105
+ )
106
+
107
+ if __name__ == "__main__":
108
+ iface.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860)), share=True)
src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/__init__.py ADDED
File without changes
src/models/__init__.py ADDED
File without changes
src/models/unet.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/models/unet.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ class ConvBlock(nn.Module):
8
+ def __init__(self, in_ch: int, out_ch: int):
9
+ super().__init__()
10
+ self.net = nn.Sequential(
11
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
12
+ nn.BatchNorm2d(out_ch),
13
+ nn.ReLU(inplace=True),
14
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
15
+ nn.BatchNorm2d(out_ch),
16
+ nn.ReLU(inplace=True),
17
+ )
18
+
19
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
20
+ return self.net(x)
21
+
22
+ class DownBlock(nn.Module):
23
+ def __init__(self, in_ch: int, out_ch: int):
24
+ super().__init__()
25
+ self.conv = ConvBlock(in_ch, out_ch)
26
+ self.pool = nn.MaxPool2d(kernel_size=2)
27
+
28
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
29
+ x = self.conv(x)
30
+ return self.pool(x)
31
+
32
+ class UpBlock(nn.Module):
33
+ def __init__(self, in_ch: int, out_ch: int):
34
+ super().__init__()
35
+ self.upconv = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
36
+ self.conv = ConvBlock(in_ch, out_ch)
37
+
38
+ def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
39
+ x = self.upconv(x)
40
+ # pad if needed
41
+ if x.shape != skip.shape:
42
+ diffY = skip.size(2) - x.size(2)
43
+ diffX = skip.size(3) - x.size(3)
44
+ x = F.pad(x, [diffX//2, diffX-diffX//2, diffY//2, diffY-diffY//2])
45
+ x = torch.cat([skip, x], dim=1)
46
+ return self.conv(x)
47
+
48
+ class UNet(nn.Module):
49
+ def __init__(self,
50
+ in_ch: int = 1,
51
+ num_sources: int = 4,
52
+ chans: int = 32,
53
+ num_pool_layers: int= 4):
54
+ super().__init__()
55
+
56
+ # --- Encoder ---
57
+ self.down_blocks = nn.ModuleList()
58
+ ch = chans
59
+ # first layer
60
+ self.down_blocks.append(ConvBlock(in_ch, ch))
61
+ # further downsampling
62
+ for _ in range(1, num_pool_layers):
63
+ self.down_blocks.append(DownBlock(ch, ch*2))
64
+ ch *= 2
65
+
66
+ # --- Bottleneck ---
67
+ self.bottleneck = ConvBlock(ch, ch*2)
68
+ ch *= 2 # now channel count matches bottleneck output
69
+
70
+ # --- Decoder ---
71
+ self.up_blocks = nn.ModuleList()
72
+ for _ in range(num_pool_layers):
73
+ # in_ch = two times the skip channels
74
+ self.up_blocks.append(UpBlock(ch, ch//2))
75
+ ch //= 2
76
+
77
+ # ch now equals the number of channels output by the last UpBlock
78
+ # --- Final conv ---
79
+ self.final_conv = nn.Conv2d(ch, num_sources, kernel_size=1)
80
+
81
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
82
+ skips = []
83
+ for down in self.down_blocks:
84
+ x = down(x)
85
+ skips.append(x)
86
+
87
+ x = self.bottleneck(x)
88
+
89
+ for up, skip in zip(self.up_blocks, reversed(skips)):
90
+ x = up(x, skip)
91
+
92
+ return self.final_conv(x)