| from audiocraft.models import MusicGen |
| from scipy.io.wavfile import write |
| import torch |
| import io |
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.model = MusicGen.get_pretrained("facebook/musicgen-large").to(self.device) |
| self.model.set_generation_params(duration=8) |
|
|
| def __call__(self, data): |
| prompt = data.get("inputs", "chill lofi music") |
| wav = self.model.generate([prompt]) |
|
|
| sample_rate = 32000 |
| audio_tensor = wav[0].to("cpu").numpy() |
|
|
| buffer = io.BytesIO() |
| write(buffer, sample_rate, audio_tensor) |
| buffer.seek(0) |
|
|
| return {"audio": buffer.read()} |
|
|