gere commited on
Commit
77eeeaa
·
verified ·
1 Parent(s): e845cba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -146
app.py CHANGED
@@ -1,31 +1,25 @@
1
  import os
2
- os.environ["OMP_NUM_THREADS"] = "2"
3
-
4
  import io
5
  import uuid
6
  import shutil
 
7
  import numpy as np
8
  import librosa
9
  import soundfile as sf
10
  import pyloudnorm as pyln
11
- from flask import Flask, request, send_file, jsonify
12
  from flask_cors import CORS
13
  from scipy.signal import butter, lfilter
14
  from pedalboard import Pedalboard, Compressor, Limiter, HighpassFilter, LowpassFilter, Gain
15
- import subprocess
16
  from pydub import AudioSegment
 
17
  import torch
18
- import threading
19
 
20
  app = Flask(__name__)
21
  CORS(app)
22
 
23
- SR = 22050
24
  TARGET_LOUDNESS = -9.0
25
- MAX_DURATION = 25
26
-
27
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
-
29
 
30
  def convert_to_wav(input_path):
31
  if input_path.lower().endswith(".mp3"):
@@ -34,174 +28,159 @@ def convert_to_wav(input_path):
34
  return wav_path
35
  return input_path
36
 
37
-
38
  def load_mono(file_path):
39
  if not os.path.exists(file_path):
40
  return np.zeros(SR * 5)
41
- y, _ = librosa.load(file_path, sr=SR, mono=True, duration=MAX_DURATION)
42
  return y
43
 
44
-
45
  def normalize_audio(y):
46
  return y / (np.max(np.abs(y)) + 1e-9)
47
 
48
-
49
  def highpass(data, cutoff):
50
  b, a = butter(4, cutoff / (SR / 2), btype='high')
51
  return lfilter(b, a, data)
52
 
53
-
54
  def lowpass(data, cutoff):
55
  b, a = butter(4, cutoff / (SR / 2), btype='low')
56
  return lfilter(b, a, data)
57
 
58
-
59
- def fast_align(source, target_len):
60
- return librosa.util.fix_length(source, size=target_len)
61
-
62
-
63
- def separate_stems(input_file, job_id):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  out_dir = f"sep_{job_id}"
65
-
66
- cmd = [
67
- "demucs",
68
- "-n", "mdx_q",
69
- "--two-stems=drums",
70
- "--out", out_dir,
71
- input_file
72
- ]
73
-
74
- if DEVICE == "cuda":
75
- cmd.insert(4, "--device")
76
- cmd.insert(5, "cuda")
77
-
78
- subprocess.run(cmd, check=True)
79
-
80
  base = os.path.splitext(os.path.basename(input_file))[0]
81
- stem_dir = os.path.join(out_dir, "mdx_q", base)
82
-
83
  stems = {
84
  "drums": os.path.join(stem_dir, "drums.wav"),
85
- "other": os.path.join(stem_dir, "no_drums.wav")
 
 
86
  }
87
-
88
  return stems, out_dir
89
 
90
-
91
  @app.route('/', methods=['GET'])
92
  def health():
93
- return jsonify({"status": "ready", "device": DEVICE})
94
-
 
 
 
95
 
96
  @app.route('/fuse', methods=['POST'])
97
  def fuse_api():
98
  job_id = uuid.uuid4().hex[:8]
99
-
100
  temp_files = []
101
  cleanup_dirs = []
102
 
103
- try:
104
- trad_req = request.files.get('melody')
105
- modern_req = request.files.get('style')
106
-
107
- if not trad_req or not modern_req:
108
- return jsonify({"error": "missing files"}), 400
109
-
110
- t_path = f"t_{job_id}.wav"
111
- m_path = f"m_{job_id}.wav"
112
-
113
- trad_req.save(t_path)
114
- modern_req.save(m_path)
115
-
116
- temp_files.extend([t_path, m_path])
117
-
118
- t_wav = convert_to_wav(t_path)
119
- m_wav = convert_to_wav(m_path)
120
-
121
- if t_wav != t_path:
122
- temp_files.append(t_wav)
123
- if m_wav != m_path:
124
- temp_files.append(m_wav)
125
-
126
- t_res, m_res = {}, {}
127
-
128
- def run_t():
129
- t_res["data"] = separate_stems(t_wav, f"t_{job_id}")
130
-
131
- def run_m():
132
- m_res["data"] = separate_stems(m_wav, f"m_{job_id}")
133
-
134
- th1 = threading.Thread(target=run_t)
135
- th2 = threading.Thread(target=run_m)
136
-
137
- th1.start()
138
- th2.start()
139
- th1.join()
140
- th2.join()
141
-
142
- t_stems, t_dir = t_res["data"]
143
- m_stems, m_dir = m_res["data"]
144
-
145
- cleanup_dirs.extend([t_dir, m_dir])
146
-
147
- t_other = load_mono(t_stems["other"])
148
- m_drums = load_mono(m_stems["drums"])
149
- m_other = load_mono(m_stems["other"])
150
-
151
- target_len = min(len(t_other), len(m_drums))
152
-
153
- t_other = fast_align(t_other, target_len)
154
- m_drums = fast_align(m_drums, target_len)
155
- m_other = fast_align(m_other, target_len)
156
-
157
- t_other = highpass(t_other, 120)
158
- m_drums = lowpass(m_drums, 12000)
159
- m_other = highpass(m_other, 150)
160
-
161
- fusion = (1.2 * t_other + 1.0 * m_drums + 0.6 * m_other)
162
- fusion = normalize_audio(fusion)
163
-
164
- board = Pedalboard([
165
- HighpassFilter(30),
166
- LowpassFilter(18000),
167
- Compressor(threshold_db=-20, ratio=2),
168
- Gain(2.0),
169
- Limiter(threshold_db=-0.5)
170
- ])
171
-
172
- fusion_mastered = board(fusion, SR)
173
-
174
- meter = pyln.Meter(SR)
175
- loudness = meter.integrated_loudness(fusion_mastered)
176
-
177
- fusion_mastered = pyln.normalize.loudness(
178
- fusion_mastered, loudness, TARGET_LOUDNESS
179
- )
180
-
181
- fusion_mastered = np.clip(fusion_mastered, -1.0, 1.0)
182
-
183
- buf = io.BytesIO()
184
- sf.write(buf, fusion_mastered, SR, format='WAV')
185
- buf.seek(0)
186
-
187
- return send_file(
188
- buf,
189
- mimetype="audio/wav",
190
- as_attachment=True,
191
- download_name="fusion.wav"
192
- )
193
-
194
- except Exception as e:
195
- return jsonify({"error": str(e)}), 500
196
-
197
- finally:
198
- for f in temp_files:
199
- if os.path.exists(f):
200
- os.remove(f)
201
- for d in cleanup_dirs:
202
- if os.path.exists(d):
203
- shutil.rmtree(d, ignore_errors=True)
204
-
205
 
206
  if __name__ == "__main__":
207
  app.run(host='0.0.0.0', port=7860)
 
1
  import os
 
 
2
  import io
3
  import uuid
4
  import shutil
5
+ import threading
6
  import numpy as np
7
  import librosa
8
  import soundfile as sf
9
  import pyloudnorm as pyln
10
+ from flask import Flask, request, send_file, jsonify, make_response
11
  from flask_cors import CORS
12
  from scipy.signal import butter, lfilter
13
  from pedalboard import Pedalboard, Compressor, Limiter, HighpassFilter, LowpassFilter, Gain
 
14
  from pydub import AudioSegment
15
+ import subprocess
16
  import torch
 
17
 
18
  app = Flask(__name__)
19
  CORS(app)
20
 
21
+ SR = 44100
22
  TARGET_LOUDNESS = -9.0
 
 
 
 
23
 
24
  def convert_to_wav(input_path):
25
  if input_path.lower().endswith(".mp3"):
 
28
  return wav_path
29
  return input_path
30
 
 
31
  def load_mono(file_path):
32
  if not os.path.exists(file_path):
33
  return np.zeros(SR * 5)
34
+ y, _ = librosa.load(file_path, sr=SR, mono=True)
35
  return y
36
 
 
37
  def normalize_audio(y):
38
  return y / (np.max(np.abs(y)) + 1e-9)
39
 
 
40
  def highpass(data, cutoff):
41
  b, a = butter(4, cutoff / (SR / 2), btype='high')
42
  return lfilter(b, a, data)
43
 
 
44
  def lowpass(data, cutoff):
45
  b, a = butter(4, cutoff / (SR / 2), btype='low')
46
  return lfilter(b, a, data)
47
 
48
+ def detect_key(y):
49
+ chroma = librosa.feature.chroma_cqt(y=y, sr=SR)
50
+ return np.argmax(np.sum(chroma, axis=1))
51
+
52
+ def match_key(source, target):
53
+ key_s = detect_key(source)
54
+ key_t = detect_key(target)
55
+ shift = key_t - key_s
56
+ return librosa.effects.pitch_shift(source, sr=SR, n_steps=float(shift))
57
+
58
+ def beat_sync_warp(source, target):
59
+ tempo_t, _ = librosa.beat.beat_track(y=target, sr=SR)
60
+ tempo_s, _ = librosa.beat.beat_track(y=source, sr=SR)
61
+ tempo_t = float(np.atleast_1d(tempo_t)[0])
62
+ tempo_s = float(np.atleast_1d(tempo_s)[0])
63
+ if tempo_t <= 0 or tempo_s <= 0:
64
+ return librosa.util.fix_length(source, size=len(target))
65
+ rate = tempo_s / tempo_t
66
+ warped = librosa.effects.time_stretch(source, rate=float(rate))
67
+ warped = librosa.util.fix_length(warped, size=len(target))
68
+ return warped
69
+
70
+ def separate_stems(input_file, job_id, model_name="mdx_q"):
71
  out_dir = f"sep_{job_id}"
72
+ device = "cuda" if torch.cuda.is_available() else "cpu"
73
+ cmd = ["demucs", "-n", model_name, "--device", device, "--out", out_dir, input_file]
74
+ try:
75
+ subprocess.run(cmd, check=True)
76
+ except subprocess.CalledProcessError as e:
77
+ raise RuntimeError(f"Demucs failed: {e}")
 
 
 
 
 
 
 
 
 
78
  base = os.path.splitext(os.path.basename(input_file))[0]
79
+ stem_dir = os.path.join(out_dir, model_name, base)
 
80
  stems = {
81
  "drums": os.path.join(stem_dir, "drums.wav"),
82
+ "bass": os.path.join(stem_dir, "bass.wav"),
83
+ "other": os.path.join(stem_dir, "other.wav"),
84
+ "vocals": os.path.join(stem_dir, "vocals.wav")
85
  }
 
86
  return stems, out_dir
87
 
 
88
  @app.route('/', methods=['GET'])
89
  def health():
90
+ response = make_response(jsonify({"status": "ready"}), 200)
91
+ response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
92
+ response.headers["Pragma"] = "no-cache"
93
+ response.headers["Expires"] = "0"
94
+ return response
95
 
96
  @app.route('/fuse', methods=['POST'])
97
  def fuse_api():
98
  job_id = uuid.uuid4().hex[:8]
 
99
  temp_files = []
100
  cleanup_dirs = []
101
 
102
+ trad_req = request.files.get('melody')
103
+ modern_req = request.files.get('style')
104
+
105
+ if not trad_req or not modern_req:
106
+ return jsonify({"error": "missing files"}), 400
107
+
108
+ t_path = f"t_{job_id}.wav"
109
+ m_path = f"m_{job_id}.wav"
110
+ trad_req.save(t_path)
111
+ modern_req.save(m_path)
112
+ temp_files.extend([t_path, m_path])
113
+
114
+ t_wav = convert_to_wav(t_path)
115
+ m_wav = convert_to_wav(m_path)
116
+ if t_wav != t_path: temp_files.append(t_wav)
117
+ if m_wav != m_path: temp_files.append(m_wav)
118
+
119
+ t_res, m_res = {}, {}
120
+
121
+ def run_t():
122
+ t_res["data"], t_dir = separate_stems(t_wav, f"t_{job_id}")
123
+ cleanup_dirs.append(t_dir)
124
+
125
+ def run_m():
126
+ m_res["data"], m_dir = separate_stems(m_wav, f"m_{job_id}")
127
+ cleanup_dirs.append(m_dir)
128
+
129
+ t_thread = threading.Thread(target=run_t)
130
+ m_thread = threading.Thread(target=run_m)
131
+ t_thread.start()
132
+ m_thread.start()
133
+ t_thread.join()
134
+ m_thread.join()
135
+
136
+ t_stems = t_res["data"]
137
+ m_stems = m_res["data"]
138
+
139
+ t_other = load_mono(t_stems["other"])
140
+ t_bass = load_mono(t_stems["bass"])
141
+ m_drums = load_mono(m_stems["drums"])
142
+ m_bass = load_mono(m_stems["bass"])
143
+ m_other = load_mono(m_stems["other"])
144
+
145
+ target_len = min(len(t_other), len(m_drums))
146
+ t_other, t_bass = t_other[:target_len], t_bass[:target_len]
147
+ m_drums, m_bass, m_other = m_drums[:target_len], m_bass[:target_len], m_other[:target_len]
148
+
149
+ t_other, t_bass = match_key(t_other, m_other), match_key(t_bass, m_bass)
150
+ t_other, t_bass = beat_sync_warp(t_other, m_drums), beat_sync_warp(t_bass, m_drums)
151
+
152
+ t_other = highpass(t_other, 120)
153
+ t_bass = highpass(t_bass, 60)
154
+ m_bass = lowpass(m_bass, 250)
155
+ m_drums = lowpass(m_drums, 12000)
156
+ m_other = highpass(m_other, 150)
157
+
158
+ fusion = 1.0 * m_drums + 1.0 * m_bass + 1.2 * t_other + 0.5 * m_other + 0.8 * t_bass
159
+ fusion = normalize_audio(fusion)
160
+
161
+ board = Pedalboard([
162
+ HighpassFilter(30),
163
+ LowpassFilter(18000),
164
+ Compressor(threshold_db=-20, ratio=2),
165
+ Gain(2.0),
166
+ Limiter(threshold_db=-0.5)
167
+ ])
168
+ fusion_mastered = board(fusion, SR)
169
+
170
+ meter = pyln.Meter(SR)
171
+ loudness = meter.integrated_loudness(fusion_mastered)
172
+ fusion_mastered = pyln.normalize.loudness(fusion_mastered, loudness, TARGET_LOUDNESS)
173
+
174
+ buf = io.BytesIO()
175
+ sf.write(buf, fusion_mastered, SR, format='WAV')
176
+ buf.seek(0)
177
+
178
+ for f in temp_files:
179
+ if os.path.exists(f): os.remove(f)
180
+ for d in cleanup_dirs:
181
+ if os.path.exists(d): shutil.rmtree(d, ignore_errors=True)
182
+
183
+ return send_file(buf, mimetype="audio/wav", as_attachment=True, download_name="fusion.wav")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
  if __name__ == "__main__":
186
  app.run(host='0.0.0.0', port=7860)