gere commited on
Commit
62d2c76
·
verified ·
1 Parent(s): 8d7689d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -0
app.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ import librosa
5
+ import numpy as np
6
+ import io
7
+ import tempfile
8
+ from flask import Flask, request, send_file, jsonify
9
+ from flask_cors import CORS
10
+ from audiocraft.models import MusicGen
11
+
12
+ app = Flask(__name__)
13
+ CORS(app)
14
+
15
+ class FusionEngine:
16
+ def __init__(self):
17
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
+ self.model = MusicGen.get_pretrained('facebook/musicgen-small')
19
+
20
+ def process(self, melody_bytes, style_bytes):
21
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as m_file:
22
+ m_file.write(melody_bytes)
23
+ m_path = m_file.name
24
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as s_file:
25
+ s_file.write(style_bytes)
26
+ s_path = s_file.name
27
+ try:
28
+ y2, sr2 = librosa.load(s_path, duration=10)
29
+ tempo_val, _ = librosa.beat.beat_track(y=y2, sr=sr2)
30
+ tempo = float(tempo_val[0]) if isinstance(tempo_val, (np.ndarray, list)) else float(tempo_val)
31
+ spec_centroid = np.mean(librosa.feature.spectral_centroid(y=y2, sr=sr2))
32
+ vibe = "electronic" if spec_centroid > 2500 else "organic"
33
+ accurate_prompt = f"A {vibe} version, {int(tempo)} BPM, studio quality."
34
+ self.model.set_generation_params(duration=15, use_sampling=True, top_k=250, temperature=0.7)
35
+ m_wav, sr = torchaudio.load(m_path)
36
+ if m_wav.shape[0] > 1: m_wav = m_wav.mean(dim=0, keepdim=True)
37
+ if sr != 32000:
38
+ resampler = torchaudio.transforms.Resample(sr, 32000)
39
+ m_wav = resampler(m_wav)
40
+ sr = 32000
41
+ result = self.model.generate_with_chroma(descriptions=[accurate_prompt], melody_wavs=m_wav[None, ...].to(self.device), melody_sample_rate=sr)
42
+ return result[0].cpu(), self.model.sample_rate
43
+ finally:
44
+ if os.path.exists(m_path): os.remove(m_path)
45
+ if os.path.exists(s_path): os.remove(s_path)
46
+
47
+ engine = None
48
+
49
+ @app.route('/', methods=['GET'])
50
+ def health():
51
+ return jsonify({"status": "ready"}), 200
52
+
53
+ @app.route('/fuse', methods=['POST'])
54
+ def fuse():
55
+ global engine
56
+ if engine is None: engine = FusionEngine()
57
+ try:
58
+ m = request.files['melody'].read()
59
+ s = request.files['style'].read()
60
+ out_wav, sr = engine.process(m, s)
61
+ buffer = io.BytesIO()
62
+ torchaudio.save(buffer, out_wav, sr, format="wav")
63
+ buffer.seek(0)
64
+ return send_file(buffer, mimetype='audio/wav')
65
+ except Exception as e:
66
+ return jsonify({"error": str(e)}), 500
67
+
68
+ if __name__ == "__main__":
69
+ port = int(os.environ.get("PORT", 7860))
70
+ app.run(host='0.0.0.0', port=port)