| import os |
| import io |
| import uuid |
| import shutil |
| import numpy as np |
| import librosa |
| import soundfile as sf |
| import pyloudnorm as pyln |
| import torch |
| from flask import Flask, request, send_file, jsonify, make_response |
| from flask_cors import CORS |
| from scipy.signal import butter, lfilter |
| from pedalboard import Pedalboard, Compressor, Limiter, HighpassFilter, LowpassFilter, Gain |
| import subprocess |
| from pydub import AudioSegment |
| from concurrent.futures import ThreadPoolExecutor |
|
|
| app = Flask(__name__) |
| CORS(app) |
|
|
| has_gpu = torch.cuda.is_available() |
| device_type = "cuda" if has_gpu else "cpu" |
| sr = 44100 |
| target_loudness = -9.0 |
|
|
| def convert_to_wav(input_path): |
| if input_path.lower().endswith(".mp3"): |
| wav_path = input_path.rsplit(".", 1)[0] + f"_{uuid.uuid4().hex}.wav" |
| AudioSegment.from_mp3(input_path).export(wav_path, format="wav") |
| return wav_path |
| return input_path |
|
|
| def load_mono(file_path): |
| if not os.path.exists(file_path): |
| return np.zeros(sr * 5) |
| y, _ = librosa.load(file_path, sr=sr, mono=True) |
| return y |
|
|
| def normalize_audio(y): |
| return y / (np.max(np.abs(y)) + 1e-9) |
|
|
| def highpass(data, cutoff): |
| b, a = butter(4, cutoff / (sr / 2), btype='high') |
| return lfilter(b, a, data) |
|
|
| def lowpass(data, cutoff): |
| b, a = butter(4, cutoff / (sr / 2), btype='low') |
| return lfilter(b, a, data) |
|
|
| def detect_key(y): |
| chroma = librosa.feature.chroma_cqt(y=y, sr=sr, hop_length=1024) |
| return np.argmax(np.sum(chroma, axis=1)) |
|
|
| def match_key(source, target): |
| key_s = detect_key(source) |
| key_t = detect_key(target) |
| shift = key_t - key_s |
| if shift == 0: return source |
| return librosa.effects.pitch_shift(source, sr=sr, n_steps=float(shift)) |
|
|
| def beat_sync_warp(source, target): |
| tempo_t, _ = librosa.beat.beat_track(y=target, sr=sr) |
| tempo_s, _ = librosa.beat.beat_track(y=source, sr=sr) |
| tempo_t = float(np.atleast_1d(tempo_t)[0]) |
| tempo_s = float(np.atleast_1d(tempo_s)[0]) |
| if tempo_t <= 0 or tempo_s <= 0 or abs(tempo_s - tempo_t) < 1: |
| return librosa.util.fix_length(source, size=len(target)) |
| rate = tempo_s / tempo_t |
| warped = librosa.effects.time_stretch(source, rate=float(rate)) |
| warped = librosa.util.fix_length(warped, size=len(target)) |
| return warped |
|
|
| def separate_stems(input_file, job_id): |
| out_dir = f"sep_{job_id}" |
| cmd = [ |
| "demucs", "-n", "htdemucs_ft", |
| "--out", out_dir, |
| "--device", device_type, |
| input_file |
| ] |
| if not has_gpu: |
| cmd.extend(["-j", "1"]) |
| |
| subprocess.run(cmd, check=True) |
| base = os.path.splitext(os.path.basename(input_file))[0] |
| stem_dir = os.path.join(out_dir, "htdemucs_ft", base) |
| return { |
| "drums": os.path.join(stem_dir, "drums.wav"), |
| "bass": os.path.join(stem_dir, "bass.wav"), |
| "other": os.path.join(stem_dir, "other.wav"), |
| "vocals": os.path.join(stem_dir, "vocals.wav") |
| }, out_dir |
|
|
| @app.route('/', methods=['GET']) |
| def health(): |
| return jsonify({"status": "ready", "hardware": device_type}), 200 |
|
|
| @app.route('/fuse', methods=['POST']) |
| def fuse_api(): |
| job_id = uuid.uuid4().hex[:8] |
| temp_files, cleanup_dirs = [], [] |
| try: |
| t_req = request.files.get('melody') |
| m_req = request.files.get('style') |
| if not t_req or not m_req: return jsonify({"error": "missing files"}), 400 |
|
|
| t_path, m_path = f"t_{job_id}.wav", f"m_{job_id}.wav" |
| t_req.save(t_path) |
| m_req.save(m_path) |
| temp_files.extend([t_path, m_path]) |
|
|
| max_workers = 2 if has_gpu else 1 |
| with ThreadPoolExecutor(max_workers=max_workers) as executor: |
| f_t = executor.submit(separate_stems, t_path, f"t_{job_id}") |
| f_m = executor.submit(separate_stems, m_path, f"m_{job_id}") |
| t_stems, t_dir = f_t.result() |
| m_stems, m_dir = f_m.result() |
| |
| cleanup_dirs.extend([t_dir, m_dir]) |
|
|
| t_other, t_bass = load_mono(t_stems["other"]), load_mono(t_stems["bass"]) |
| m_drums, m_bass, m_other = load_mono(m_stems["drums"]), load_mono(m_stems["bass"]), load_mono(m_stems["other"]) |
|
|
| target_len = min(len(t_other), len(m_drums)) |
| t_other, t_bass = t_other[:target_len], t_bass[:target_len] |
| m_drums, m_bass, m_other = m_drums[:target_len], m_bass[:target_len], m_other[:target_len] |
|
|
| t_other = match_key(t_other, m_other) |
| t_bass = match_key(t_bass, m_bass) |
| t_other = beat_sync_warp(t_other, m_drums) |
| t_bass = beat_sync_warp(t_bass, m_drums) |
|
|
| fusion = normalize_audio(1.0*m_drums + 1.0*m_bass + 1.2*t_other + 0.5*m_other + 0.8*t_bass) |
|
|
| board = Pedalboard([ |
| HighpassFilter(30), LowpassFilter(18000), |
| Compressor(threshold_db=-20, ratio=2), Gain(2.0), Limiter(threshold_db=-0.5) |
| ]) |
| fusion_mastered = board(fusion, sr) |
| |
| meter = pyln.Meter(sr) |
| loudness = meter.integrated_loudness(fusion_mastered) |
| fusion_mastered = pyln.normalize.loudness(fusion_mastered, loudness, target_loudness) |
|
|
| buf = io.BytesIO() |
| sf.write(buf, fusion_mastered, sr, format='WAV') |
| buf.seek(0) |
| return send_file(buf, mimetype="audio/wav", as_attachment=True, download_name="fusion.wav") |
|
|
| except Exception as e: |
| return jsonify({"error": str(e)}), 500 |
| finally: |
| for f in temp_files: |
| if os.path.exists(f): os.remove(f) |
| for d in cleanup_dirs: |
| if os.path.exists(d): shutil.rmtree(d, ignore_errors=True) |
|
|
| if __name__ == "__main__": |
| app.run(host='0.0.0.0', port=7860) |