Text-to-Audio
Transformers
PyTorch
musicgen
File size: 918 Bytes
004c940
 
1d5a5e4
004c940
ae1ed29
cef0579
ae1ed29
 
 
004c940
ae1ed29
 
 
 
 
 
 
 
cef0579
ae1ed29
 
2984362
ae1ed29
1d5a5e4
ae1ed29
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from audiocraft.data.audio_utils import convert_audio
from audiocraft.data.audio import audio_write
from audiocraft.models import MusicGen

from typing import Dict, List, Any
import logging

class EndpointHandler:
    def __init__(self, path=""):
        self.model = MusicGen.get_pretrained("musicgen-medium")

    def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
        """
        Args:
            data (:dict:):
                The payload with the text prompt and generation parameters.
        """
        # process input
        logging.info(f"data: {data}")
        inputs = data.pop("inputs", data)
        parameters = data.pop("parameters", None)
        self.model.set_generation_params(**parameters)

        outputs = self.model.generate(**inputs)

        # postprocess the prediction
        prediction = outputs[0].cpu().numpy().tolist()

        return [{"generated_audio": prediction}]