Tachyeon commited on
Commit
5f6cebb
·
verified ·
1 Parent(s): b9f4091

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -89
app.py CHANGED
@@ -6,9 +6,9 @@ import soundfile as sf
6
  import numpy as np
7
  from huggingface_hub import hf_hub_download
8
 
9
- # ==========================================
10
- # 1. ENGINE SETUP (UNCHANGED)
11
- # ==========================================
12
  try:
13
  from bs_roformer import BSRoformer
14
  from attend import Attend
@@ -18,7 +18,9 @@ except ImportError:
18
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
  def safe_attend_forward(self, q, k, v, mask=None):
21
- return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0., is_causal=False)
 
 
22
 
23
  try:
24
  Attend.forward = safe_attend_forward
@@ -27,7 +29,7 @@ except Exception:
27
 
28
  def load_model():
29
  print("Connecting to model...")
30
- checkpoint_path = hf_hub_download(
31
  repo_id="Tachyeon/IAM-RoFormer-Model-Weights",
32
  filename="v11_consensus_epoch_30.pt"
33
  )
@@ -42,169 +44,215 @@ def load_model():
42
  flash_attn=True
43
  ).to(DEVICE)
44
 
45
- ck = torch.load(checkpoint_path, map_location=DEVICE)
46
- model.load_state_dict(ck["model"] if "model" in ck else ck)
47
  model.eval()
48
  return model
49
 
50
  model = load_model()
51
 
52
  def separate_audio(audio_path):
53
- if model is None or not audio_path:
54
  return [None] * 4
55
 
56
  mix, sr = librosa.load(audio_path, sr=44100, mono=False)
57
  if mix.ndim == 1:
58
- mix = np.stack([mix, mix], axis=0)
59
 
60
- chunk_size = 44100 * 10
61
  overlap = 44100
62
 
63
- mix_tensor = torch.tensor(mix).float().to(DEVICE).unsqueeze(0)
64
- length = mix_tensor.shape[-1]
65
 
66
- output = torch.zeros(1, 4, 2, length, device=DEVICE)
67
- count = torch.zeros_like(output)
68
 
69
  with torch.no_grad(), torch.autocast("cuda", enabled=torch.cuda.is_available()):
70
- for start in range(0, length, chunk_size - overlap):
71
- end = min(start + chunk_size, length)
72
- chunk = mix_tensor[:, :, start:end]
73
- if chunk.shape[-1] < chunk_size:
74
- chunk = F.pad(chunk, (0, chunk_size - chunk.shape[-1]))
75
-
76
- pred = model(chunk)
77
- valid = end - start
78
- output[:, :, :, start:end] += pred[:, :, :, :valid]
79
- count[:, :, :, start:end] += 1
80
-
81
- stems = (output / count.clamp(min=1)).cpu().numpy()[0]
82
-
83
  files = []
84
  for i in range(4):
85
- fname = f"stem_{i}.wav"
86
- sf.write(fname, stems[i].T, sr)
87
- files.append(fname)
88
-
89
  return files
90
 
91
- # ==========================================
92
- # 2. UI (Gradio 6 SAFE)
93
- # ==========================================
94
  css = """
95
  @import url('https://fonts.googleapis.com/css2?family=Anton&family=Playfair+Display:ital@1&family=Poppins:wght@400;600;700&display=swap');
96
 
97
- :root{
98
- --bg:#2b1620;
99
- --panel:#3a2430;
 
 
100
  --ink:#f6efe8;
101
  --muted:#c7bfbf;
102
  --accent:#ff73a6;
103
  }
104
 
 
105
  html, body, .gradio-container {
106
  height:100%;
107
- background:linear-gradient(180deg,#2b1620,#1b0d14)!important;
108
- color:var(--ink)!important;
 
 
 
 
109
  font-family:Poppins,sans-serif;
110
  }
111
 
112
- .contain{
113
- height:100vh;
114
- max-width:1200px;
115
- margin:auto;
116
- padding:20px;
 
117
  display:grid;
118
  grid-template-rows:auto 1fr;
119
- gap:20px;
 
120
  }
121
 
122
- .header{
 
123
  display:flex;
124
  justify-content:space-between;
125
  align-items:center;
126
- border:1px solid rgba(255,255,255,.05);
127
- padding:16px;
128
  }
129
 
130
- .logo{
131
  font-family:Anton,sans-serif;
132
- font-size:42px;
 
133
  }
134
 
135
- .subtitle{
136
  font-family:'Playfair Display',serif;
137
  font-style:italic;
138
  color:var(--accent);
 
139
  }
140
 
141
- .grid{
 
142
  display:grid;
143
  grid-template-columns:1fr 1fr;
144
- gap:20px;
145
  height:100%;
146
  }
147
 
148
- .card{
149
- border:1px solid rgba(255,255,255,.05);
150
- padding:20px;
 
 
 
151
  display:flex;
152
  flex-direction:column;
153
- gap:16px;
 
154
  }
155
 
156
- .input-box{
157
- border:1px dashed rgba(255,255,255,.08);
158
- padding:30px;
 
 
159
  text-align:center;
160
  }
161
 
162
- .run-btn{
163
- background:linear-gradient(90deg,#ff73a6,#ffd58a)!important;
164
- color:#12090b!important;
165
- font-weight:800!important;
 
 
 
 
 
 
 
 
 
166
  }
167
 
168
- .stems{
 
169
  display:grid;
170
  grid-template-columns:1fr 1fr;
171
- gap:14px;
172
  }
173
 
174
- .label{
 
 
 
 
 
 
 
175
  font-family:'Playfair Display',serif;
176
  font-style:italic;
177
  color:var(--accent);
 
 
 
 
 
 
 
178
  }
179
  """
180
 
181
  with gr.Blocks() as demo:
182
- with gr.Column(elem_classes="contain"):
183
 
184
  with gr.Row(elem_classes="header"):
185
- gr.HTML('<div class="logo">SWARA STUDIO</div>')
186
- gr.HTML('<div class="subtitle">audio source separation</div>')
187
-
188
- with gr.Row(elem_classes="grid"):
189
-
190
- with gr.Column(elem_classes="card"):
191
- gr.HTML('<div class="input-box"><b>MASTER AUDIO</b><br>Drop or upload WAV / MP3</div>')
192
- input_audio = gr.Audio(type="filepath")
193
- run_btn = gr.Button("RUN SEPARATION", elem_classes="run-btn")
194
-
195
- with gr.Column(elem_classes="card"):
 
 
 
 
 
196
  gr.HTML('<div class="label">STEMS</div>')
197
  with gr.Row(elem_classes="stems"):
198
- out_vocals = gr.Audio(label="Vocals", interactive=False)
199
- out_drums = gr.Audio(label="Drums", interactive=False)
200
- out_bass = gr.Audio(label="Bass", interactive=False)
201
- out_other = gr.Audio(label="Other", interactive=False)
202
-
203
- run_btn.click(
204
- separate_audio,
205
- input_audio,
206
- [out_vocals, out_drums, out_bass, out_other]
207
- )
 
 
 
 
208
 
209
  if __name__ == "__main__":
210
  demo.launch(css=css, theme=gr.themes.Base())
 
6
  import numpy as np
7
  from huggingface_hub import hf_hub_download
8
 
9
+ # =====================================================
10
+ # 1. MODEL LOGIC (UNCHANGED)
11
+ # =====================================================
12
  try:
13
  from bs_roformer import BSRoformer
14
  from attend import Attend
 
18
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
  def safe_attend_forward(self, q, k, v, mask=None):
21
+ return F.scaled_dot_product_attention(
22
+ q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
23
+ )
24
 
25
  try:
26
  Attend.forward = safe_attend_forward
 
29
 
30
  def load_model():
31
  print("Connecting to model...")
32
+ ckpt = hf_hub_download(
33
  repo_id="Tachyeon/IAM-RoFormer-Model-Weights",
34
  filename="v11_consensus_epoch_30.pt"
35
  )
 
44
  flash_attn=True
45
  ).to(DEVICE)
46
 
47
+ state = torch.load(ckpt, map_location=DEVICE)
48
+ model.load_state_dict(state["model"] if "model" in state else state)
49
  model.eval()
50
  return model
51
 
52
  model = load_model()
53
 
54
  def separate_audio(audio_path):
55
+ if not audio_path:
56
  return [None] * 4
57
 
58
  mix, sr = librosa.load(audio_path, sr=44100, mono=False)
59
  if mix.ndim == 1:
60
+ mix = np.stack([mix, mix])
61
 
62
+ chunk = 44100 * 10
63
  overlap = 44100
64
 
65
+ x = torch.tensor(mix).float().to(DEVICE)[None]
66
+ length = x.shape[-1]
67
 
68
+ out = torch.zeros(1, 4, 2, length, device=DEVICE)
69
+ cnt = torch.zeros_like(out)
70
 
71
  with torch.no_grad(), torch.autocast("cuda", enabled=torch.cuda.is_available()):
72
+ for s in range(0, length, chunk - overlap):
73
+ e = min(s + chunk, length)
74
+ part = x[:, :, s:e]
75
+ if part.shape[-1] < chunk:
76
+ part = F.pad(part, (0, chunk - part.shape[-1]))
77
+ pred = model(part)
78
+ out[:, :, :, s:e] += pred[:, :, :, : e - s]
79
+ cnt[:, :, :, s:e] += 1
80
+
81
+ stems = (out / cnt.clamp(min=1)).cpu().numpy()[0]
 
 
 
82
  files = []
83
  for i in range(4):
84
+ f = f"stem_{i}.wav"
85
+ sf.write(f, stems[i].T, sr)
86
+ files.append(f)
 
87
  return files
88
 
89
+ # =====================================================
90
+ # 2. POLISHED UI (FIXED LAYOUT, NO SCROLL)
91
+ # =====================================================
92
  css = """
93
  @import url('https://fonts.googleapis.com/css2?family=Anton&family=Playfair+Display:ital@1&family=Poppins:wght@400;600;700&display=swap');
94
 
95
+ :root {
96
+ --bg1:#2b1620;
97
+ --bg2:#1c0d14;
98
+ --panel:rgba(255,255,255,0.04);
99
+ --border:rgba(255,255,255,0.08);
100
  --ink:#f6efe8;
101
  --muted:#c7bfbf;
102
  --accent:#ff73a6;
103
  }
104
 
105
+ /* HARD RESET */
106
  html, body, .gradio-container {
107
  height:100%;
108
+ width:100%;
109
+ margin:0;
110
+ padding:0;
111
+ overflow:hidden !important;
112
+ background:linear-gradient(180deg,var(--bg1),var(--bg2)) !important;
113
+ color:var(--ink);
114
  font-family:Poppins,sans-serif;
115
  }
116
 
117
+ /* CENTERED APP */
118
+ .app {
119
+ max-width:1100px;
120
+ height:100%;
121
+ margin:0 auto;
122
+ padding:32px;
123
  display:grid;
124
  grid-template-rows:auto 1fr;
125
+ gap:28px;
126
+ box-sizing:border-box;
127
  }
128
 
129
+ /* HEADER */
130
+ .header {
131
  display:flex;
132
  justify-content:space-between;
133
  align-items:center;
 
 
134
  }
135
 
136
+ .title {
137
  font-family:Anton,sans-serif;
138
+ font-size:44px;
139
+ letter-spacing:1px;
140
  }
141
 
142
+ .subtitle {
143
  font-family:'Playfair Display',serif;
144
  font-style:italic;
145
  color:var(--accent);
146
+ margin-left:14px;
147
  }
148
 
149
+ /* MAIN GRID */
150
+ .main {
151
  display:grid;
152
  grid-template-columns:1fr 1fr;
153
+ gap:32px;
154
  height:100%;
155
  }
156
 
157
+ /* PANELS */
158
+ .panel {
159
+ background:var(--panel);
160
+ border:1px solid var(--border);
161
+ border-radius:16px;
162
+ padding:28px;
163
  display:flex;
164
  flex-direction:column;
165
+ gap:22px;
166
+ box-sizing:border-box;
167
  }
168
 
169
+ /* INPUT */
170
+ .drop {
171
+ border:1px dashed var(--border);
172
+ border-radius:12px;
173
+ padding:32px;
174
  text-align:center;
175
  }
176
 
177
+ .drop h3 {
178
+ margin:0;
179
+ font-size:18px;
180
+ letter-spacing:1px;
181
+ }
182
+
183
+ /* BUTTON */
184
+ .run {
185
+ background:linear-gradient(90deg,#ff73a6,#ffd58a) !important;
186
+ color:#160c10 !important;
187
+ font-weight:800 !important;
188
+ border-radius:10px !important;
189
+ border:none !important;
190
  }
191
 
192
+ /* STEMS */
193
+ .stems {
194
  display:grid;
195
  grid-template-columns:1fr 1fr;
196
+ gap:18px;
197
  }
198
 
199
+ .stem {
200
+ background:rgba(255,255,255,0.03);
201
+ border:1px solid var(--border);
202
+ border-radius:12px;
203
+ padding:16px;
204
+ }
205
+
206
+ .label {
207
  font-family:'Playfair Display',serif;
208
  font-style:italic;
209
  color:var(--accent);
210
+ margin-bottom:6px;
211
+ }
212
+
213
+ /* AUDIO FIX */
214
+ audio {
215
+ width:100%;
216
+ max-height:36px;
217
  }
218
  """
219
 
220
  with gr.Blocks() as demo:
221
+ with gr.Column(elem_classes="app"):
222
 
223
  with gr.Row(elem_classes="header"):
224
+ gr.HTML('<div class="title">SWARA STUDIO</div>')
225
+ gr.HTML('<div class="subtitle">Audio Source Separation</div>')
226
+
227
+ with gr.Row(elem_classes="main"):
228
+
229
+ with gr.Column(elem_classes="panel"):
230
+ gr.HTML("""
231
+ <div class="drop">
232
+ <h3>MASTER AUDIO</h3>
233
+ <p>Drop or upload WAV / MP3</p>
234
+ </div>
235
+ """)
236
+ inp = gr.Audio(type="filepath")
237
+ btn = gr.Button("RUN SEPARATION", elem_classes="run")
238
+
239
+ with gr.Column(elem_classes="panel"):
240
  gr.HTML('<div class="label">STEMS</div>')
241
  with gr.Row(elem_classes="stems"):
242
+ with gr.Column(elem_classes="stem"):
243
+ gr.HTML('<div class="label">Vocals</div>')
244
+ o1 = gr.Audio(interactive=False)
245
+ with gr.Column(elem_classes="stem"):
246
+ gr.HTML('<div class="label">Drums</div>')
247
+ o2 = gr.Audio(interactive=False)
248
+ with gr.Column(elem_classes="stem"):
249
+ gr.HTML('<div class="label">Bass</div>')
250
+ o3 = gr.Audio(interactive=False)
251
+ with gr.Column(elem_classes="stem"):
252
+ gr.HTML('<div class="label">Other</div>')
253
+ o4 = gr.Audio(interactive=False)
254
+
255
+ btn.click(separate_audio, inp, [o1, o2, o3, o4])
256
 
257
  if __name__ == "__main__":
258
  demo.launch(css=css, theme=gr.themes.Base())