Tachyeon commited on
Commit
dc8229b
Β·
verified Β·
1 Parent(s): 5a98f35

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import librosa
4
+ import soundfile as sf
5
+ import numpy as np
6
+ import os
7
+ import sys
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ # 1. SETUP
11
+ # Import the architecture directly since we installed the repo via requirements.txt
12
+ from models.bs_roformer.bs_roformer import BSRoformer
13
+
14
+ DEVICE = "cpu" # Free Tier uses CPU
15
+
16
+ # 2. DOWNLOAD & LOAD MODEL
17
+ # πŸ‘‡ REPLACE THIS with your actual Model Repo ID (e.g. "Rahul/IAM-RoFormer-Weights")
18
+ REPO_ID = "Tachyeon/IAM-RoFormer-Model-Weights"
19
+ FILENAME = "v11_consensus_epoch_30.pt"
20
+
21
+ print(f">>> ⏳ Downloading Model from {REPO_ID}...")
22
+ try:
23
+ # This downloads the 4.5GB file from your storage repo to the Space
24
+ model_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
25
+ print(f">>> βœ… Download Complete: {model_path}")
26
+
27
+ # Initialize Architecture
28
+ model = BSRoformer(
29
+ dim=512, depth=12, stereo=True, num_stems=4,
30
+ time_transformer_depth=1, freq_transformer_depth=1,
31
+ flash_attn=False
32
+ ).to(DEVICE)
33
+
34
+ # Load Weights
35
+ state = torch.load(model_path, map_location=DEVICE)
36
+ if 'model' in state: state = state['model']
37
+ model.load_state_dict(state, strict=False)
38
+ model.eval()
39
+ print(">>> βœ… Model Loaded Successfully!")
40
+ except Exception as e:
41
+ print(f"❌ Error: {e}")
42
+ raise e
43
+
44
+ # 3. INFERENCE LOGIC (V15 PURE FIDELITY)
45
+ def separate_audio(audio_file, progress=gr.Progress()):
46
+ if audio_file is None: return None, None, None, None
47
+
48
+ progress(0, desc="Loading Audio...")
49
+ print(f">>> πŸͺ„ Processing: {audio_file}")
50
+
51
+ mix, sr = librosa.load(audio_file, sr=44100, mono=False)
52
+ if len(mix.shape) == 1: mix = np.stack([mix, mix], axis=0)
53
+
54
+ # Chunking (Safe for CPU)
55
+ chunk_size = 44100 * 10
56
+ overlap = 44100 * 1
57
+
58
+ mix_tensor = torch.tensor(mix, dtype=torch.float32).to(DEVICE)
59
+ if mix_tensor.dim() == 2: mix_tensor = mix_tensor.unsqueeze(0)
60
+
61
+ length = mix_tensor.shape[-1]
62
+ final_output = torch.zeros(1, 4, 2, length).to(DEVICE)
63
+ counts = torch.zeros(1, 4, 2, length).to(DEVICE)
64
+
65
+ progress(0.1, desc="Separating Stems...")
66
+ with torch.no_grad():
67
+ for start in range(0, length, int(chunk_size - overlap)):
68
+ end = min(start + int(chunk_size), length)
69
+ chunk = mix_tensor[:, :, start:end]
70
+ if chunk.shape[-1] < chunk_size:
71
+ pad_len = int(chunk_size - chunk.shape[-1])
72
+ chunk = torch.nn.functional.pad(chunk, (0, pad_len))
73
+
74
+ pred = model(chunk)
75
+
76
+ valid_len = end - start
77
+ final_output[:, :, :, start:end] += pred[:, :, :, :valid_len]
78
+ counts[:, :, :, start:end] += 1.0
79
+
80
+ current_progress = 0.1 + (0.8 * (end / length))
81
+ progress(current_progress, desc="Processing...")
82
+
83
+ stems = (final_output / torch.clamp(counts, min=1.0)).cpu().numpy()[0]
84
+
85
+ # V15 Safety Normalization
86
+ peak = np.max(np.abs(stems))
87
+ if peak > 0.99: stems = stems / peak
88
+
89
+ outputs = []
90
+ for i in range(4):
91
+ outfile = f"stem_{i}.wav"
92
+ sf.write(outfile, stems[i].T, sr)
93
+ outputs.append(outfile)
94
+
95
+ return outputs[0], outputs[1], outputs[2], outputs[3]
96
+
97
+ # 4. UI
98
+ custom_css = "#title {text-align: center} #desc {text-align: center}"
99
+
100
+ with gr.Blocks(css=custom_css, title="IAM Source Separation") as demo:
101
+ gr.Markdown("# 🎻 Indian Art Music Source Separator", elem_id="title")
102
+ gr.Markdown("### Powered by RoFormer | Epoch 30 Consensus Model", elem_id="desc")
103
+
104
+ with gr.Row():
105
+ with gr.Column():
106
+ input_audio = gr.Audio(label="Input Mixture", type="filepath")
107
+ submit_btn = gr.Button("✨ Separate Audio", variant="primary", size="lg")
108
+ with gr.Column():
109
+ out_vocals = gr.Audio(label="Vocals", interactive=False)
110
+ out_drums = gr.Audio(label="Mridangam", interactive=False)
111
+ out_bass = gr.Audio(label="Tanpura", interactive=False)
112
+ out_other = gr.Audio(label="Violin", interactive=False)
113
+
114
+ submit_btn.click(separate_audio, inputs=input_audio, outputs=[out_vocals, out_drums, out_bass, out_other])
115
+
116
+ demo.launch()