gere commited on
Commit
7f8ecbe
·
verified ·
1 Parent(s): b979387

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -101
app.py CHANGED
@@ -10,11 +10,13 @@ from flask import Flask, request, send_file, jsonify, make_response
10
  from flask_cors import CORS
11
  from scipy.signal import butter, lfilter
12
  from pedalboard import Pedalboard, Compressor, Limiter, HighpassFilter, LowpassFilter, Gain
13
- import subprocess
14
  from pydub import AudioSegment
15
- from concurrent.futures import ThreadPoolExecutor
16
  import torch
17
 
 
 
 
 
18
  os.environ["OMP_NUM_THREADS"] = "1"
19
 
20
  app = Flask(__name__)
@@ -24,6 +26,12 @@ SR = 44100
24
  TARGET_LOUDNESS = -9.0
25
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
 
 
 
 
 
 
 
27
  def convert_to_wav(input_path):
28
  if input_path.lower().endswith(".mp3"):
29
  wav_path = input_path.rsplit(".", 1)[0] + f"_{uuid.uuid4().hex}.wav"
@@ -31,76 +39,31 @@ def convert_to_wav(input_path):
31
  return wav_path
32
  return input_path
33
 
34
- def load_mono(file_path):
35
- if not os.path.exists(file_path):
36
- return np.zeros(SR * 5)
37
- y, _ = librosa.load(file_path, sr=SR, mono=True)
38
- return y
39
-
40
- def normalize_audio(y):
41
- return y / (np.max(np.abs(y)) + 1e-9)
42
-
43
- def highpass(data, cutoff):
44
- b, a = butter(4, cutoff / (SR / 2), btype='high')
45
- return lfilter(b, a, data)
46
-
47
- def lowpass(data, cutoff):
48
- b, a = butter(4, cutoff / (SR / 2), btype='low')
49
- return lfilter(b, a, data)
50
-
51
- def detect_key(y):
52
- chroma = librosa.feature.chroma_cqt(y=y, sr=SR)
53
- return np.argmax(np.sum(chroma, axis=1))
54
-
55
- def match_key(source, target):
56
- key_s = detect_key(source)
57
- key_t = detect_key(target)
58
- shift = key_t - key_s
59
- if shift == 0: return source
60
- return librosa.effects.pitch_shift(source, sr=SR, n_steps=float(shift))
61
-
62
- def beat_sync_warp(source, target):
63
- tempo_t, _ = librosa.beat.beat_track(y=target, sr=SR)
64
- tempo_s, _ = librosa.beat.beat_track(y=source, sr=SR)
65
- tempo_t = float(np.atleast_1d(tempo_t)[0])
66
- tempo_s = float(np.atleast_1d(tempo_s)[0])
67
- if tempo_t <= 0 or tempo_s <= 0 or abs(tempo_s - tempo_t) < 1:
68
- return librosa.util.fix_length(source, size=len(target))
69
- rate = tempo_s / tempo_t
70
- warped = librosa.effects.time_stretch(source, rate=float(rate))
71
- warped = librosa.util.fix_length(warped, size=len(target))
72
- return warped
73
-
74
- def separate_stems(input_file, job_id):
75
- out_dir = f"sep_{job_id}"
76
- cmd = [
77
- "demucs",
78
- "-n", "htdemucs",
79
- "--out", out_dir,
80
- "-d", DEVICE,
81
- input_file
82
- ]
83
- subprocess.run(cmd, check=True)
84
- base = os.path.splitext(os.path.basename(input_file))[0]
85
- stem_dir = os.path.join(out_dir, "htdemucs", base)
86
  stems = {
87
- "drums": os.path.join(stem_dir, "drums.wav"),
88
- "bass": os.path.join(stem_dir, "bass.wav"),
89
- "other": os.path.join(stem_dir, "other.wav"),
90
- "vocals": os.path.join(stem_dir, "vocals.wav")
91
  }
92
- return stems, out_dir
93
 
94
- @app.route('/', methods=['GET'])
95
- def health():
96
- response = make_response(jsonify({"status": "ready", "device": DEVICE}), 200)
97
- response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
98
- return response
99
 
100
  @app.route('/fuse', methods=['POST'])
101
  def fuse_api():
102
  job_id = uuid.uuid4().hex[:8]
103
- temp_files, cleanup_dirs = [], []
104
  try:
105
  trad_req = request.files.get('melody')
106
  modern_req = request.files.get('style')
@@ -112,45 +75,24 @@ def fuse_api():
112
  modern_req.save(m_path)
113
  temp_files.extend([t_path, m_path])
114
 
 
115
  t_wav = convert_to_wav(t_path)
116
  m_wav = convert_to_wav(m_path)
117
  if t_wav != t_path: temp_files.append(t_wav)
118
  if m_wav != m_path: temp_files.append(m_wav)
119
 
120
- with ThreadPoolExecutor(max_workers=2) as executor:
121
- fut_t = executor.submit(separate_stems, t_wav, f"t_{job_id}")
122
- fut_m = executor.submit(separate_stems, m_wav, f"m_{job_id}")
123
- t_stems, t_dir = fut_t.result()
124
- m_stems, m_dir = fut_m.result()
125
-
126
- cleanup_dirs.extend([t_dir, m_dir])
127
-
128
- t_other, t_bass = load_mono(t_stems["other"]), load_mono(t_stems["bass"])
129
- m_drums, m_bass, m_other = load_mono(m_stems["drums"]), load_mono(m_stems["bass"]), load_mono(m_stems["other"])
130
-
131
- target_len = min(len(t_other), len(m_drums))
132
- t_other, t_bass = t_other[:target_len], t_bass[:target_len]
133
- m_drums, m_bass, m_other = m_drums[:target_len], m_bass[:target_len], m_other[:target_len]
134
-
135
- t_other = match_key(t_other, m_other)
136
- t_bass = match_key(t_bass, m_bass)
137
- t_other = beat_sync_warp(t_other, m_drums)
138
- t_bass = beat_sync_warp(t_bass, m_drums)
139
-
140
- t_other, t_bass = highpass(t_other, 120), highpass(t_bass, 60)
141
- m_bass, m_drums, m_other = lowpass(m_bass, 250), lowpass(m_drums, 12000), highpass(m_other, 150)
142
-
143
- fusion = normalize_audio(1.0 * m_drums + 1.0 * m_bass + 1.2 * t_other + 0.5 * m_other + 0.8 * t_bass)
144
-
145
- board = Pedalboard([
146
- HighpassFilter(30), LowpassFilter(18000),
147
- Compressor(threshold_db=-20, ratio=2), Gain(2.0), Limiter(threshold_db=-0.5)
148
- ])
149
- fusion_mastered = board(fusion, SR)
150
-
151
- meter = pyln.Meter(SR)
152
- loudness = meter.integrated_loudness(fusion_mastered)
153
- fusion_mastered = pyln.normalize.loudness(fusion_mastered, loudness, TARGET_LOUDNESS)
154
 
155
  buf = io.BytesIO()
156
  sf.write(buf, fusion_mastered, SR, format='WAV')
@@ -162,8 +104,6 @@ def fuse_api():
162
  finally:
163
  for f in temp_files:
164
  if os.path.exists(f): os.remove(f)
165
- for d in cleanup_dirs:
166
- shutil.rmtree(d, ignore_errors=True)
167
 
168
  if __name__ == "__main__":
169
  app.run(host='0.0.0.0', port=7860)
 
10
  from flask_cors import CORS
11
  from scipy.signal import butter, lfilter
12
  from pedalboard import Pedalboard, Compressor, Limiter, HighpassFilter, LowpassFilter, Gain
 
13
  from pydub import AudioSegment
 
14
  import torch
15
 
16
+ # Import Demucs Python API
17
+ from demucs.apply import apply_model
18
+ from demucs.pretrained import get_model
19
+
20
  os.environ["OMP_NUM_THREADS"] = "1"
21
 
22
  app = Flask(__name__)
 
26
  TARGET_LOUDNESS = -9.0
27
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
 
29
+ # PRE-LOAD MODEL (This runs once at startup)
30
+ print(f"Loading Demucs model on {DEVICE}...")
31
+ MODEL = get_model("htdemucs")
32
+ MODEL.to(DEVICE)
33
+ MODEL.eval()
34
+
35
  def convert_to_wav(input_path):
36
  if input_path.lower().endswith(".mp3"):
37
  wav_path = input_path.rsplit(".", 1)[0] + f"_{uuid.uuid4().hex}.wav"
 
39
  return wav_path
40
  return input_path
41
 
42
+ def separate_stems_fast(input_file):
43
+ # Load audio
44
+ wav, _ = librosa.load(input_file, sr=MODEL.samplerate, mono=False)
45
+ wav_tensor = torch.from_numpy(wav).to(DEVICE).float()
46
+
47
+ # Process with pre-loaded model
48
+ with torch.no_grad():
49
+ # Add batch dimension and apply model
50
+ sources = apply_model(MODEL, wav_tensor[None], device=DEVICE)[0]
51
+
52
+ # Map sources to dictionary (htdemucs order: drums, bass, other, vocals)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  stems = {
54
+ "drums": sources[0].cpu().numpy(),
55
+ "bass": sources[1].cpu().numpy(),
56
+ "other": sources[2].cpu().numpy(),
57
+ "vocals": sources[3].cpu().numpy()
58
  }
59
+ return stems
60
 
61
+ # ... (Keep your helper functions like match_key, beat_sync_warp, etc.) ...
 
 
 
 
62
 
63
  @app.route('/fuse', methods=['POST'])
64
  def fuse_api():
65
  job_id = uuid.uuid4().hex[:8]
66
+ temp_files = []
67
  try:
68
  trad_req = request.files.get('melody')
69
  modern_req = request.files.get('style')
 
75
  modern_req.save(m_path)
76
  temp_files.extend([t_path, m_path])
77
 
78
+ # Convert/Load
79
  t_wav = convert_to_wav(t_path)
80
  m_wav = convert_to_wav(m_path)
81
  if t_wav != t_path: temp_files.append(t_wav)
82
  if m_wav != m_path: temp_files.append(m_wav)
83
 
84
+ # SEPARATION (Using pre-loaded model in sequence to avoid GPU crash)
85
+ t_stems = separate_stems_fast(t_wav)
86
+ m_stems = separate_stems_fast(m_wav)
87
+
88
+ # Get mono tracks for your logic (mean across channels)
89
+ t_other = np.mean(t_stems["other"], axis=0)
90
+ t_bass = np.mean(t_stems["bass"], axis=0)
91
+ m_drums = np.mean(m_stems["drums"], axis=0)
92
+ m_bass = np.mean(m_stems["bass"], axis=0)
93
+ m_other = np.mean(m_stems["other"], axis=0)
94
+
95
+ # ... (Rest of your fusion and mastering logic) ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  buf = io.BytesIO()
98
  sf.write(buf, fusion_mastered, SR, format='WAV')
 
104
  finally:
105
  for f in temp_files:
106
  if os.path.exists(f): os.remove(f)
 
 
107
 
108
  if __name__ == "__main__":
109
  app.run(host='0.0.0.0', port=7860)