PyTorch
musicgen
mg / handler.py
kkhatri's picture
Update handler.py
8465857 verified
raw
history blame contribute delete
780 Bytes
from audiocraft.models import MusicGen
from scipy.io.wavfile import write
import torch
import io
class EndpointHandler:
def __init__(self, path=""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = MusicGen.get_pretrained("facebook/musicgen-large").to(self.device)
self.model.set_generation_params(duration=8)
def __call__(self, data):
prompt = data.get("inputs", "chill lofi music")
wav = self.model.generate([prompt]) # internally uses the model device
sample_rate = 32000
audio_tensor = wav[0].to("cpu").numpy() # ensure CPU before saving
buffer = io.BytesIO()
write(buffer, sample_rate, audio_tensor)
buffer.seek(0)
return {"audio": buffer.read()}