Spaces:
Running
Running
| import os | |
| import torch | |
| import torchaudio | |
| import librosa | |
| import numpy as np | |
| import io | |
| import tempfile | |
| import wave | |
| from flask import Flask, request, send_file, jsonify | |
| from flask_cors import CORS | |
| from audiocraft.models import MusicGen | |
| app = Flask(__name__) | |
| CORS(app) | |
| class FusionEngine: | |
| def __init__(self): | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # Using melody model as required for melody conditioning | |
| self.model = MusicGen.get_pretrained('facebook/musicgen-melody') | |
| def process(self, melody_bytes, style_bytes): | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as m_file: | |
| m_file.write(melody_bytes) | |
| m_path = m_file.name | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as s_file: | |
| s_file.write(style_bytes) | |
| s_path = s_file.name | |
| try: | |
| y2, sr2 = librosa.load(s_path, duration=10) | |
| tempo_val, _ = librosa.beat.beat_track(y=y2, sr=sr2) | |
| tempo = float(tempo_val[0]) if isinstance(tempo_val, (np.ndarray, list)) else float(tempo_val) | |
| spec_centroid = np.mean(librosa.feature.spectral_centroid(y=y2, sr=sr2)) | |
| vibe = "electronic" if spec_centroid > 2500 else "organic" | |
| accurate_prompt = f"A {vibe} version, {int(tempo)} BPM, studio quality." | |
| self.model.set_generation_params(duration=15, use_sampling=True, top_k=250, temperature=0.7) | |
| m_wav, sr = torchaudio.load(m_path) | |
| if m_wav.shape[0] > 1: | |
| m_wav = m_wav.mean(dim=0, keepdim=True) | |
| if sr != 32000: | |
| resampler = torchaudio.transforms.Resample(sr, 32000) | |
| m_wav = resampler(m_wav) | |
| sr = 32000 | |
| result = self.model.generate_with_chroma( | |
| descriptions=[accurate_prompt], | |
| melody_wavs=m_wav[None, ...].to(self.device), | |
| melody_sample_rate=sr | |
| ) | |
| return result[0].cpu().numpy(), self.model.sample_rate | |
| finally: | |
| if os.path.exists(m_path): os.remove(m_path) | |
| if os.path.exists(s_path): os.remove(s_path) | |
| engine = None | |
| def health(): | |
| return jsonify({"status": "ready"}), 200 | |
| def fuse(): | |
| global engine | |
| if engine is None: | |
| engine = FusionEngine() | |
| try: | |
| m = request.files['melody'].read() | |
| s = request.files['style'].read() | |
| out_wav, sr = engine.process(m, s) | |
| # Manually construct the WAV file to bypass FFmpeg AVFormatContext errors | |
| buffer = io.BytesIO() | |
| # MusicGen outputs (Channels, Samples), we need (Samples, Channels) for wave | |
| audio_data = (out_wav[0] * 32767).astype(np.int16) | |
| with wave.open(buffer, 'wb') as wf: | |
| wf.setnchannels(1) | |
| wf.setsampwidth(2) | |
| wf.setframerate(sr) | |
| wf.writeframes(audio_data.tobytes()) | |
| buffer.seek(0) | |
| return send_file(buffer, mimetype='audio/wav', as_attachment=True, download_name="fused.wav") | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| if __name__ == "__main__": | |
| app.run(host='0.0.0.0', port=7860) |