gere commited on
Commit
8eb21dd
·
verified ·
1 Parent(s): 5f15633

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +140 -117
app.py CHANGED
@@ -1,154 +1,177 @@
1
  import os
2
-
3
- # Fix for the libgomp error in your logs
4
- os.environ["OMP_NUM_THREADS"] = "1"
5
- os.environ["MKL_NUM_THREADS"] = "1"
6
-
7
  import io
8
- import gc
9
  import uuid
 
10
  import numpy as np
11
  import librosa
12
  import soundfile as sf
13
  import pyloudnorm as pyln
14
- from flask import Flask, request, send_file, jsonify
15
  from flask_cors import CORS
16
- from pedalboard import Pedalboard, Compressor, Limiter, Gain
17
- import torch
18
- import torchaudio
19
- from demucs.pretrained import get_model
20
- from demucs.apply import apply_model
21
 
22
  app = Flask(__name__)
23
  CORS(app)
24
 
25
  SR = 44100
26
  TARGET_LOUDNESS = -9.0
27
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
28
-
29
- # PRE-LOAD MODEL (Warm Start)
30
- MODEL = None
31
- try:
32
- print(f">>> [hf] booting demucs on {DEVICE}...")
33
- MODEL = get_model('htdemucs')
34
- MODEL.to(DEVICE)
35
- MODEL.eval()
36
- print(">>> [hf] engine ready")
37
- except Exception as e:
38
- print(f"!!! CRITICAL: model load failed: {e}")
39
-
40
- def clear_vram():
41
- if torch.cuda.is_available():
42
- torch.cuda.empty_cache()
43
- torch.cuda.ipc_collect()
44
- gc.collect()
45
-
46
- def separate_to_memory(input_file):
47
- clear_vram()
48
- # Load directly from the Flask file storage object
49
- wav, sr = torchaudio.load(input_file)
50
- if sr != SR:
51
- wav = torchaudio.transforms.Resample(sr, SR)(wav)
52
-
53
- wav = wav.to(DEVICE)
54
- # Standardize audio for model input
55
- ref = wav.mean(0)
56
- wav = (wav - ref.mean()) / (ref.std() + 1e-8)
57
-
58
- with torch.no_grad():
59
- # Aggressive speed settings for 16GB T4
60
- # shifts=0 and overlap=0 removes the extra processing passes
61
- sources = apply_model(MODEL, wav[None], device=DEVICE, shifts=0, overlap=0)[0]
62
-
63
- sources = sources * ref.std() + ref.mean()
64
-
65
- # Map sources to dictionary of numpy arrays (in-memory only)
66
- stems = {name: source.mean(0).cpu().numpy() for name, source in zip(MODEL.sources, sources)}
67
-
68
- del wav, sources
69
- clear_vram()
70
- return stems
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  @app.route('/', methods=['GET'])
73
- @app.route('/health', methods=['GET'])
74
  def health():
75
- if MODEL is None:
76
- return jsonify({"status": "starting", "device": DEVICE}), 503
77
- return jsonify({"status": "ready", "device": DEVICE}), 200
 
 
78
 
79
  @app.route('/fuse', methods=['POST'])
80
  def fuse_api():
81
- if MODEL is None:
82
- return jsonify({"error": "engine not initialized"}), 503
83
-
 
84
  try:
85
- t_file = request.files.get('melody')
86
- m_file = request.files.get('style')
87
 
88
- if not t_file or not m_file:
89
- return jsonify({"error": "files missing"}), 400
90
-
91
- # Memory-based separation (Zero Disk Write)
92
- t_stems = separate_to_memory(t_file)
93
- m_stems = separate_to_memory(m_file)
94
-
95
- # Get stems
96
- t_other, t_bass = t_stems["other"], t_stems["bass"]
97
- m_drums, m_bass, m_other = m_stems["drums"], m_stems["bass"], m_stems["other"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  target_len = min(len(t_other), len(m_drums))
100
-
101
- # Fast Sync & Key matching
102
- def sync_audio(src, trg_key_ref, sync_ref):
103
- src_slice = src[:target_len]
104
- # Key Detection
105
- s_k = np.argmax(np.sum(librosa.feature.chroma_cqt(y=src_slice, sr=SR), axis=1))
106
- t_k = np.argmax(np.sum(librosa.feature.chroma_cqt(y=trg_key_ref[:target_len], sr=SR), axis=1))
107
- src_slice = librosa.effects.pitch_shift(src_slice, sr=SR, n_steps=float(t_k - s_k))
108
-
109
- # Beat Detection
110
- t_s, _ = librosa.beat.beat_track(y=src_slice, sr=SR)
111
- t_t, _ = librosa.beat.beat_track(y=sync_ref[:target_len], sr=SR)
112
- rate = float(np.atleast_1d(t_s)[0]) / float(np.atleast_1d(t_t)[0])
113
-
114
- return librosa.util.fix_length(librosa.effects.time_stretch(src_slice, rate=rate), size=target_len)
115
-
116
- # Build final mix
117
- t_other_sync = sync_audio(t_other, m_other, m_drums)
118
- t_bass_sync = sync_audio(t_bass, m_bass, m_drums)
119
-
120
- fusion = (1.0 * m_drums[:target_len] +
121
- 1.0 * m_bass[:target_len] +
122
- 1.2 * t_other_sync +
123
- 0.8 * t_bass_sync)
 
 
 
 
 
124
 
125
- # Prevent clipping
126
- fusion /= (np.max(np.abs(fusion)) + 1e-9)
127
-
128
- # Mastering
129
- mastered = Pedalboard([
130
- Compressor(threshold_db=-18, ratio=2.5),
131
- Gain(1.5),
132
- Limiter(threshold_db=-0.1)
133
- ])(fusion, SR)
134
-
135
- # Loudness Normalization
136
  meter = pyln.Meter(SR)
137
- loudness = meter.integrated_loudness(mastered)
138
- mastered = pyln.normalize.loudness(mastered, loudness, TARGET_LOUDNESS)
139
 
140
- # Send back without saving to disk
141
  buf = io.BytesIO()
142
- sf.write(buf, mastered, SR, format='WAV')
143
  buf.seek(0)
144
 
 
145
  return send_file(buf, mimetype="audio/wav", as_attachment=True, download_name="fusion.wav")
146
 
147
  except Exception as e:
148
- print(f"CRITICAL ERROR: {str(e)}")
149
  return jsonify({"error": str(e)}), 500
150
  finally:
151
- clear_vram()
 
 
 
152
 
153
  if __name__ == "__main__":
154
- app.run(host='0.0.0.0', port=7860, threaded=False)
 
1
  import os
 
 
 
 
 
2
  import io
 
3
  import uuid
4
+ import shutil
5
  import numpy as np
6
  import librosa
7
  import soundfile as sf
8
  import pyloudnorm as pyln
9
+ from flask import Flask, request, send_file, jsonify, make_response
10
  from flask_cors import CORS
11
+ from scipy.signal import butter, lfilter
12
+ from pedalboard import Pedalboard, Compressor, Limiter, HighpassFilter, LowpassFilter, Gain
13
+ import subprocess
14
+ from pydub import AudioSegment
 
15
 
16
  app = Flask(__name__)
17
  CORS(app)
18
 
19
  SR = 44100
20
  TARGET_LOUDNESS = -9.0
21
+
22
+ def convert_to_wav(input_path):
23
+ if input_path.lower().endswith(".mp3"):
24
+ wav_path = input_path.rsplit(".", 1)[0] + f"_{uuid.uuid4().hex}.wav"
25
+ AudioSegment.from_mp3(input_path).export(wav_path, format="wav")
26
+ return wav_path
27
+ return input_path
28
+
29
+ def load_mono(file_path):
30
+ if not os.path.exists(file_path):
31
+ return np.zeros(SR * 5)
32
+ y, _ = librosa.load(file_path, sr=SR, mono=True)
33
+ return y
34
+
35
+ def normalize_audio(y):
36
+ return y / (np.max(np.abs(y)) + 1e-9)
37
+
38
+ def highpass(data, cutoff):
39
+ b, a = butter(4, cutoff / (SR / 2), btype='high')
40
+ return lfilter(b, a, data)
41
+
42
+ def lowpass(data, cutoff):
43
+ b, a = butter(4, cutoff / (SR / 2), btype='low')
44
+ return lfilter(b, a, data)
45
+
46
+ def detect_key(y):
47
+ chroma = librosa.feature.chroma_cqt(y=y, sr=SR)
48
+ return np.argmax(np.sum(chroma, axis=1))
49
+
50
+ def match_key(source, target):
51
+ key_s = detect_key(source)
52
+ key_t = detect_key(target)
53
+ shift = key_t - key_s
54
+ return librosa.effects.pitch_shift(source, sr=SR, n_steps=float(shift))
55
+
56
+ def beat_sync_warp(source, target):
57
+ tempo_t, _ = librosa.beat.beat_track(y=target, sr=SR)
58
+ tempo_s, _ = librosa.beat.beat_track(y=source, sr=SR)
59
+ tempo_t = float(np.atleast_1d(tempo_t)[0])
60
+ tempo_s = float(np.atleast_1d(tempo_s)[0])
61
+ if tempo_t <= 0 or tempo_s <= 0:
62
+ return librosa.util.fix_length(source, size=len(target))
63
+ rate = tempo_s / tempo_t
64
+ warped = librosa.effects.time_stretch(source, rate=float(rate))
65
+ warped = librosa.util.fix_length(warped, size=len(target))
66
+ return warped
67
+
68
+ def separate_stems(input_file, job_id):
69
+ print(f">>> [hf] starting demucs for {input_file}")
70
+ out_dir = f"sep_{job_id}"
71
+ subprocess.run(["demucs", "-n", "htdemucs", "--out", out_dir, input_file], check=True)
72
+ base = os.path.splitext(os.path.basename(input_file))[0]
73
+ stem_dir = os.path.join(out_dir, "htdemucs", base)
74
+ stems = {
75
+ "drums": os.path.join(stem_dir, "drums.wav"),
76
+ "bass": os.path.join(stem_dir, "bass.wav"),
77
+ "other": os.path.join(stem_dir, "other.wav"),
78
+ "vocals": os.path.join(stem_dir, "vocals.wav")
79
+ }
80
+ return stems, out_dir
81
 
82
  @app.route('/', methods=['GET'])
 
83
  def health():
84
+ response = make_response(jsonify({"status": "ready"}), 200)
85
+ response.headers["Cache-Control"] = "no-cache, no-store, must-revalidate"
86
+ response.headers["Pragma"] = "no-cache"
87
+ response.headers["Expires"] = "0"
88
+ return response
89
 
90
  @app.route('/fuse', methods=['POST'])
91
  def fuse_api():
92
+ job_id = uuid.uuid4().hex[:8]
93
+ temp_files = []
94
+ cleanup_dirs = []
95
+ print(f">>> [hf] new request: job_{job_id}")
96
  try:
97
+ trad_req = request.files.get('melody')
98
+ modern_req = request.files.get('style')
99
 
100
+ if not trad_req or not modern_req:
101
+ return jsonify({"error": "missing files"}), 400
102
+
103
+ t_path = f"t_{job_id}.wav"
104
+ m_path = f"m_{job_id}.wav"
105
+ trad_req.save(t_path)
106
+ modern_req.save(m_path)
107
+ temp_files.extend([t_path, m_path])
108
+
109
+ t_wav = convert_to_wav(t_path)
110
+ m_wav = convert_to_wav(m_path)
111
+ if t_wav != t_path: temp_files.append(t_wav)
112
+ if m_wav != m_path: temp_files.append(m_wav)
113
+
114
+ t_stems, t_dir = separate_stems(t_wav, f"t_{job_id}")
115
+ m_stems, m_dir = separate_stems(m_wav, f"m_{job_id}")
116
+ cleanup_dirs.extend([t_dir, m_dir])
117
+
118
+ print(f">>> [hf] loading stems for job_{job_id}")
119
+ t_other = load_mono(t_stems["other"])
120
+ t_bass = load_mono(t_stems["bass"])
121
+ m_drums = load_mono(m_stems["drums"])
122
+ m_bass = load_mono(m_stems["bass"])
123
+ m_other = load_mono(m_stems["other"])
124
 
125
  target_len = min(len(t_other), len(m_drums))
126
+ t_other = t_other[:target_len]
127
+ t_bass = t_bass[:target_len]
128
+ m_drums = m_drums[:target_len]
129
+ m_bass = m_bass[:target_len]
130
+ m_other = m_other[:target_len]
131
+
132
+ print(f">>> [hf] matching key and warp...")
133
+ t_other = match_key(t_other, m_other)
134
+ t_bass = match_key(t_bass, m_bass)
135
+ t_other = beat_sync_warp(t_other, m_drums)
136
+ t_bass = beat_sync_warp(t_bass, m_drums)
137
+
138
+ t_other = highpass(t_other, 120)
139
+ t_bass = highpass(t_bass, 60)
140
+ m_bass = lowpass(m_bass, 250)
141
+ m_drums = lowpass(m_drums, 12000)
142
+ m_other = highpass(m_other, 150)
143
+
144
+ fusion = (1.0 * m_drums + 1.0 * m_bass + 1.2 * t_other + 0.5 * m_other + 0.8 * t_bass)
145
+ fusion = normalize_audio(fusion)
146
+
147
+ board = Pedalboard([
148
+ HighpassFilter(30),
149
+ LowpassFilter(18000),
150
+ Compressor(threshold_db=-20, ratio=2),
151
+ Gain(2.0),
152
+ Limiter(threshold_db=-0.5)
153
+ ])
154
+ fusion_mastered = board(fusion, SR)
155
 
 
 
 
 
 
 
 
 
 
 
 
156
  meter = pyln.Meter(SR)
157
+ loudness = meter.integrated_loudness(fusion_mastered)
158
+ fusion_mastered = pyln.normalize.loudness(fusion_mastered, loudness, TARGET_LOUDNESS)
159
 
 
160
  buf = io.BytesIO()
161
+ sf.write(buf, fusion_mastered, SR, format='WAV')
162
  buf.seek(0)
163
 
164
+ print(f">>> [hf] job_{job_id} complete!")
165
  return send_file(buf, mimetype="audio/wav", as_attachment=True, download_name="fusion.wav")
166
 
167
  except Exception as e:
168
+ print(f">>> [hf] error in job_{job_id}: {str(e)}")
169
  return jsonify({"error": str(e)}), 500
170
  finally:
171
+ for f in temp_files:
172
+ if os.path.exists(f): os.remove(f)
173
+ for d in cleanup_dirs:
174
+ if os.path.exists(d): shutil.rmtree(d, ignore_errors=True)
175
 
176
  if __name__ == "__main__":
177
+ app.run(host='0.0.0.0', port=7860)