fusionAI / app.py
gere's picture
Update app.py
e9b314b verified
import os
import io
import uuid
import shutil
import numpy as np
import librosa
import soundfile as sf
import pyloudnorm as pyln
import torch
from flask import Flask, request, send_file, jsonify, make_response
from flask_cors import CORS
from scipy.signal import butter, lfilter
from pedalboard import Pedalboard, Compressor, Limiter, HighpassFilter, LowpassFilter, Gain
import subprocess
from pydub import AudioSegment
from concurrent.futures import ThreadPoolExecutor
app = Flask(__name__)
CORS(app)
has_gpu = torch.cuda.is_available()
device_type = "cuda" if has_gpu else "cpu"
sr = 44100
target_loudness = -9.0
def convert_to_wav(input_path):
if input_path.lower().endswith(".mp3"):
wav_path = input_path.rsplit(".", 1)[0] + f"_{uuid.uuid4().hex}.wav"
AudioSegment.from_mp3(input_path).export(wav_path, format="wav")
return wav_path
return input_path
def load_mono(file_path):
if not os.path.exists(file_path):
return np.zeros(sr * 5)
y, _ = librosa.load(file_path, sr=sr, mono=True)
return y
def normalize_audio(y):
return y / (np.max(np.abs(y)) + 1e-9)
def highpass(data, cutoff):
b, a = butter(4, cutoff / (sr / 2), btype='high')
return lfilter(b, a, data)
def lowpass(data, cutoff):
b, a = butter(4, cutoff / (sr / 2), btype='low')
return lfilter(b, a, data)
def detect_key(y):
chroma = librosa.feature.chroma_cqt(y=y, sr=sr, hop_length=1024)
return np.argmax(np.sum(chroma, axis=1))
def match_key(source, target):
key_s = detect_key(source)
key_t = detect_key(target)
shift = key_t - key_s
if shift == 0: return source
return librosa.effects.pitch_shift(source, sr=sr, n_steps=float(shift))
def beat_sync_warp(source, target):
tempo_t, _ = librosa.beat.beat_track(y=target, sr=sr)
tempo_s, _ = librosa.beat.beat_track(y=source, sr=sr)
tempo_t = float(np.atleast_1d(tempo_t)[0])
tempo_s = float(np.atleast_1d(tempo_s)[0])
if tempo_t <= 0 or tempo_s <= 0 or abs(tempo_s - tempo_t) < 1:
return librosa.util.fix_length(source, size=len(target))
rate = tempo_s / tempo_t
warped = librosa.effects.time_stretch(source, rate=float(rate))
warped = librosa.util.fix_length(warped, size=len(target))
return warped
def separate_stems(input_file, job_id):
out_dir = f"sep_{job_id}"
cmd = [
"demucs", "-n", "htdemucs_ft",
"--out", out_dir,
"--device", device_type,
input_file
]
if not has_gpu:
cmd.extend(["-j", "1"])
subprocess.run(cmd, check=True)
base = os.path.splitext(os.path.basename(input_file))[0]
stem_dir = os.path.join(out_dir, "htdemucs_ft", base)
return {
"drums": os.path.join(stem_dir, "drums.wav"),
"bass": os.path.join(stem_dir, "bass.wav"),
"other": os.path.join(stem_dir, "other.wav"),
"vocals": os.path.join(stem_dir, "vocals.wav")
}, out_dir
@app.route('/', methods=['GET'])
def health():
return jsonify({"status": "ready", "hardware": device_type}), 200
@app.route('/fuse', methods=['POST'])
def fuse_api():
job_id = uuid.uuid4().hex[:8]
temp_files, cleanup_dirs = [], []
try:
t_req = request.files.get('melody')
m_req = request.files.get('style')
if not t_req or not m_req: return jsonify({"error": "missing files"}), 400
t_path, m_path = f"t_{job_id}.wav", f"m_{job_id}.wav"
t_req.save(t_path)
m_req.save(m_path)
temp_files.extend([t_path, m_path])
max_workers = 2 if has_gpu else 1
with ThreadPoolExecutor(max_workers=max_workers) as executor:
f_t = executor.submit(separate_stems, t_path, f"t_{job_id}")
f_m = executor.submit(separate_stems, m_path, f"m_{job_id}")
t_stems, t_dir = f_t.result()
m_stems, m_dir = f_m.result()
cleanup_dirs.extend([t_dir, m_dir])
t_other, t_bass = load_mono(t_stems["other"]), load_mono(t_stems["bass"])
m_drums, m_bass, m_other = load_mono(m_stems["drums"]), load_mono(m_stems["bass"]), load_mono(m_stems["other"])
target_len = min(len(t_other), len(m_drums))
t_other, t_bass = t_other[:target_len], t_bass[:target_len]
m_drums, m_bass, m_other = m_drums[:target_len], m_bass[:target_len], m_other[:target_len]
t_other = match_key(t_other, m_other)
t_bass = match_key(t_bass, m_bass)
t_other = beat_sync_warp(t_other, m_drums)
t_bass = beat_sync_warp(t_bass, m_drums)
fusion = normalize_audio(1.0*m_drums + 1.0*m_bass + 1.2*t_other + 0.5*m_other + 0.8*t_bass)
board = Pedalboard([
HighpassFilter(30), LowpassFilter(18000),
Compressor(threshold_db=-20, ratio=2), Gain(2.0), Limiter(threshold_db=-0.5)
])
fusion_mastered = board(fusion, sr)
meter = pyln.Meter(sr)
loudness = meter.integrated_loudness(fusion_mastered)
fusion_mastered = pyln.normalize.loudness(fusion_mastered, loudness, target_loudness)
buf = io.BytesIO()
sf.write(buf, fusion_mastered, sr, format='WAV')
buf.seek(0)
return send_file(buf, mimetype="audio/wav", as_attachment=True, download_name="fusion.wav")
except Exception as e:
return jsonify({"error": str(e)}), 500
finally:
for f in temp_files:
if os.path.exists(f): os.remove(f)
for d in cleanup_dirs:
if os.path.exists(d): shutil.rmtree(d, ignore_errors=True)
if __name__ == "__main__":
app.run(host='0.0.0.0', port=7860)