gere commited on
Commit
a34d937
·
verified ·
1 Parent(s): 2ed3866

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -75
app.py CHANGED
@@ -12,16 +12,15 @@ from scipy.signal import butter, lfilter
12
  from pedalboard import Pedalboard, Compressor, Limiter, HighpassFilter, LowpassFilter, Gain
13
  import subprocess
14
  from pydub import AudioSegment
15
- import hashlib
16
- import pickle
17
 
18
  app = Flask(__name__)
19
  CORS(app)
20
 
21
  SR = 44100
22
  TARGET_LOUDNESS = -9.0
23
- CACHE_DIR = "stem_cache"
24
- os.makedirs(CACHE_DIR, exist_ok=True)
25
 
26
  def convert_to_wav(input_path):
27
  if input_path.lower().endswith(".mp3"):
@@ -55,6 +54,7 @@ def match_key(source, target):
55
  key_s = detect_key(source)
56
  key_t = detect_key(target)
57
  shift = key_t - key_s
 
58
  return librosa.effects.pitch_shift(source, sr=SR, n_steps=float(shift))
59
 
60
  def beat_sync_warp(source, target):
@@ -62,43 +62,23 @@ def beat_sync_warp(source, target):
62
  tempo_s, _ = librosa.beat.beat_track(y=source, sr=SR)
63
  tempo_t = float(np.atleast_1d(tempo_t)[0])
64
  tempo_s = float(np.atleast_1d(tempo_s)[0])
65
- if tempo_t <= 0 or tempo_s <= 0:
66
  return librosa.util.fix_length(source, size=len(target))
67
  rate = tempo_s / tempo_t
68
  warped = librosa.effects.time_stretch(source, rate=float(rate))
69
  warped = librosa.util.fix_length(warped, size=len(target))
70
  return warped
71
 
72
- def get_file_hash(file_path):
73
- hasher = hashlib.md5()
74
- with open(file_path, 'rb') as f:
75
- for chunk in iter(lambda: f.read(65536), b''):
76
- hasher.update(chunk)
77
- return hasher.hexdigest()
78
-
79
  def separate_stems(input_file, job_id):
80
- file_hash = get_file_hash(input_file)
81
- cache_path = os.path.join(CACHE_DIR, f"{file_hash}_stems.pkl")
82
-
83
- if os.path.exists(cache_path):
84
- print(f">>> [hf] using cached stems for {input_file}")
85
- with open(cache_path, 'rb') as f:
86
- return pickle.load(f), None
87
-
88
- print(f">>> [hf] starting demucs for {input_file}")
89
  out_dir = f"sep_{job_id}"
90
-
91
- device = "cuda" if subprocess.run(["nvidia-smi"], capture_output=True).returncode == 0 else "cpu"
92
-
93
- subprocess.run([
94
  "demucs",
95
- "--two-stems", "vocals",
96
  "-n", "htdemucs",
97
  "--out", out_dir,
98
- "-d", device,
99
  input_file
100
- ], check=True)
101
-
102
  base = os.path.splitext(os.path.basename(input_file))[0]
103
  stem_dir = os.path.join(out_dir, "htdemucs", base)
104
  stems = {
@@ -107,35 +87,25 @@ def separate_stems(input_file, job_id):
107
  "other": os.path.join(stem_dir, "other.wav"),
108
  "vocals": os.path.join(stem_dir, "vocals.wav")
109
  }
110
-
111
- with open(cache_path, 'wb') as f:
112
- pickle.dump(stems, f)
113
-
114
  return stems, out_dir
115
 
116
  @app.route('/', methods=['GET'])
117
  def health():
118
- response = make_response(jsonify({"status": "ready"}), 200)
119
  response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
120
- response.headers["Pragma"] = "no-cache"
121
- response.headers["Expires"] = "0"
122
  return response
123
 
124
  @app.route('/fuse', methods=['POST'])
125
  def fuse_api():
126
  job_id = uuid.uuid4().hex[:8]
127
- temp_files = []
128
- cleanup_dirs = []
129
- print(f">>> [hf] new request: job_{job_id}")
130
  try:
131
  trad_req = request.files.get('melody')
132
  modern_req = request.files.get('style')
133
-
134
  if not trad_req or not modern_req:
135
  return jsonify({"error": "missing files"}), 400
136
 
137
- t_path = f"t_{job_id}.wav"
138
- m_path = f"m_{job_id}.wav"
139
  trad_req.save(t_path)
140
  modern_req.save(m_path)
141
  temp_files.extend([t_path, m_path])
@@ -145,46 +115,34 @@ def fuse_api():
145
  if t_wav != t_path: temp_files.append(t_wav)
146
  if m_wav != m_path: temp_files.append(m_wav)
147
 
148
- t_stems, t_dir = separate_stems(t_wav, f"t_{job_id}")
149
- m_stems, m_dir = separate_stems(m_wav, f"m_{job_id}")
150
- if t_dir: cleanup_dirs.append(t_dir)
151
- if m_dir: cleanup_dirs.append(m_dir)
 
 
 
152
 
153
- print(f">>> [hf] loading stems for job_{job_id}")
154
- t_other = load_mono(t_stems["other"])
155
- t_bass = load_mono(t_stems["bass"])
156
- m_drums = load_mono(m_stems["drums"])
157
- m_bass = load_mono(m_stems["bass"])
158
- m_other = load_mono(m_stems["other"])
159
 
160
  target_len = min(len(t_other), len(m_drums))
161
- t_other = t_other[:target_len]
162
- t_bass = t_bass[:target_len]
163
- m_drums = m_drums[:target_len]
164
- m_bass = m_bass[:target_len]
165
- m_other = m_other[:target_len]
166
 
167
- print(f">>> [hf] matching key and warp...")
168
  t_other = match_key(t_other, m_other)
169
  t_bass = match_key(t_bass, m_bass)
170
  t_other = beat_sync_warp(t_other, m_drums)
171
  t_bass = beat_sync_warp(t_bass, m_drums)
172
 
173
- t_other = highpass(t_other, 120)
174
- t_bass = highpass(t_bass, 60)
175
- m_bass = lowpass(m_bass, 250)
176
- m_drums = lowpass(m_drums, 12000)
177
- m_other = highpass(m_other, 150)
178
 
179
- fusion = (1.0 * m_drums + 1.0 * m_bass + 1.2 * t_other + 0.5 * m_other + 0.8 * t_bass)
180
- fusion = normalize_audio(fusion)
181
 
182
  board = Pedalboard([
183
- HighpassFilter(30),
184
- LowpassFilter(18000),
185
- Compressor(threshold_db=-20, ratio=2),
186
- Gain(2.0),
187
- Limiter(threshold_db=-0.5)
188
  ])
189
  fusion_mastered = board(fusion, SR)
190
 
@@ -195,18 +153,15 @@ def fuse_api():
195
  buf = io.BytesIO()
196
  sf.write(buf, fusion_mastered, SR, format='WAV')
197
  buf.seek(0)
198
-
199
- print(f">>> [hf] job_{job_id} complete!")
200
  return send_file(buf, mimetype="audio/wav", as_attachment=True, download_name="fusion.wav")
201
 
202
  except Exception as e:
203
- print(f">>> [hf] error in job_{job_id}: {str(e)}")
204
  return jsonify({"error": str(e)}), 500
205
  finally:
206
  for f in temp_files:
207
  if os.path.exists(f): os.remove(f)
208
  for d in cleanup_dirs:
209
- if d and os.path.exists(d): shutil.rmtree(d, ignore_errors=True)
210
 
211
  if __name__ == "__main__":
212
- app.run(host='0.0.0.0', port=7860, threaded=True)
 
12
  from pedalboard import Pedalboard, Compressor, Limiter, HighpassFilter, LowpassFilter, Gain
13
  import subprocess
14
  from pydub import AudioSegment
15
+ from concurrent.futures import ThreadPoolExecutor
16
+ import torch
17
 
18
  app = Flask(__name__)
19
  CORS(app)
20
 
21
  SR = 44100
22
  TARGET_LOUDNESS = -9.0
23
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
24
 
25
  def convert_to_wav(input_path):
26
  if input_path.lower().endswith(".mp3"):
 
54
  key_s = detect_key(source)
55
  key_t = detect_key(target)
56
  shift = key_t - key_s
57
+ if shift == 0: return source
58
  return librosa.effects.pitch_shift(source, sr=SR, n_steps=float(shift))
59
 
60
  def beat_sync_warp(source, target):
 
62
  tempo_s, _ = librosa.beat.beat_track(y=source, sr=SR)
63
  tempo_t = float(np.atleast_1d(tempo_t)[0])
64
  tempo_s = float(np.atleast_1d(tempo_s)[0])
65
+ if tempo_t <= 0 or tempo_s <= 0 or abs(tempo_s - tempo_t) < 1:
66
  return librosa.util.fix_length(source, size=len(target))
67
  rate = tempo_s / tempo_t
68
  warped = librosa.effects.time_stretch(source, rate=float(rate))
69
  warped = librosa.util.fix_length(warped, size=len(target))
70
  return warped
71
 
 
 
 
 
 
 
 
72
  def separate_stems(input_file, job_id):
 
 
 
 
 
 
 
 
 
73
  out_dir = f"sep_{job_id}"
74
+ cmd = [
 
 
 
75
  "demucs",
 
76
  "-n", "htdemucs",
77
  "--out", out_dir,
78
+ "-d", DEVICE,
79
  input_file
80
+ ]
81
+ subprocess.run(cmd, check=True)
82
  base = os.path.splitext(os.path.basename(input_file))[0]
83
  stem_dir = os.path.join(out_dir, "htdemucs", base)
84
  stems = {
 
87
  "other": os.path.join(stem_dir, "other.wav"),
88
  "vocals": os.path.join(stem_dir, "vocals.wav")
89
  }
 
 
 
 
90
  return stems, out_dir
91
 
92
  @app.route('/', methods=['GET'])
93
  def health():
94
+ response = make_response(jsonify({"status": "ready", "device": DEVICE}), 200)
95
  response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
 
 
96
  return response
97
 
98
  @app.route('/fuse', methods=['POST'])
99
  def fuse_api():
100
  job_id = uuid.uuid4().hex[:8]
101
+ temp_files, cleanup_dirs = [], []
 
 
102
  try:
103
  trad_req = request.files.get('melody')
104
  modern_req = request.files.get('style')
 
105
  if not trad_req or not modern_req:
106
  return jsonify({"error": "missing files"}), 400
107
 
108
+ t_path, m_path = f"t_{job_id}.wav", f"m_{job_id}.wav"
 
109
  trad_req.save(t_path)
110
  modern_req.save(m_path)
111
  temp_files.extend([t_path, m_path])
 
115
  if t_wav != t_path: temp_files.append(t_wav)
116
  if m_wav != m_path: temp_files.append(m_wav)
117
 
118
+ with ThreadPoolExecutor(max_workers=2) as executor:
119
+ fut_t = executor.submit(separate_stems, t_wav, f"t_{job_id}")
120
+ fut_m = executor.submit(separate_stems, m_wav, f"m_{job_id}")
121
+ t_stems, t_dir = fut_t.result()
122
+ m_stems, m_dir = fut_m.result()
123
+
124
+ cleanup_dirs.extend([t_dir, m_dir])
125
 
126
+ t_other, t_bass = load_mono(t_stems["other"]), load_mono(t_stems["bass"])
127
+ m_drums, m_bass, m_other = load_mono(m_stems["drums"]), load_mono(m_stems["bass"]), load_mono(m_stems["other"])
 
 
 
 
128
 
129
  target_len = min(len(t_other), len(m_drums))
130
+ t_other, t_bass = t_other[:target_len], t_bass[:target_len]
131
+ m_drums, m_bass, m_other = m_drums[:target_len], m_bass[:target_len], m_other[:target_len]
 
 
 
132
 
 
133
  t_other = match_key(t_other, m_other)
134
  t_bass = match_key(t_bass, m_bass)
135
  t_other = beat_sync_warp(t_other, m_drums)
136
  t_bass = beat_sync_warp(t_bass, m_drums)
137
 
138
+ t_other, t_bass = highpass(t_other, 120), highpass(t_bass, 60)
139
+ m_bass, m_drums, m_other = lowpass(m_bass, 250), lowpass(m_drums, 12000), highpass(m_other, 150)
 
 
 
140
 
141
+ fusion = normalize_audio(1.0 * m_drums + 1.0 * m_bass + 1.2 * t_other + 0.5 * m_other + 0.8 * t_bass)
 
142
 
143
  board = Pedalboard([
144
+ HighpassFilter(30), LowpassFilter(18000),
145
+ Compressor(threshold_db=-20, ratio=2), Gain(2.0), Limiter(threshold_db=-0.5)
 
 
 
146
  ])
147
  fusion_mastered = board(fusion, SR)
148
 
 
153
  buf = io.BytesIO()
154
  sf.write(buf, fusion_mastered, SR, format='WAV')
155
  buf.seek(0)
 
 
156
  return send_file(buf, mimetype="audio/wav", as_attachment=True, download_name="fusion.wav")
157
 
158
  except Exception as e:
 
159
  return jsonify({"error": str(e)}), 500
160
  finally:
161
  for f in temp_files:
162
  if os.path.exists(f): os.remove(f)
163
  for d in cleanup_dirs:
164
+ shutil.rmtree(d, ignore_errors=True)
165
 
166
  if __name__ == "__main__":
167
+ app.run(host='0.0.0.0', port=7860)