File size: 1,140 Bytes
b970ee2
2c2c0fa
b970ee2
2c2c0fa
 
b970ee2
 
3dc6856
b970ee2
 
2c2c0fa
b970ee2
5b92a42
b970ee2
 
5b92a42
b970ee2
 
5b92a42
b970ee2
 
 
5b92a42
 
b970ee2
 
 
5b92a42
b970ee2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from typing import Dict, List, Any
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch

class EndpointHandler:
    def __init__(self, model_path):
        self.processor = AutoProcessor.from_pretrained(model_path)
        self.model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
        if torch.cuda.is_available():
            self.model = self.model.to("cuda")

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        inputs = self.processor(
            text=data["text"],
            audio=data.get("audio", None),
            padding=True,
            sampling_rate=data.get("sampling_rate", None),
            return_tensors="pt",
        )
        if torch.cuda.is_available():
            inputs = {k: v.to("cuda") for k, v in inputs.items()}

        audio_values = self.model.generate(
            **inputs,
            do_sample=data.get("do_sample", True),
            guidance_scale=data.get("guidance_scale", 3),
            max_new_tokens=data.get("max_new_tokens", 256),
        )
        return {"audio_values": audio_values.cpu().numpy()}