gere commited on
Commit
3f26513
·
verified ·
1 Parent(s): f3d75a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -75
app.py CHANGED
@@ -6,18 +6,24 @@ import numpy as np
6
  import librosa
7
  import soundfile as sf
8
  import pyloudnorm as pyln
 
9
  from flask import Flask, request, send_file, jsonify, make_response
10
  from flask_cors import CORS
11
  from scipy.signal import butter, lfilter
12
  from pedalboard import Pedalboard, Compressor, Limiter, HighpassFilter, LowpassFilter, Gain
13
  import subprocess
14
  from pydub import AudioSegment
 
15
 
16
  app = Flask(__name__)
17
  CORS(app)
18
 
19
- SR = 44100
20
- TARGET_LOUDNESS = -9.0
 
 
 
 
21
 
22
  def convert_to_wav(input_path):
23
  if input_path.lower().endswith(".mp3"):
@@ -28,37 +34,40 @@ def convert_to_wav(input_path):
28
 
29
  def load_mono(file_path):
30
  if not os.path.exists(file_path):
31
- return np.zeros(SR * 5)
32
- y, _ = librosa.load(file_path, sr=SR, mono=True)
 
 
33
  return y
34
 
35
  def normalize_audio(y):
36
  return y / (np.max(np.abs(y)) + 1e-9)
37
 
38
  def highpass(data, cutoff):
39
- b, a = butter(4, cutoff / (SR / 2), btype='high')
40
  return lfilter(b, a, data)
41
 
42
  def lowpass(data, cutoff):
43
- b, a = butter(4, cutoff / (SR / 2), btype='low')
44
  return lfilter(b, a, data)
45
 
46
  def detect_key(y):
47
- chroma = librosa.feature.chroma_cqt(y=y, sr=SR)
48
  return np.argmax(np.sum(chroma, axis=1))
49
 
50
  def match_key(source, target):
51
  key_s = detect_key(source)
52
  key_t = detect_key(target)
53
  shift = key_t - key_s
54
- return librosa.effects.pitch_shift(source, sr=SR, n_steps=float(shift))
 
55
 
56
  def beat_sync_warp(source, target):
57
- tempo_t, _ = librosa.beat.beat_track(y=target, sr=SR)
58
- tempo_s, _ = librosa.beat.beat_track(y=source, sr=SR)
59
  tempo_t = float(np.atleast_1d(tempo_t)[0])
60
  tempo_s = float(np.atleast_1d(tempo_s)[0])
61
- if tempo_t <= 0 or tempo_s <= 0:
62
  return librosa.util.fix_length(source, size=len(target))
63
  rate = tempo_s / tempo_t
64
  warped = librosa.effects.time_stretch(source, rate=float(rate))
@@ -66,111 +75,92 @@ def beat_sync_warp(source, target):
66
  return warped
67
 
68
  def separate_stems(input_file, job_id):
69
- print(f">>> [hf] starting demucs for {input_file}")
70
  out_dir = f"sep_{job_id}"
71
- subprocess.run(["demucs", "-n", "htdemucs", "--out", out_dir, input_file], check=True)
 
 
 
 
 
 
 
 
 
72
  base = os.path.splitext(os.path.basename(input_file))[0]
73
- stem_dir = os.path.join(out_dir, "htdemucs", base)
74
- stems = {
75
  "drums": os.path.join(stem_dir, "drums.wav"),
76
  "bass": os.path.join(stem_dir, "bass.wav"),
77
  "other": os.path.join(stem_dir, "other.wav"),
78
  "vocals": os.path.join(stem_dir, "vocals.wav")
79
- }
80
- return stems, out_dir
81
 
82
  @app.route('/', methods=['GET'])
83
  def health():
84
- response = make_response(jsonify({"status": "ready"}), 200)
85
- response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
86
- response.headers["Pragma"] = "no-cache"
87
- response.headers["Expires"] = "0"
88
- return response
89
 
90
  @app.route('/fuse', methods=['POST'])
91
  def fuse_api():
92
  job_id = uuid.uuid4().hex[:8]
93
- temp_files = []
94
- cleanup_dirs = []
95
- print(f">>> [hf] new request: job_{job_id}")
96
  try:
97
- trad_req = request.files.get('melody')
98
- modern_req = request.files.get('style')
99
-
100
- if not trad_req or not modern_req:
101
- return jsonify({"error": "missing files"}), 400
102
 
103
- t_path = f"t_{job_id}.wav"
104
- m_path = f"m_{job_id}.wav"
105
- trad_req.save(t_path)
106
- modern_req.save(m_path)
107
  temp_files.extend([t_path, m_path])
108
 
109
- t_wav = convert_to_wav(t_path)
110
- m_wav = convert_to_wav(m_path)
111
- if t_wav != t_path: temp_files.append(t_wav)
112
- if m_wav != m_path: temp_files.append(m_wav)
113
-
114
- t_stems, t_dir = separate_stems(t_wav, f"t_{job_id}")
115
- m_stems, m_dir = separate_stems(m_wav, f"m_{job_id}")
 
 
 
 
116
  cleanup_dirs.extend([t_dir, m_dir])
117
 
118
- print(f">>> [hf] loading stems for job_{job_id}")
119
- t_other = load_mono(t_stems["other"])
120
- t_bass = load_mono(t_stems["bass"])
121
- m_drums = load_mono(m_stems["drums"])
122
- m_bass = load_mono(m_stems["bass"])
123
- m_other = load_mono(m_stems["other"])
124
 
125
  target_len = min(len(t_other), len(m_drums))
126
- t_other = t_other[:target_len]
127
- t_bass = t_bass[:target_len]
128
- m_drums = m_drums[:target_len]
129
- m_bass = m_bass[:target_len]
130
- m_other = m_other[:target_len]
131
 
132
- print(f">>> [hf] matching key and warp...")
133
  t_other = match_key(t_other, m_other)
134
  t_bass = match_key(t_bass, m_bass)
135
  t_other = beat_sync_warp(t_other, m_drums)
136
  t_bass = beat_sync_warp(t_bass, m_drums)
137
 
138
- t_other = highpass(t_other, 120)
139
- t_bass = highpass(t_bass, 60)
140
- m_bass = lowpass(m_bass, 250)
141
- m_drums = lowpass(m_drums, 12000)
142
- m_other = highpass(m_other, 150)
143
-
144
- fusion = (1.0 * m_drums + 1.0 * m_bass + 1.2 * t_other + 0.5 * m_other + 0.8 * t_bass)
145
- fusion = normalize_audio(fusion)
146
 
147
  board = Pedalboard([
148
- HighpassFilter(30),
149
- LowpassFilter(18000),
150
- Compressor(threshold_db=-20, ratio=2),
151
- Gain(2.0),
152
- Limiter(threshold_db=-0.5)
153
  ])
154
- fusion_mastered = board(fusion, SR)
155
 
156
- meter = pyln.Meter(SR)
157
  loudness = meter.integrated_loudness(fusion_mastered)
158
- fusion_mastered = pyln.normalize.loudness(fusion_mastered, loudness, TARGET_LOUDNESS)
159
 
160
  buf = io.BytesIO()
161
- sf.write(buf, fusion_mastered, SR, format='WAV')
162
  buf.seek(0)
163
-
164
- print(f">>> [hf] job_{job_id} complete!")
165
  return send_file(buf, mimetype="audio/wav", as_attachment=True, download_name="fusion.wav")
166
 
167
  except Exception as e:
168
- print(f">>> [hf] error in job_{job_id}: {str(e)}")
169
  return jsonify({"error": str(e)}), 500
170
  finally:
171
- for f in temp_files:
172
  if os.path.exists(f): os.remove(f)
173
- for d in cleanup_dirs:
174
  if os.path.exists(d): shutil.rmtree(d, ignore_errors=True)
175
 
176
  if __name__ == "__main__":
 
6
  import librosa
7
  import soundfile as sf
8
  import pyloudnorm as pyln
9
+ import torch
10
  from flask import Flask, request, send_file, jsonify, make_response
11
  from flask_cors import CORS
12
  from scipy.signal import butter, lfilter
13
  from pedalboard import Pedalboard, Compressor, Limiter, HighpassFilter, LowpassFilter, Gain
14
  import subprocess
15
  from pydub import AudioSegment
16
+ from concurrent.futures import ThreadPoolExecutor
17
 
18
  app = Flask(__name__)
19
  CORS(app)
20
 
21
+ # dynamic hardware detection
22
+ has_gpu = torch.cuda.is_available()
23
+ device_type = "cuda" if has_gpu else "cpu"
24
+ # use lower sample rate on cpu for speed; higher on gpu for quality
25
+ sr = 44100 if has_gpu else 22050
26
+ target_loudness = -9.0
27
 
28
  def convert_to_wav(input_path):
29
  if input_path.lower().endswith(".mp3"):
 
34
 
35
  def load_mono(file_path):
36
  if not os.path.exists(file_path):
37
+ return np.zeros(sr * 5)
38
+ # limit to 60s on cpu to prevent hangs
39
+ duration = None if has_gpu else 60
40
+ y, _ = librosa.load(file_path, sr=sr, mono=True, duration=duration)
41
  return y
42
 
43
  def normalize_audio(y):
44
  return y / (np.max(np.abs(y)) + 1e-9)
45
 
46
  def highpass(data, cutoff):
47
+ b, a = butter(4, cutoff / (sr / 2), btype='high')
48
  return lfilter(b, a, data)
49
 
50
  def lowpass(data, cutoff):
51
+ b, a = butter(4, cutoff / (sr / 2), btype='low')
52
  return lfilter(b, a, data)
53
 
54
  def detect_key(y):
55
+ chroma = librosa.feature.chroma_cqt(y=y, sr=sr, hop_length=1024)
56
  return np.argmax(np.sum(chroma, axis=1))
57
 
58
  def match_key(source, target):
59
  key_s = detect_key(source)
60
  key_t = detect_key(target)
61
  shift = key_t - key_s
62
+ if shift == 0: return source
63
+ return librosa.effects.pitch_shift(source, sr=sr, n_steps=float(shift))
64
 
65
  def beat_sync_warp(source, target):
66
+ tempo_t, _ = librosa.beat.beat_track(y=target, sr=sr)
67
+ tempo_s, _ = librosa.beat.beat_track(y=source, sr=sr)
68
  tempo_t = float(np.atleast_1d(tempo_t)[0])
69
  tempo_s = float(np.atleast_1d(tempo_s)[0])
70
+ if tempo_t <= 0 or tempo_s <= 0 or abs(tempo_s - tempo_t) < 1:
71
  return librosa.util.fix_length(source, size=len(target))
72
  rate = tempo_s / tempo_t
73
  warped = librosa.effects.time_stretch(source, rate=float(rate))
 
75
  return warped
76
 
77
  def separate_stems(input_file, job_id):
 
78
  out_dir = f"sep_{job_id}"
79
+ cmd = [
80
+ "demucs", "-n", "htdemucs_ft",
81
+ "--out", out_dir,
82
+ "--device", device_type,
83
+ input_file
84
+ ]
85
+ if not has_gpu:
86
+ cmd.extend(["-j", "2"]) # limit threads on cpu to avoid oom
87
+
88
+ subprocess.run(cmd, check=True)
89
  base = os.path.splitext(os.path.basename(input_file))[0]
90
+ stem_dir = os.path.join(out_dir, "htdemucs_ft", base)
91
+ return {
92
  "drums": os.path.join(stem_dir, "drums.wav"),
93
  "bass": os.path.join(stem_dir, "bass.wav"),
94
  "other": os.path.join(stem_dir, "other.wav"),
95
  "vocals": os.path.join(stem_dir, "vocals.wav")
96
+ }, out_dir
 
97
 
98
  @app.route('/', methods=['GET'])
99
  def health():
100
+ return jsonify({"status": "ready", "hardware": device_type}), 200
 
 
 
 
101
 
102
  @app.route('/fuse', methods=['POST'])
103
  def fuse_api():
104
  job_id = uuid.uuid4().hex[:8]
105
+ temp_files, cleanup_dirs = [], []
 
 
106
  try:
107
+ t_req = request.files.get('melody')
108
+ m_req = request.files.get('style')
109
+ if not t_req or not m_req: return jsonify({"error": "missing files"}), 400
 
 
110
 
111
+ t_path, m_path = f"t_{job_id}.wav", f"m_{job_id}.wav"
112
+ t_req.save(t_path)
113
+ m_req.save(m_path)
 
114
  temp_files.extend([t_path, m_path])
115
 
116
+ # parallel on gpu, sequential on cpu to manage resources
117
+ if has_gpu:
118
+ with ThreadPoolExecutor(max_workers=2) as executor:
119
+ f_t = executor.submit(separate_stems, t_path, f"t_{job_id}")
120
+ f_m = executor.submit(separate_stems, m_path, f"m_{job_id}")
121
+ t_stems, t_dir = f_t.result()
122
+ m_stems, m_dir = f_m.result()
123
+ else:
124
+ t_stems, t_dir = separate_stems(t_path, f"t_{job_id}")
125
+ m_stems, m_dir = separate_stems(m_path, f"m_{job_id}")
126
+
127
  cleanup_dirs.extend([t_dir, m_dir])
128
 
129
+ t_other, t_bass = load_mono(t_stems["other"]), load_mono(t_stems["bass"])
130
+ m_drums, m_bass, m_other = load_mono(m_stems["drums"]), load_mono(m_stems["bass"]), load_mono(m_stems["other"])
 
 
 
 
131
 
132
  target_len = min(len(t_other), len(m_drums))
133
+ t_other, t_bass = t_other[:target_len], t_bass[:target_len]
134
+ m_drums, m_bass, m_other = m_drums[:target_len], m_bass[:target_len], m_other[:target_len]
 
 
 
135
 
 
136
  t_other = match_key(t_other, m_other)
137
  t_bass = match_key(t_bass, m_bass)
138
  t_other = beat_sync_warp(t_other, m_drums)
139
  t_bass = beat_sync_warp(t_bass, m_drums)
140
 
141
+ fusion = normalize_audio(1.0*m_drums + 1.0*m_bass + 1.2*t_other + 0.5*m_other + 0.8*t_bass)
 
 
 
 
 
 
 
142
 
143
  board = Pedalboard([
144
+ HighpassFilter(30), LowpassFilter(18000),
145
+ Compressor(threshold_db=-20, ratio=2), Gain(2.0), Limiter(threshold_db=-0.5)
 
 
 
146
  ])
147
+ fusion_mastered = board(fusion, sr)
148
 
149
+ meter = pyln.Meter(sr)
150
  loudness = meter.integrated_loudness(fusion_mastered)
151
+ fusion_mastered = pyln.normalize.loudness(fusion_mastered, loudness, target_loudness)
152
 
153
  buf = io.BytesIO()
154
+ sf.write(buf, fusion_mastered, sr, format='WAV')
155
  buf.seek(0)
 
 
156
  return send_file(buf, mimetype="audio/wav", as_attachment=True, download_name="fusion.wav")
157
 
158
  except Exception as e:
 
159
  return jsonify({"error": str(e)}), 500
160
  finally:
161
+ for f in temp_files:
162
  if os.path.exists(f): os.remove(f)
163
+ for d in cleanup_dirs:
164
  if os.path.exists(d): shutil.rmtree(d, ignore_errors=True)
165
 
166
  if __name__ == "__main__":