File size: 5,323 Bytes
62d2c76
 
34b4ff9
 
 
 
 
 
62d2c76
 
34b4ff9
 
 
 
62d2c76
 
 
 
34b4ff9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62d2c76
34b4ff9
 
 
62d2c76
34b4ff9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62d2c76
 
 
 
 
 
34b4ff9
 
 
 
62d2c76
34b4ff9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb1c64b
34b4ff9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62d2c76
 
34b4ff9
 
 
 
 
62d2c76
 
bdee77c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import io
import uuid
import shutil
import numpy as np
import librosa
import soundfile as sf
import pyloudnorm as pyln
from flask import Flask, request, send_file, jsonify
from flask_cors import CORS
from scipy.signal import butter, lfilter
from pedalboard import Pedalboard, Compressor, Limiter, HighpassFilter, LowpassFilter, Gain
import subprocess
from pydub import AudioSegment

app = Flask(__name__)
CORS(app)

SR = 44100
TARGET_LOUDNESS = -9.0

def convert_to_wav(input_path):
    if input_path.lower().endswith(".mp3"):
        wav_path = input_path.rsplit(".", 1)[0] + f"_{uuid.uuid4().hex}.wav"
        AudioSegment.from_mp3(input_path).export(wav_path, format="wav")
        return wav_path
    return input_path

def load_mono(file):
    y, _ = librosa.load(file, sr=SR, mono=True)
    return y

def normalize_audio(y):
    return y / (np.max(np.abs(y)) + 1e-9)

def highpass(data, cutoff):
    b, a = butter(4, cutoff / (SR / 2), btype='high')
    return lfilter(b, a, data)

def lowpass(data, cutoff):
    b, a = butter(4, cutoff / (SR / 2), btype='low')
    return lfilter(b, a, data)

def detect_key(y):
    chroma = librosa.feature.chroma_cqt(y=y, sr=SR)
    return np.argmax(np.sum(chroma, axis=1))

def match_key(source, target):
    key_s = detect_key(source)
    key_t = detect_key(target)
    shift = key_t - key_s
    return librosa.effects.pitch_shift(source, sr=SR, n_steps=float(shift))

def beat_sync_warp(source, target):
    tempo_t, _ = librosa.beat.beat_track(y=target, sr=SR)
    tempo_s, _ = librosa.beat.beat_track(y=source, sr=SR)
    tempo_t = float(np.atleast_1d(tempo_t)[0])
    tempo_s = float(np.atleast_1d(tempo_s)[0])
    if tempo_t <= 0 or tempo_s <= 0:
        return librosa.util.fix_length(source, size=len(target))
    rate = tempo_s / tempo_t
    warped = librosa.effects.time_stretch(source, rate=float(rate))
    warped = librosa.util.fix_length(warped, size=len(target))
    return warped

def separate_stems(input_file, job_id):
    out_dir = f"separated_{job_id}"
    subprocess.run(["demucs", "-n", "htdemucs", "--out", out_dir, input_file], check=True)
    base = os.path.splitext(os.path.basename(input_file))[0]
    stem_dir = f"{out_dir}/htdemucs/{base}"
    stems = {
        "drums": f"{stem_dir}/drums.wav",
        "bass": f"{stem_dir}/bass.wav",
        "other": f"{stem_dir}/other.wav",
        "vocals": f"{stem_dir}/vocals.wav"
    }
    return stems, out_dir

@app.route('/', methods=['GET'])
def health():
    return jsonify({"status": "ready"}), 200

@app.route('/fuse', methods=['POST'])
def fuse_api():
    job_id = str(uuid.uuid4())
    temp_files = []
    cleanup_dirs = []
    try:
        trad_req = request.files.get('melody')
        modern_req = request.files.get('style')
        if not trad_req or not modern_req:
            return jsonify({"error": "missing files"}), 400
        t_path = f"trad_{job_id}.wav"
        m_path = f"mod_{job_id}.wav"
        trad_req.save(t_path)
        modern_req.save(m_path)
        temp_files.extend([t_path, m_path])
        t_wav = convert_to_wav(t_path)
        m_wav = convert_to_wav(m_path)
        if t_wav != t_path: temp_files.append(t_wav)
        if m_wav != m_path: temp_files.append(m_wav)
        t_stems, t_dir = separate_stems(t_wav, f"t_{job_id}")
        m_stems, m_dir = separate_stems(m_wav, f"m_{job_id}")
        cleanup_dirs.extend([t_dir, m_dir])
        t_other = load_mono(t_stems["other"])
        t_bass = load_mono(t_stems["bass"])
        m_drums = load_mono(m_stems["drums"])
        m_bass = load_mono(m_stems["bass"])
        m_other = load_mono(m_stems["other"])
        
        target_len = min(len(t_other), len(m_drums))
        t_other = t_other[:target_len]
        t_bass = t_bass[:target_len]
        m_drums = m_drums[:target_len]
        m_bass = m_bass[:target_len]
        m_other = m_other[:target_len]

        t_other = match_key(t_other, m_other)
        t_bass = match_key(t_bass, m_bass)
        t_other = beat_sync_warp(t_other, m_drums)
        t_bass = beat_sync_warp(t_bass, m_drums)
        t_other = highpass(t_other, 120)
        t_bass = highpass(t_bass, 60)
        m_bass = lowpass(m_bass, 250)
        m_drums = lowpass(m_drums, 12000)
        m_other = highpass(m_other, 150)
        fusion = (1.0 * m_drums + 1.0 * m_bass + 1.2 * t_other + 0.5 * m_other + 0.8 * t_bass)
        fusion = normalize_audio(fusion)
        board = Pedalboard([HighpassFilter(30), LowpassFilter(18000), Compressor(threshold_db=-20, ratio=2), Gain(2.0), Limiter(threshold_db=-0.5)])
        fusion_mastered = board(fusion, SR)
        meter = pyln.Meter(SR)
        loudness = meter.integrated_loudness(fusion_mastered)
        fusion_mastered = pyln.normalize.loudness(fusion_mastered, loudness, TARGET_LOUDNESS)
        buf = io.BytesIO()
        sf.write(buf, fusion_mastered, SR, format='WAV')
        buf.seek(0)
        return send_file(buf, mimetype="audio/wav", as_attachment=True, download_name="fusion_output.wav")
    except Exception as e:
        return jsonify({"error": str(e)}), 500
    finally:
        for f in temp_files:
            if os.path.exists(f): os.remove(f)
        for d in cleanup_dirs:
            if os.path.exists(d): shutil.rmtree(d, ignore_errors=True)

if __name__ == "__main__":
    app.run(host='0.0.0.0', port=7860)