File size: 780 Bytes
8465857 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | from audiocraft.models import MusicGen
from scipy.io.wavfile import write
import torch
import io
class EndpointHandler:
def __init__(self, path=""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = MusicGen.get_pretrained("facebook/musicgen-large").to(self.device)
self.model.set_generation_params(duration=8)
def __call__(self, data):
prompt = data.get("inputs", "chill lofi music")
wav = self.model.generate([prompt]) # internally uses the model device
sample_rate = 32000
audio_tensor = wav[0].to("cpu").numpy() # ensure CPU before saving
buffer = io.BytesIO()
write(buffer, sample_rate, audio_tensor)
buffer.seek(0)
return {"audio": buffer.read()}
|