import os import io import uuid import shutil import numpy as np import librosa import soundfile as sf import pyloudnorm as pyln import torch from flask import Flask, request, send_file, jsonify, make_response from flask_cors import CORS from scipy.signal import butter, lfilter from pedalboard import Pedalboard, Compressor, Limiter, HighpassFilter, LowpassFilter, Gain import subprocess from pydub import AudioSegment from concurrent.futures import ThreadPoolExecutor app = Flask(__name__) CORS(app) has_gpu = torch.cuda.is_available() device_type = "cuda" if has_gpu else "cpu" # use 22k for cpu to make it 2x faster; 44k for gpu quality sr = 44100 if has_gpu else 22050 # use the lighter model for cpu to prevent hanging model_name = "htdemucs_ft" if has_gpu else "htdemucs_lite" target_loudness = -9.0 def convert_to_wav(input_path): if input_path.lower().endswith(".mp3"): wav_path = input_path.rsplit(".", 1)[0] + f"_{uuid.uuid4().hex}.wav" AudioSegment.from_mp3(input_path).export(wav_path, format="wav") return wav_path return input_path def load_mono(file_path): if not os.path.exists(file_path): return np.zeros(sr * 5) y, _ = librosa.load(file_path, sr=sr, mono=True) return y def normalize_audio(y): return y / (np.max(np.abs(y)) + 1e-9) def highpass(data, cutoff): b, a = butter(4, cutoff / (sr / 2), btype='high') return lfilter(b, a, data) def lowpass(data, cutoff): b, a = butter(4, cutoff / (sr / 2), btype='low') return lfilter(b, a, data) def detect_key(y): chroma = librosa.feature.chroma_cqt(y=y, sr=sr, hop_length=1024) return np.argmax(np.sum(chroma, axis=1)) def match_key(source, target): key_s = detect_key(source) key_t = detect_key(target) shift = key_t - key_s if shift == 0: return source return librosa.effects.pitch_shift(source, sr=sr, n_steps=float(shift)) def beat_sync_warp(source, target): tempo_t, _ = librosa.beat.beat_track(y=target, sr=sr) tempo_s, _ = librosa.beat.beat_track(y=source, sr=sr) tempo_t = float(np.atleast_1d(tempo_t)[0]) tempo_s = float(np.atleast_1d(tempo_s)[0]) if tempo_t <= 0 or tempo_s <= 0 or abs(tempo_s - tempo_t) < 1: return librosa.util.fix_length(source, size=len(target)) rate = tempo_s / tempo_t warped = librosa.effects.time_stretch(source, rate=float(rate)) warped = librosa.util.fix_length(warped, size=len(target)) return warped def separate_stems(input_file, job_id): out_dir = f"sep_{job_id}" cmd = [ "demucs", "-n", model_name, "--out", out_dir, "--device", device_type, input_file ] # use only 1 thread for cpu to prevent memory lockups if not has_gpu: cmd.extend(["-j", "1"]) subprocess.run(cmd, check=True) base = os.path.splitext(os.path.basename(input_file))[0] stem_dir = os.path.join(out_dir, model_name, base) return { "drums": os.path.join(stem_dir, "drums.wav"), "bass": os.path.join(stem_dir, "bass.wav"), "other": os.path.join(stem_dir, "other.wav"), "vocals": os.path.join(stem_dir, "vocals.wav") }, out_dir @app.route('/', methods=['GET']) def health(): return jsonify({"status": "ready", "hardware": device_type, "model": model_name}), 200 @app.route('/fuse', methods=['POST']) def fuse_api(): job_id = uuid.uuid4().hex[:8] temp_files, cleanup_dirs = [], [] try: t_req = request.files.get('melody') m_req = request.files.get('style') if not t_req or not m_req: return jsonify({"error": "missing files"}), 400 t_path, m_path = f"t_{job_id}.wav", f"m_{job_id}.wav" t_req.save(t_path) m_req.save(m_path) temp_files.extend([t_path, m_path]) # run in parallel only on gpu; sequentially on cpu to prevent hangs if has_gpu: with ThreadPoolExecutor(max_workers=2) as executor: f_t = executor.submit(separate_stems, t_path, f"t_{job_id}") f_m = executor.submit(separate_stems, m_path, f"m_{job_id}") t_stems, t_dir = f_t.result() m_stems, m_dir = f_m.result() else: t_stems, t_dir = separate_stems(t_path, f"t_{job_id}") m_stems, m_dir = separate_stems(m_path, f"m_{job_id}") cleanup_dirs.extend([t_dir, m_dir]) t_other, t_bass = load_mono(t_stems["other"]), load_mono(t_stems["bass"]) m_drums, m_bass, m_other = load_mono(m_stems["drums"]), load_mono(m_stems["bass"]), load_mono(m_stems["other"]) target_len = min(len(t_other), len(m_drums)) t_other, t_bass = t_other[:target_len], t_bass[:target_len] m_drums, m_bass, m_other = m_drums[:target_len], m_bass[:target_len], m_other[:target_len] t_other = match_key(t_other, m_other) t_bass = match_key(t_bass, m_bass) t_other = beat_sync_warp(t_other, m_drums) t_bass = beat_sync_warp(t_bass, m_drums) fusion = normalize_audio(1.0*m_drums + 1.0*m_bass + 1.2*t_other + 0.5*m_other + 0.8*t_bass) board = Pedalboard([ HighpassFilter(30), LowpassFilter(18000), Compressor(threshold_db=-20, ratio=2), Gain(2.0), Limiter(threshold_db=-0.5) ]) fusion_mastered = board(fusion, sr) meter = pyln.Meter(sr) loudness = meter.integrated_loudness(fusion_mastered) fusion_mastered = pyln.normalize.loudness(fusion_mastered, loudness, target_loudness) buf = io.BytesIO() sf.write(buf, fusion_mastered, sr, format='WAV') buf.seek(0) return send_file(buf, mimetype="audio/wav", as_attachment=True, download_name="fusion.wav") except Exception as e: return jsonify({"error": str(e)}), 500 finally: for f in temp_files: if os.path.exists(f): os.remove(f) for d in cleanup_dirs: if os.path.exists(d): shutil.rmtree(d, ignore_errors=True) if __name__ == "__main__": app.run(host='0.0.0.0', port=7860)