File size: 918 Bytes
004c940 1d5a5e4 004c940 ae1ed29 cef0579 ae1ed29 004c940 ae1ed29 cef0579 ae1ed29 2984362 ae1ed29 1d5a5e4 ae1ed29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
from audiocraft.data.audio_utils import convert_audio
from audiocraft.data.audio import audio_write
from audiocraft.models import MusicGen
from typing import Dict, List, Any
import logging
class EndpointHandler:
def __init__(self, path=""):
self.model = MusicGen.get_pretrained("musicgen-medium")
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
"""
Args:
data (:dict:):
The payload with the text prompt and generation parameters.
"""
# process input
logging.info(f"data: {data}")
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
self.model.set_generation_params(**parameters)
outputs = self.model.generate(**inputs)
# postprocess the prediction
prediction = outputs[0].cpu().numpy().tolist()
return [{"generated_audio": prediction}]
|