musicgen-custom / handler.py
yuthrb's picture
Update handler.py
3dc6856 verified
from typing import Dict, List, Any
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import torch
class EndpointHandler:
def __init__(self, model_path):
self.processor = AutoProcessor.from_pretrained(model_path)
self.model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
if torch.cuda.is_available():
self.model = self.model.to("cuda")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
inputs = self.processor(
text=data["text"],
audio=data.get("audio", None),
padding=True,
sampling_rate=data.get("sampling_rate", None),
return_tensors="pt",
)
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
audio_values = self.model.generate(
**inputs,
do_sample=data.get("do_sample", True),
guidance_scale=data.get("guidance_scale", 3),
max_new_tokens=data.get("max_new_tokens", 256),
)
return {"audio_values": audio_values.cpu().numpy()}