gere commited on
Commit
b1bd12b
·
verified ·
1 Parent(s): 3250dd2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -65
app.py CHANGED
@@ -12,7 +12,6 @@ from scipy.signal import butter, lfilter
12
  from pedalboard import Pedalboard, Compressor, Limiter, HighpassFilter, LowpassFilter, Gain
13
  import subprocess
14
  from pydub import AudioSegment
15
- import torch
16
  import hashlib
17
  import pickle
18
 
@@ -21,7 +20,6 @@ CORS(app)
21
 
22
  SR = 44100
23
  TARGET_LOUDNESS = -9.0
24
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
25
  CACHE_DIR = "stem_cache"
26
  os.makedirs(CACHE_DIR, exist_ok=True)
27
 
@@ -57,8 +55,6 @@ def match_key(source, target):
57
  key_s = detect_key(source)
58
  key_t = detect_key(target)
59
  shift = key_t - key_s
60
- if shift == 0:
61
- return source
62
  return librosa.effects.pitch_shift(source, sr=SR, n_steps=float(shift))
63
 
64
  def beat_sync_warp(source, target):
@@ -66,7 +62,7 @@ def beat_sync_warp(source, target):
66
  tempo_s, _ = librosa.beat.beat_track(y=source, sr=SR)
67
  tempo_t = float(np.atleast_1d(tempo_t)[0])
68
  tempo_s = float(np.atleast_1d(tempo_s)[0])
69
- if tempo_t <= 0 or tempo_s <= 0 or abs(tempo_s - tempo_t) < 1:
70
  return librosa.util.fix_length(source, size=len(target))
71
  rate = tempo_s / tempo_t
72
  warped = librosa.effects.time_stretch(source, rate=float(rate))
@@ -80,53 +76,31 @@ def get_file_hash(file_path):
80
  hasher.update(chunk)
81
  return hasher.hexdigest()
82
 
83
- def get_cached_stems(file_path):
84
- file_hash = get_file_hash(file_path)
85
  cache_path = os.path.join(CACHE_DIR, f"{file_hash}_stems.pkl")
86
 
87
  if os.path.exists(cache_path):
 
88
  with open(cache_path, 'rb') as f:
89
- return pickle.load(f)
90
- return None
91
-
92
- def cache_stems(file_path, stems_dict):
93
- file_hash = get_file_hash(file_path)
94
- cache_path = os.path.join(CACHE_DIR, f"{file_hash}_stems.pkl")
95
-
96
- with open(cache_path, 'wb') as f:
97
- pickle.dump(stems_dict, f)
98
-
99
- def separate_stems(input_file, job_id):
100
- cached = get_cached_stems(input_file)
101
- if cached:
102
- return cached, None
103
 
 
104
  out_dir = f"sep_{job_id}"
105
 
106
- if DEVICE == "cuda":
107
- cmd = [
108
- "demucs",
109
- "--two-stems", "vocals",
110
- "-n", "htdemucs",
111
- "--out", out_dir,
112
- "-d", DEVICE,
113
- input_file
114
- ]
115
- else:
116
- cmd = [
117
- "demucs",
118
- "--two-stems", "vocals",
119
- "-n", "htdemucs",
120
- "--out", out_dir,
121
- "-d", "cpu",
122
- input_file
123
- ]
124
 
125
- subprocess.run(cmd, check=True)
 
 
 
 
 
 
 
126
 
127
  base = os.path.splitext(os.path.basename(input_file))[0]
128
  stem_dir = os.path.join(out_dir, "htdemucs", base)
129
-
130
  stems = {
131
  "drums": os.path.join(stem_dir, "drums.wav"),
132
  "bass": os.path.join(stem_dir, "bass.wav"),
@@ -134,14 +108,17 @@ def separate_stems(input_file, job_id):
134
  "vocals": os.path.join(stem_dir, "vocals.wav")
135
  }
136
 
137
- cache_stems(input_file, stems)
 
138
 
139
  return stems, out_dir
140
 
141
  @app.route('/', methods=['GET'])
142
  def health():
143
- response = make_response(jsonify({"status": "ready", "device": DEVICE}), 200)
144
  response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
 
 
145
  return response
146
 
147
  @app.route('/fuse', methods=['POST'])
@@ -149,58 +126,59 @@ def fuse_api():
149
  job_id = uuid.uuid4().hex[:8]
150
  temp_files = []
151
  cleanup_dirs = []
152
-
153
  try:
154
  trad_req = request.files.get('melody')
155
  modern_req = request.files.get('style')
156
 
157
  if not trad_req or not modern_req:
158
  return jsonify({"error": "missing files"}), 400
159
-
160
  t_path = f"t_{job_id}.wav"
161
  m_path = f"m_{job_id}.wav"
162
  trad_req.save(t_path)
163
  modern_req.save(m_path)
164
  temp_files.extend([t_path, m_path])
165
-
166
  t_wav = convert_to_wav(t_path)
167
  m_wav = convert_to_wav(m_path)
168
  if t_wav != t_path: temp_files.append(t_wav)
169
  if m_wav != m_path: temp_files.append(m_wav)
170
-
171
  t_stems, t_dir = separate_stems(t_wav, f"t_{job_id}")
172
  m_stems, m_dir = separate_stems(m_wav, f"m_{job_id}")
173
-
174
  if t_dir: cleanup_dirs.append(t_dir)
175
  if m_dir: cleanup_dirs.append(m_dir)
176
-
 
177
  t_other = load_mono(t_stems["other"])
178
  t_bass = load_mono(t_stems["bass"])
179
  m_drums = load_mono(m_stems["drums"])
180
  m_bass = load_mono(m_stems["bass"])
181
  m_other = load_mono(m_stems["other"])
182
-
183
  target_len = min(len(t_other), len(m_drums))
184
  t_other = t_other[:target_len]
185
  t_bass = t_bass[:target_len]
186
  m_drums = m_drums[:target_len]
187
  m_bass = m_bass[:target_len]
188
  m_other = m_other[:target_len]
189
-
 
190
  t_other = match_key(t_other, m_other)
191
  t_bass = match_key(t_bass, m_bass)
192
  t_other = beat_sync_warp(t_other, m_drums)
193
  t_bass = beat_sync_warp(t_bass, m_drums)
194
-
195
  t_other = highpass(t_other, 120)
196
  t_bass = highpass(t_bass, 60)
197
  m_bass = lowpass(m_bass, 250)
198
  m_drums = lowpass(m_drums, 12000)
199
  m_other = highpass(m_other, 150)
200
-
201
  fusion = (1.0 * m_drums + 1.0 * m_bass + 1.2 * t_other + 0.5 * m_other + 0.8 * t_bass)
202
  fusion = normalize_audio(fusion)
203
-
204
  board = Pedalboard([
205
  HighpassFilter(30),
206
  LowpassFilter(18000),
@@ -208,31 +186,27 @@ def fuse_api():
208
  Gain(2.0),
209
  Limiter(threshold_db=-0.5)
210
  ])
211
-
212
  fusion_mastered = board(fusion, SR)
213
 
214
  meter = pyln.Meter(SR)
215
  loudness = meter.integrated_loudness(fusion_mastered)
216
- if not np.isinf(loudness) and not np.isnan(loudness):
217
- fusion_mastered = pyln.normalize.loudness(fusion_mastered, loudness, TARGET_LOUDNESS)
218
-
219
  buf = io.BytesIO()
220
  sf.write(buf, fusion_mastered, SR, format='WAV')
221
  buf.seek(0)
222
 
 
223
  return send_file(buf, mimetype="audio/wav", as_attachment=True, download_name="fusion.wav")
224
-
225
  except Exception as e:
 
226
  return jsonify({"error": str(e)}), 500
227
  finally:
228
  for f in temp_files:
229
- if os.path.exists(f):
230
- try: os.remove(f)
231
- except: pass
232
  for d in cleanup_dirs:
233
- if d and os.path.exists(d):
234
- try: shutil.rmtree(d, ignore_errors=True)
235
- except: pass
236
 
237
  if __name__ == "__main__":
238
  app.run(host='0.0.0.0', port=7860, threaded=True)
 
12
  from pedalboard import Pedalboard, Compressor, Limiter, HighpassFilter, LowpassFilter, Gain
13
  import subprocess
14
  from pydub import AudioSegment
 
15
  import hashlib
16
  import pickle
17
 
 
20
 
21
  SR = 44100
22
  TARGET_LOUDNESS = -9.0
 
23
  CACHE_DIR = "stem_cache"
24
  os.makedirs(CACHE_DIR, exist_ok=True)
25
 
 
55
  key_s = detect_key(source)
56
  key_t = detect_key(target)
57
  shift = key_t - key_s
 
 
58
  return librosa.effects.pitch_shift(source, sr=SR, n_steps=float(shift))
59
 
60
  def beat_sync_warp(source, target):
 
62
  tempo_s, _ = librosa.beat.beat_track(y=source, sr=SR)
63
  tempo_t = float(np.atleast_1d(tempo_t)[0])
64
  tempo_s = float(np.atleast_1d(tempo_s)[0])
65
+ if tempo_t <= 0 or tempo_s <= 0:
66
  return librosa.util.fix_length(source, size=len(target))
67
  rate = tempo_s / tempo_t
68
  warped = librosa.effects.time_stretch(source, rate=float(rate))
 
76
  hasher.update(chunk)
77
  return hasher.hexdigest()
78
 
79
+ def separate_stems(input_file, job_id):
80
+ file_hash = get_file_hash(input_file)
81
  cache_path = os.path.join(CACHE_DIR, f"{file_hash}_stems.pkl")
82
 
83
  if os.path.exists(cache_path):
84
+ print(f">>> [hf] using cached stems for {input_file}")
85
  with open(cache_path, 'rb') as f:
86
+ return pickle.load(f), None
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ print(f">>> [hf] starting demucs for {input_file}")
89
  out_dir = f"sep_{job_id}"
90
 
91
+ device = "cuda" if subprocess.run(["nvidia-smi"], capture_output=True).returncode == 0 else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ subprocess.run([
94
+ "demucs",
95
+ "--two-stems", "vocals",
96
+ "-n", "htdemucs",
97
+ "--out", out_dir,
98
+ "-d", device,
99
+ input_file
100
+ ], check=True)
101
 
102
  base = os.path.splitext(os.path.basename(input_file))[0]
103
  stem_dir = os.path.join(out_dir, "htdemucs", base)
 
104
  stems = {
105
  "drums": os.path.join(stem_dir, "drums.wav"),
106
  "bass": os.path.join(stem_dir, "bass.wav"),
 
108
  "vocals": os.path.join(stem_dir, "vocals.wav")
109
  }
110
 
111
+ with open(cache_path, 'wb') as f:
112
+ pickle.dump(stems, f)
113
 
114
  return stems, out_dir
115
 
116
  @app.route('/', methods=['GET'])
117
  def health():
118
+ response = make_response(jsonify({"status": "ready"}), 200)
119
  response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
120
+ response.headers["Pragma"] = "no-cache"
121
+ response.headers["Expires"] = "0"
122
  return response
123
 
124
  @app.route('/fuse', methods=['POST'])
 
126
  job_id = uuid.uuid4().hex[:8]
127
  temp_files = []
128
  cleanup_dirs = []
129
+ print(f">>> [hf] new request: job_{job_id}")
130
  try:
131
  trad_req = request.files.get('melody')
132
  modern_req = request.files.get('style')
133
 
134
  if not trad_req or not modern_req:
135
  return jsonify({"error": "missing files"}), 400
136
+
137
  t_path = f"t_{job_id}.wav"
138
  m_path = f"m_{job_id}.wav"
139
  trad_req.save(t_path)
140
  modern_req.save(m_path)
141
  temp_files.extend([t_path, m_path])
142
+
143
  t_wav = convert_to_wav(t_path)
144
  m_wav = convert_to_wav(m_path)
145
  if t_wav != t_path: temp_files.append(t_wav)
146
  if m_wav != m_path: temp_files.append(m_wav)
147
+
148
  t_stems, t_dir = separate_stems(t_wav, f"t_{job_id}")
149
  m_stems, m_dir = separate_stems(m_wav, f"m_{job_id}")
 
150
  if t_dir: cleanup_dirs.append(t_dir)
151
  if m_dir: cleanup_dirs.append(m_dir)
152
+
153
+ print(f">>> [hf] loading stems for job_{job_id}")
154
  t_other = load_mono(t_stems["other"])
155
  t_bass = load_mono(t_stems["bass"])
156
  m_drums = load_mono(m_stems["drums"])
157
  m_bass = load_mono(m_stems["bass"])
158
  m_other = load_mono(m_stems["other"])
159
+
160
  target_len = min(len(t_other), len(m_drums))
161
  t_other = t_other[:target_len]
162
  t_bass = t_bass[:target_len]
163
  m_drums = m_drums[:target_len]
164
  m_bass = m_bass[:target_len]
165
  m_other = m_other[:target_len]
166
+
167
+ print(f">>> [hf] matching key and warp...")
168
  t_other = match_key(t_other, m_other)
169
  t_bass = match_key(t_bass, m_bass)
170
  t_other = beat_sync_warp(t_other, m_drums)
171
  t_bass = beat_sync_warp(t_bass, m_drums)
172
+
173
  t_other = highpass(t_other, 120)
174
  t_bass = highpass(t_bass, 60)
175
  m_bass = lowpass(m_bass, 250)
176
  m_drums = lowpass(m_drums, 12000)
177
  m_other = highpass(m_other, 150)
178
+
179
  fusion = (1.0 * m_drums + 1.0 * m_bass + 1.2 * t_other + 0.5 * m_other + 0.8 * t_bass)
180
  fusion = normalize_audio(fusion)
181
+
182
  board = Pedalboard([
183
  HighpassFilter(30),
184
  LowpassFilter(18000),
 
186
  Gain(2.0),
187
  Limiter(threshold_db=-0.5)
188
  ])
 
189
  fusion_mastered = board(fusion, SR)
190
 
191
  meter = pyln.Meter(SR)
192
  loudness = meter.integrated_loudness(fusion_mastered)
193
+ fusion_mastered = pyln.normalize.loudness(fusion_mastered, loudness, TARGET_LOUDNESS)
194
+
 
195
  buf = io.BytesIO()
196
  sf.write(buf, fusion_mastered, SR, format='WAV')
197
  buf.seek(0)
198
 
199
+ print(f">>> [hf] job_{job_id} complete!")
200
  return send_file(buf, mimetype="audio/wav", as_attachment=True, download_name="fusion.wav")
201
+
202
  except Exception as e:
203
+ print(f">>> [hf] error in job_{job_id}: {str(e)}")
204
  return jsonify({"error": str(e)}), 500
205
  finally:
206
  for f in temp_files:
207
+ if os.path.exists(f): os.remove(f)
 
 
208
  for d in cleanup_dirs:
209
+ if d and os.path.exists(d): shutil.rmtree(d, ignore_errors=True)
 
 
210
 
211
  if __name__ == "__main__":
212
  app.run(host='0.0.0.0', port=7860, threaded=True)