Text-to-Audio
Transformers
PyTorch
musicgen
musicgen-medium / handler.py
zsolt-klang's picture
Update handler.py
624c170
raw
history blame contribute delete
918 Bytes
from audiocraft.data.audio_utils import convert_audio
from audiocraft.data.audio import audio_write
from audiocraft.models import MusicGen
from typing import Dict, List, Any
import logging
class EndpointHandler:
def __init__(self, path=""):
self.model = MusicGen.get_pretrained("musicgen-medium")
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
"""
Args:
data (:dict:):
The payload with the text prompt and generation parameters.
"""
# process input
logging.info(f"data: {data}")
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
self.model.set_generation_params(**parameters)
outputs = self.model.generate(**inputs)
# postprocess the prediction
prediction = outputs[0].cpu().numpy().tolist()
return [{"generated_audio": prediction}]