| from typing import Dict, List, Any | |
| from transformers import AutoProcessor, MusicgenForConditionalGeneration | |
| import torch | |
| class EndpointHandler: | |
| def __init__(self, model_path): | |
| self.processor = AutoProcessor.from_pretrained(model_path) | |
| self.model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") | |
| if torch.cuda.is_available(): | |
| self.model = self.model.to("cuda") | |
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
| inputs = self.processor( | |
| text=data["text"], | |
| audio=data.get("audio", None), | |
| padding=True, | |
| sampling_rate=data.get("sampling_rate", None), | |
| return_tensors="pt", | |
| ) | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
| audio_values = self.model.generate( | |
| **inputs, | |
| do_sample=data.get("do_sample", True), | |
| guidance_scale=data.get("guidance_scale", 3), | |
| max_new_tokens=data.get("max_new_tokens", 256), | |
| ) | |
| return {"audio_values": audio_values.cpu().numpy()} | |