Tachyeon commited on
Commit
8a7120d
Β·
verified Β·
1 Parent(s): 1259067

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -82
app.py CHANGED
@@ -6,159 +6,272 @@ import soundfile as sf
6
  import numpy as np
7
  import os
8
  import sys
9
- # New import to download your model
10
  from huggingface_hub import hf_hub_download
11
 
12
- # 1. IMPORT YOUR LOCAL MODULES
13
- from bs_roformer import BSRoformer
14
- from attend import Attend
 
 
 
 
 
 
 
 
15
 
16
- # 2. SETUP DEVICE & PATCHES
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
  def safe_attend_forward(self, q, k, v, mask=None):
20
  return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0., is_causal=False)
21
 
22
- Attend.forward = safe_attend_forward
23
-
24
- print(f">>> 🎡 INITIALIZING RAW V11 ENGINE on {DEVICE}...")
 
 
25
 
26
- # 3. LOAD MODEL FROM YOUR REPO
27
  def load_model():
28
- print(">>> πŸ“‘ Downloading/Loading weights from Tachyeon/IAM-RoFormer-Model-Weights...")
29
-
30
- # This automatically downloads the file if it's not cached
31
  try:
32
  checkpoint_path = hf_hub_download(
33
  repo_id="Tachyeon/IAM-RoFormer-Model-Weights",
34
  filename="v11_consensus_epoch_30.pt"
35
  )
36
- print(f">>> βœ… Weights found at: {checkpoint_path}")
37
  except Exception as e:
38
- print(f"❌ Error downloading model: {e}")
39
  return None
40
 
41
- # Initialize Architecture
42
- try:
43
- model = BSRoformer(
44
- dim=512, depth=12, stereo=True, num_stems=4,
45
- time_transformer_depth=1, freq_transformer_depth=1,
46
- flash_attn=True
47
- ).to(DEVICE)
48
- except:
49
- model = BSRoformer(
50
- dim=512, depth=12, stereo=True, num_stems=4,
51
- time_transformer_depth=1, freq_transformer_depth=1
52
- ).to(DEVICE)
53
 
54
  # Load Weights
55
  ck = torch.load(checkpoint_path, map_location=DEVICE)
56
- if 'model' in ck:
57
- model.load_state_dict(ck['model'])
58
- else:
59
- model.load_state_dict(ck)
60
 
61
  model.eval()
62
  return model
63
 
 
64
  model = load_model()
65
 
66
- # 4. INFERENCE LOGIC
 
 
67
  def separate_audio(audio_path):
68
- if model is None:
69
- raise ValueError("Model failed to load. Check logs.")
70
-
71
- if not audio_path:
72
- return None, None, None, None
73
 
74
- print(f"\n>>> πŸͺ„ Separating '{os.path.basename(audio_path)}'...")
75
-
76
  mix, sr = librosa.load(audio_path, sr=44100, mono=False)
77
- if len(mix.shape) == 1:
78
- mix = np.stack([mix, mix], axis=0)
79
 
 
80
  chunk_size = 44100 * 10
81
  overlap = 44100 * 1
82
 
83
  mix_tensor = torch.tensor(mix, dtype=torch.float32).to(DEVICE)
84
- if mix_tensor.dim() == 2:
85
- mix_tensor = mix_tensor.unsqueeze(0)
86
 
87
  length = mix_tensor.shape[-1]
88
  final_output = torch.zeros(1, 4, 2, length).to(DEVICE)
89
  counts = torch.zeros(1, 4, 2, length).to(DEVICE)
90
 
91
- print(" Processing chunks...")
92
  with torch.no_grad():
93
  context = torch.amp.autocast('cuda') if torch.cuda.is_available() else torch.no_grad()
94
  with context:
95
  for start in range(0, length, chunk_size - overlap):
96
  end = min(start + chunk_size, length)
97
  chunk = mix_tensor[:, :, start:end]
 
98
  if chunk.shape[-1] < chunk_size:
99
  chunk = F.pad(chunk, (0, chunk_size - chunk.shape[-1]))
 
100
  pred = model(chunk)
 
 
101
  valid = end - start
102
  final_output[:, :, :, start:end] += pred[:, :, :, :valid]
103
  counts[:, :, :, start:end] += 1.0
104
 
 
105
  stems = (final_output / torch.clamp(counts, min=1.0)).cpu().numpy()[0]
106
 
 
107
  outputs = []
108
- stem_names = ["Stem_1", "Stem_2", "Stem_3", "Stem_4"]
109
 
110
- for i, name in enumerate(stem_names):
111
- outfile = f"output_{i}_{name}.wav"
112
  sf.write(outfile, stems[i].T, sr)
113
  outputs.append(outfile)
114
 
115
  return outputs[0], outputs[1], outputs[2], outputs[3]
116
 
117
  # ==========================================
118
- # 6. PROFESSIONAL UI (SWARA STUDIO)
119
  # ==========================================
120
 
121
- custom_css = """
122
- @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;600&display=swap');
123
- body, .gradio-container { background-color: #0b0f19 !important; color: #e2e8f0 !important; font-family: 'Inter', sans-serif !important; }
124
- .main-header { text-align: center; margin-bottom: 2rem; padding: 2rem 0; border-bottom: 1px solid #1e293b; }
125
- .title-text { font-size: 3.5rem; font-weight: 300; letter-spacing: 4px; color: #f8fafc; margin: 0; text-transform: uppercase; }
126
- .studio-panel { background: #111827; border: 1px solid #1f2937; border-radius: 16px; padding: 24px; }
127
- #process-btn { background: linear-gradient(135deg, #3b82f6 0%, #2563eb 100%); border: none; color: white; font-weight: 600; padding: 12px; border-radius: 8px; margin-top: 20px; }
128
- #process-btn:hover { transform: translateY(-1px); box-shadow: 0 10px 15px -3px rgba(37, 99, 235, 0.3); }
129
- audio { width: 100%; filter: invert(0.9) hue-rotate(180deg); opacity: 0.8; }
130
- .section-label { font-size: 0.85rem; font-weight: 600; color: #64748b; text-transform: uppercase; margin-bottom: 12px; display: block; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  """
132
 
133
- with gr.Blocks() as demo:
134
 
135
- with gr.Row():
136
- gr.HTML("""
137
- <div class="main-header">
138
- <h1 class="title-text">SWARA STUDIO</h1>
 
 
 
 
139
  </div>
140
- """)
141
-
142
- with gr.Row():
143
- with gr.Column(scale=1):
144
- with gr.Group(elem_classes="studio-panel"):
145
- gr.HTML("<span class='section-label'>// Source Material</span>")
146
- input_audio = gr.Audio(label="", type="filepath", interactive=True)
147
- process_btn = gr.Button("INITIALIZE SEPARATION", elem_id="process-btn", size="lg")
148
-
149
- with gr.Column(scale=1):
150
- with gr.Group(elem_classes="studio-panel"):
151
- gr.HTML("<span class='section-label'>// Isolated Stems</span>")
152
- out1 = gr.Audio(label="Vocals", interactive=False, show_label=True)
153
- out2 = gr.Audio(label="Mridangam", interactive=False, show_label=True)
154
- out3 = gr.Audio(label="Tanpura", interactive=False, show_label=True)
155
- out4 = gr.Audio(label="Other", interactive=False, show_label=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
 
157
  process_btn.click(
158
- fn=separate_audio,
159
- inputs=[input_audio],
160
- outputs=[out1, out2, out3, out4]
161
  )
162
 
163
  if __name__ == "__main__":
164
- demo.launch(theme=gr.themes.Base(primary_hue="slate"), css=custom_css)
 
6
  import numpy as np
7
  import os
8
  import sys
 
9
  from huggingface_hub import hf_hub_download
10
 
11
+ # ==========================================
12
+ # 1. SETUP & MODEL LOADING (Backend)
13
+ # ==========================================
14
+ # We keep your exact logic here, just ensuring robust imports
15
+ try:
16
+ from bs_roformer import BSRoformer
17
+ from attend import Attend
18
+ except ImportError:
19
+ # Fallback if running locally without properly set paths
20
+ # You might need to ensure these files are in your HF Space root
21
+ pass
22
 
 
23
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
 
25
  def safe_attend_forward(self, q, k, v, mask=None):
26
  return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0., is_causal=False)
27
 
28
+ # Monkey Patch
29
+ try:
30
+ Attend.forward = safe_attend_forward
31
+ except NameError:
32
+ pass # Handle case where imports failed
33
 
34
+ # Load Model with Caching
35
  def load_model():
36
+ print(">>> πŸ“‘ Loading Model Weights...")
 
 
37
  try:
38
  checkpoint_path = hf_hub_download(
39
  repo_id="Tachyeon/IAM-RoFormer-Model-Weights",
40
  filename="v11_consensus_epoch_30.pt"
41
  )
 
42
  except Exception as e:
43
+ print(f"Error: {e}")
44
  return None
45
 
46
+ # Initialize Architecture (Standard BSRoformer Config)
47
+ model = BSRoformer(
48
+ dim=512, depth=12, stereo=True, num_stems=4,
49
+ time_transformer_depth=1, freq_transformer_depth=1,
50
+ flash_attn=True
51
+ ).to(DEVICE)
 
 
 
 
 
 
52
 
53
  # Load Weights
54
  ck = torch.load(checkpoint_path, map_location=DEVICE)
55
+ if 'model' in ck: model.load_state_dict(ck['model'])
56
+ else: model.load_state_dict(ck)
 
 
57
 
58
  model.eval()
59
  return model
60
 
61
+ # Initialize Global Model
62
  model = load_model()
63
 
64
+ # ==========================================
65
+ # 2. INFERENCE LOGIC (Chunking)
66
+ # ==========================================
67
  def separate_audio(audio_path):
68
+ if model is None: return None, None, None, None
69
+ if not audio_path: return None, None, None, None
 
 
 
70
 
71
+ # Load & Normalize
 
72
  mix, sr = librosa.load(audio_path, sr=44100, mono=False)
73
+ if len(mix.shape) == 1: mix = np.stack([mix, mix], axis=0)
 
74
 
75
+ # Chunking Params
76
  chunk_size = 44100 * 10
77
  overlap = 44100 * 1
78
 
79
  mix_tensor = torch.tensor(mix, dtype=torch.float32).to(DEVICE)
80
+ if mix_tensor.dim() == 2: mix_tensor = mix_tensor.unsqueeze(0)
 
81
 
82
  length = mix_tensor.shape[-1]
83
  final_output = torch.zeros(1, 4, 2, length).to(DEVICE)
84
  counts = torch.zeros(1, 4, 2, length).to(DEVICE)
85
 
86
+ # Inference Loop
87
  with torch.no_grad():
88
  context = torch.amp.autocast('cuda') if torch.cuda.is_available() else torch.no_grad()
89
  with context:
90
  for start in range(0, length, chunk_size - overlap):
91
  end = min(start + chunk_size, length)
92
  chunk = mix_tensor[:, :, start:end]
93
+ # Pad if needed
94
  if chunk.shape[-1] < chunk_size:
95
  chunk = F.pad(chunk, (0, chunk_size - chunk.shape[-1]))
96
+
97
  pred = model(chunk)
98
+
99
+ # Overlap Add
100
  valid = end - start
101
  final_output[:, :, :, start:end] += pred[:, :, :, :valid]
102
  counts[:, :, :, start:end] += 1.0
103
 
104
+ # Normalize by counts
105
  stems = (final_output / torch.clamp(counts, min=1.0)).cpu().numpy()[0]
106
 
107
+ # Save Outputs
108
  outputs = []
109
+ stem_names = ["Vocals", "Drums", "Bass", "Other"]
110
 
111
+ for i in range(4):
112
+ outfile = f"stem_{i}.wav"
113
  sf.write(outfile, stems[i].T, sr)
114
  outputs.append(outfile)
115
 
116
  return outputs[0], outputs[1], outputs[2], outputs[3]
117
 
118
  # ==========================================
119
+ # 3. UI DESIGN (ELEGANT DARK MODE)
120
  # ==========================================
121
 
122
+ # CSS: High-End VST Plugin Look
123
+ css = """
124
+ @import url('https://fonts.googleapis.com/css2?family=Manrope:wght@300;400;600;800&display=swap');
125
+
126
+ :root {
127
+ --bg-dark: #0F1116;
128
+ --panel-bg: #161922;
129
+ --accent: #6C5CE7; /* Elegant Violet */
130
+ --accent-glow: rgba(108, 92, 231, 0.3);
131
+ --text-main: #E0E0E0;
132
+ --text-muted: #888899;
133
+ --border: #2A2D3A;
134
+ }
135
+
136
+ body, .gradio-container {
137
+ background-color: var(--bg-dark) !important;
138
+ font-family: 'Manrope', sans-serif !important;
139
+ color: var(--text-main) !important;
140
+ margin: 0;
141
+ padding: 0;
142
+ height: 100vh; /* Force full screen */
143
+ overflow: hidden; /* No scroll */
144
+ }
145
+
146
+ /* Remove Gradio Bloat */
147
+ footer { display: none !important; }
148
+ .contain { display: flex; flex-direction: column; height: 100%; padding: 20px !important; }
149
+
150
+ /* HEADER */
151
+ .header-bar {
152
+ display: flex;
153
+ justify-content: space-between;
154
+ align-items: center;
155
+ padding-bottom: 20px;
156
+ border-bottom: 1px solid var(--border);
157
+ margin-bottom: 20px;
158
+ }
159
+ .brand {
160
+ font-size: 1.5rem;
161
+ font-weight: 800;
162
+ letter-spacing: 1px;
163
+ background: linear-gradient(90deg, #fff, #a5b4fc);
164
+ -webkit-background-clip: text;
165
+ -webkit-text-fill-color: transparent;
166
+ }
167
+ .tagline {
168
+ font-size: 0.85rem;
169
+ color: var(--text-muted);
170
+ font-weight: 400;
171
+ }
172
+
173
+ /* PANELS */
174
+ .panel {
175
+ background: var(--panel-bg);
176
+ border: 1px solid var(--border);
177
+ border-radius: 16px;
178
+ padding: 24px;
179
+ height: 100%;
180
+ display: flex;
181
+ flex-direction: column;
182
+ box-shadow: 0 10px 30px rgba(0,0,0,0.2);
183
+ }
184
+
185
+ .panel-header {
186
+ font-size: 0.9rem;
187
+ color: var(--accent);
188
+ text-transform: uppercase;
189
+ letter-spacing: 2px;
190
+ font-weight: 600;
191
+ margin-bottom: 15px;
192
+ display: flex;
193
+ align-items: center;
194
+ gap: 8px;
195
+ }
196
+
197
+ /* BUTTONS */
198
+ button.primary-btn {
199
+ background: linear-gradient(135deg, var(--accent) 0%, #4834d4 100%) !important;
200
+ border: none !important;
201
+ color: white !important;
202
+ font-weight: 700 !important;
203
+ padding: 15px !important;
204
+ border-radius: 12px !important;
205
+ font-size: 1rem !important;
206
+ margin-top: auto !important; /* Push to bottom */
207
+ transition: all 0.3s ease !important;
208
+ box-shadow: 0 4px 15px var(--accent-glow) !important;
209
+ }
210
+ button.primary-btn:hover {
211
+ transform: translateY(-2px);
212
+ box-shadow: 0 8px 25px var(--accent-glow) !important;
213
+ }
214
+
215
+ /* AUDIO PLAYERS - Minimalist */
216
+ .audio-container {
217
+ background: transparent !important;
218
+ border: none !important;
219
+ }
220
  """
221
 
222
+ with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
223
 
224
+ with gr.Column(elem_classes="contain"):
225
+
226
+ # 1. TOP BAR
227
+ with gr.Row(elem_classes="header-bar"):
228
+ gr.HTML("""
229
+ <div>
230
+ <div class="brand">SWARA STUDIO <span style="font-weight:300; opacity:0.5;">| PRO</span></div>
231
+ <div class="tagline">Indian Art Music Source Separation Engine</div>
232
  </div>
233
+ """)
234
+
235
+ # 2. MAIN WORKSPACE (Grid)
236
+ with gr.Row(equal_height=True):
237
+
238
+ # LEFT: INPUT DECK
239
+ with gr.Column(scale=1):
240
+ with gr.Group(elem_classes="panel"):
241
+ gr.HTML('<div class="panel-header">πŸ’Ώ Source Deck</div>')
242
+
243
+ # File Input
244
+ input_audio = gr.Audio(
245
+ label="Drop Mix Here",
246
+ type="filepath",
247
+ interactive=True,
248
+ elem_classes="audio-container"
249
+ )
250
+
251
+ gr.Markdown("Supports WAV, MP3, FLAC (44.1kHz)", elem_classes="tagline")
252
+
253
+ # Separation Button (Pushed to bottom via CSS)
254
+ process_btn = gr.Button("⚑ SEPARATE TRACKS", elem_classes="primary-btn")
255
+
256
+ # RIGHT: OUTPUT RACK
257
+ with gr.Column(scale=2):
258
+ with gr.Group(elem_classes="panel"):
259
+ gr.HTML('<div class="panel-header">🎚️ Stem Rack</div>')
260
+
261
+ with gr.Row():
262
+ with gr.Column():
263
+ out_vocals = gr.Audio(label="🎀 Vocals", interactive=False, type="filepath")
264
+ out_drums = gr.Audio(label="πŸ₯ Mridangam / Drums", interactive=False, type="filepath")
265
+ with gr.Column():
266
+ out_bass = gr.Audio(label="🎸 Tanpura / Bass", interactive=False, type="filepath")
267
+ out_other = gr.Audio(label="🎻 Violin / Other", interactive=False, type="filepath")
268
 
269
+ # 3. WIRING
270
  process_btn.click(
271
+ fn=separate_audio,
272
+ inputs=[input_audio],
273
+ outputs=[out_vocals, out_drums, out_bass, out_other]
274
  )
275
 
276
  if __name__ == "__main__":
277
+ demo.launch()