fusionmodel / app.py
gere's picture
Update app.py
bdee77c verified
import os
import torch
import torchaudio
import librosa
import numpy as np
import io
import tempfile
import wave
from flask import Flask, request, send_file, jsonify
from flask_cors import CORS
from audiocraft.models import MusicGen
app = Flask(__name__)
CORS(app)
class FusionEngine:
def __init__(self):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Using melody model as required for melody conditioning
self.model = MusicGen.get_pretrained('facebook/musicgen-melody')
def process(self, melody_bytes, style_bytes):
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as m_file:
m_file.write(melody_bytes)
m_path = m_file.name
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as s_file:
s_file.write(style_bytes)
s_path = s_file.name
try:
y2, sr2 = librosa.load(s_path, duration=10)
tempo_val, _ = librosa.beat.beat_track(y=y2, sr=sr2)
tempo = float(tempo_val[0]) if isinstance(tempo_val, (np.ndarray, list)) else float(tempo_val)
spec_centroid = np.mean(librosa.feature.spectral_centroid(y=y2, sr=sr2))
vibe = "electronic" if spec_centroid > 2500 else "organic"
accurate_prompt = f"A {vibe} version, {int(tempo)} BPM, studio quality."
self.model.set_generation_params(duration=15, use_sampling=True, top_k=250, temperature=0.7)
m_wav, sr = torchaudio.load(m_path)
if m_wav.shape[0] > 1:
m_wav = m_wav.mean(dim=0, keepdim=True)
if sr != 32000:
resampler = torchaudio.transforms.Resample(sr, 32000)
m_wav = resampler(m_wav)
sr = 32000
result = self.model.generate_with_chroma(
descriptions=[accurate_prompt],
melody_wavs=m_wav[None, ...].to(self.device),
melody_sample_rate=sr
)
return result[0].cpu().numpy(), self.model.sample_rate
finally:
if os.path.exists(m_path): os.remove(m_path)
if os.path.exists(s_path): os.remove(s_path)
engine = None
@app.route('/', methods=['GET'])
def health():
return jsonify({"status": "ready"}), 200
@app.route('/fuse', methods=['POST'])
def fuse():
global engine
if engine is None:
engine = FusionEngine()
try:
m = request.files['melody'].read()
s = request.files['style'].read()
out_wav, sr = engine.process(m, s)
# Manually construct the WAV file to bypass FFmpeg AVFormatContext errors
buffer = io.BytesIO()
# MusicGen outputs (Channels, Samples), we need (Samples, Channels) for wave
audio_data = (out_wav[0] * 32767).astype(np.int16)
with wave.open(buffer, 'wb') as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(sr)
wf.writeframes(audio_data.tobytes())
buffer.seek(0)
return send_file(buffer, mimetype='audio/wav', as_attachment=True, download_name="fused.wav")
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
app.run(host='0.0.0.0', port=7860)