File size: 6,066 Bytes
651ee01
2dd988c
 
8eb21dd
2dd988c
 
 
 
3f26513
8eb21dd
2dd988c
8eb21dd
 
f3d75a9
8eb21dd
3f26513
b979387
845f995
2dd988c
77eeeaa
3f26513
 
b5dbdf0
 
 
 
 
3f26513
7f8ecbe
8eb21dd
 
 
 
 
 
 
f3d75a9
 
3f26513
e9b314b
f3d75a9
 
 
 
 
 
3f26513
f3d75a9
 
 
3f26513
f3d75a9
 
 
3f26513
f3d75a9
 
 
 
 
 
3f26513
 
f3d75a9
 
3f26513
 
f3d75a9
 
3f26513
f3d75a9
 
 
 
 
 
 
 
3f26513
b5dbdf0
3f26513
 
 
 
b5dbdf0
3f26513
e9b314b
3f26513
 
f3d75a9
b5dbdf0
3f26513
f3d75a9
 
 
 
3f26513
2dd988c
f3d75a9
 
b5dbdf0
b7e27bd
2dd988c
aed4377
8eb21dd
3f26513
2dd988c
3f26513
 
 
b1bd12b
3f26513
 
 
8eb21dd
b1bd12b
b5dbdf0
 
 
 
 
 
 
 
 
 
3f26513
f3d75a9
 
3f26513
 
f3d75a9
 
3f26513
 
f3d75a9
 
 
 
 
 
3f26513
f3d75a9
 
3f26513
 
f3d75a9
3f26513
f3d75a9
3f26513
f3d75a9
3f26513
b1bd12b
2dd988c
3f26513
2dd988c
 
b1bd12b
2dd988c
 
 
3f26513
b1bd12b
3f26513
f3d75a9
4c52bf8
651ee01
a34d937
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os
import io
import uuid
import shutil
import numpy as np
import librosa
import soundfile as sf
import pyloudnorm as pyln
import torch
from flask import Flask, request, send_file, jsonify, make_response
from flask_cors import CORS
from scipy.signal import butter, lfilter
from pedalboard import Pedalboard, Compressor, Limiter, HighpassFilter, LowpassFilter, Gain
import subprocess
from pydub import AudioSegment
from concurrent.futures import ThreadPoolExecutor

app = Flask(__name__)
CORS(app)

has_gpu = torch.cuda.is_available()
device_type = "cuda" if has_gpu else "cpu"

# use 22k for cpu to make it 2x faster; 44k for gpu quality
sr = 44100 if has_gpu else 22050
# use the lighter model for cpu to prevent hanging
model_name = "htdemucs_ft" if has_gpu else "htdemucs_lite"
target_loudness = -9.0

def convert_to_wav(input_path):
    if input_path.lower().endswith(".mp3"):
        wav_path = input_path.rsplit(".", 1)[0] + f"_{uuid.uuid4().hex}.wav"
        AudioSegment.from_mp3(input_path).export(wav_path, format="wav")
        return wav_path
    return input_path

def load_mono(file_path):
    if not os.path.exists(file_path):
        return np.zeros(sr * 5)
    y, _ = librosa.load(file_path, sr=sr, mono=True)
    return y

def normalize_audio(y):
    return y / (np.max(np.abs(y)) + 1e-9)

def highpass(data, cutoff):
    b, a = butter(4, cutoff / (sr / 2), btype='high')
    return lfilter(b, a, data)

def lowpass(data, cutoff):
    b, a = butter(4, cutoff / (sr / 2), btype='low')
    return lfilter(b, a, data)

def detect_key(y):
    chroma = librosa.feature.chroma_cqt(y=y, sr=sr, hop_length=1024)
    return np.argmax(np.sum(chroma, axis=1))

def match_key(source, target):
    key_s = detect_key(source)
    key_t = detect_key(target)
    shift = key_t - key_s
    if shift == 0: return source
    return librosa.effects.pitch_shift(source, sr=sr, n_steps=float(shift))

def beat_sync_warp(source, target):
    tempo_t, _ = librosa.beat.beat_track(y=target, sr=sr)
    tempo_s, _ = librosa.beat.beat_track(y=source, sr=sr)
    tempo_t = float(np.atleast_1d(tempo_t)[0])
    tempo_s = float(np.atleast_1d(tempo_s)[0])
    if tempo_t <= 0 or tempo_s <= 0 or abs(tempo_s - tempo_t) < 1:
        return librosa.util.fix_length(source, size=len(target))
    rate = tempo_s / tempo_t
    warped = librosa.effects.time_stretch(source, rate=float(rate))
    warped = librosa.util.fix_length(warped, size=len(target))
    return warped

def separate_stems(input_file, job_id):
    out_dir = f"sep_{job_id}"
    cmd = [
        "demucs", "-n", model_name, 
        "--out", out_dir, 
        "--device", device_type,
        input_file
    ]
    # use only 1 thread for cpu to prevent memory lockups
    if not has_gpu:
        cmd.extend(["-j", "1"])
    
    subprocess.run(cmd, check=True)
    base = os.path.splitext(os.path.basename(input_file))[0]
    stem_dir = os.path.join(out_dir, model_name, base)
    return {
        "drums": os.path.join(stem_dir, "drums.wav"),
        "bass": os.path.join(stem_dir, "bass.wav"),
        "other": os.path.join(stem_dir, "other.wav"),
        "vocals": os.path.join(stem_dir, "vocals.wav")
    }, out_dir

@app.route('/', methods=['GET'])
def health():
    return jsonify({"status": "ready", "hardware": device_type, "model": model_name}), 200

@app.route('/fuse', methods=['POST'])
def fuse_api():
    job_id = uuid.uuid4().hex[:8]
    temp_files, cleanup_dirs = [], []
    try:
        t_req = request.files.get('melody')
        m_req = request.files.get('style')
        if not t_req or not m_req: return jsonify({"error": "missing files"}), 400

        t_path, m_path = f"t_{job_id}.wav", f"m_{job_id}.wav"
        t_req.save(t_path)
        m_req.save(m_path)
        temp_files.extend([t_path, m_path])

        # run in parallel only on gpu; sequentially on cpu to prevent hangs
        if has_gpu:
            with ThreadPoolExecutor(max_workers=2) as executor:
                f_t = executor.submit(separate_stems, t_path, f"t_{job_id}")
                f_m = executor.submit(separate_stems, m_path, f"m_{job_id}")
                t_stems, t_dir = f_t.result()
                m_stems, m_dir = f_m.result()
        else:
            t_stems, t_dir = separate_stems(t_path, f"t_{job_id}")
            m_stems, m_dir = separate_stems(m_path, f"m_{job_id}")
        
        cleanup_dirs.extend([t_dir, m_dir])

        t_other, t_bass = load_mono(t_stems["other"]), load_mono(t_stems["bass"])
        m_drums, m_bass, m_other = load_mono(m_stems["drums"]), load_mono(m_stems["bass"]), load_mono(m_stems["other"])

        target_len = min(len(t_other), len(m_drums))
        t_other, t_bass = t_other[:target_len], t_bass[:target_len]
        m_drums, m_bass, m_other = m_drums[:target_len], m_bass[:target_len], m_other[:target_len]

        t_other = match_key(t_other, m_other)
        t_bass = match_key(t_bass, m_bass)
        t_other = beat_sync_warp(t_other, m_drums)
        t_bass = beat_sync_warp(t_bass, m_drums)

        fusion = normalize_audio(1.0*m_drums + 1.0*m_bass + 1.2*t_other + 0.5*m_other + 0.8*t_bass)

        board = Pedalboard([
            HighpassFilter(30), LowpassFilter(18000),
            Compressor(threshold_db=-20, ratio=2), Gain(2.0), Limiter(threshold_db=-0.5)
        ])
        fusion_mastered = board(fusion, sr)
        
        meter = pyln.Meter(sr)
        loudness = meter.integrated_loudness(fusion_mastered)
        fusion_mastered = pyln.normalize.loudness(fusion_mastered, loudness, target_loudness)

        buf = io.BytesIO()
        sf.write(buf, fusion_mastered, sr, format='WAV')
        buf.seek(0)
        return send_file(buf, mimetype="audio/wav", as_attachment=True, download_name="fusion.wav")

    except Exception as e:
        return jsonify({"error": str(e)}), 500
    finally:
        for f in temp_files: 
            if os.path.exists(f): os.remove(f)
        for d in cleanup_dirs: 
            if os.path.exists(d): shutil.rmtree(d, ignore_errors=True)

if __name__ == "__main__":
    app.run(host='0.0.0.0', port=7860)