import os import torch import torchaudio import librosa import numpy as np import io import tempfile import wave from flask import Flask, request, send_file, jsonify from flask_cors import CORS from audiocraft.models import MusicGen app = Flask(__name__) CORS(app) class FusionEngine: def __init__(self): self.device = 'cuda' if torch.cuda.is_available() else 'cpu' # Using melody model as required for melody conditioning self.model = MusicGen.get_pretrained('facebook/musicgen-melody') def process(self, melody_bytes, style_bytes): with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as m_file: m_file.write(melody_bytes) m_path = m_file.name with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as s_file: s_file.write(style_bytes) s_path = s_file.name try: y2, sr2 = librosa.load(s_path, duration=10) tempo_val, _ = librosa.beat.beat_track(y=y2, sr=sr2) tempo = float(tempo_val[0]) if isinstance(tempo_val, (np.ndarray, list)) else float(tempo_val) spec_centroid = np.mean(librosa.feature.spectral_centroid(y=y2, sr=sr2)) vibe = "electronic" if spec_centroid > 2500 else "organic" accurate_prompt = f"A {vibe} version, {int(tempo)} BPM, studio quality." self.model.set_generation_params(duration=15, use_sampling=True, top_k=250, temperature=0.7) m_wav, sr = torchaudio.load(m_path) if m_wav.shape[0] > 1: m_wav = m_wav.mean(dim=0, keepdim=True) if sr != 32000: resampler = torchaudio.transforms.Resample(sr, 32000) m_wav = resampler(m_wav) sr = 32000 result = self.model.generate_with_chroma( descriptions=[accurate_prompt], melody_wavs=m_wav[None, ...].to(self.device), melody_sample_rate=sr ) return result[0].cpu().numpy(), self.model.sample_rate finally: if os.path.exists(m_path): os.remove(m_path) if os.path.exists(s_path): os.remove(s_path) engine = None @app.route('/', methods=['GET']) def health(): return jsonify({"status": "ready"}), 200 @app.route('/fuse', methods=['POST']) def fuse(): global engine if engine is None: engine = FusionEngine() try: m = request.files['melody'].read() s = request.files['style'].read() out_wav, sr = engine.process(m, s) # Manually construct the WAV file to bypass FFmpeg AVFormatContext errors buffer = io.BytesIO() # MusicGen outputs (Channels, Samples), we need (Samples, Channels) for wave audio_data = (out_wav[0] * 32767).astype(np.int16) with wave.open(buffer, 'wb') as wf: wf.setnchannels(1) wf.setsampwidth(2) wf.setframerate(sr) wf.writeframes(audio_data.tobytes()) buffer.seek(0) return send_file(buffer, mimetype='audio/wav', as_attachment=True, download_name="fused.wav") except Exception as e: return jsonify({"error": str(e)}), 500 if __name__ == "__main__": app.run(host='0.0.0.0', port=7860)