from typing import Dict, List, Any from transformers import AutoProcessor, MusicgenForConditionalGeneration import torch class EndpointHandler: def __init__(self, model_path): self.processor = AutoProcessor.from_pretrained(model_path) self.model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") if torch.cuda.is_available(): self.model = self.model.to("cuda") def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: inputs = self.processor( text=data["text"], audio=data.get("audio", None), padding=True, sampling_rate=data.get("sampling_rate", None), return_tensors="pt", ) if torch.cuda.is_available(): inputs = {k: v.to("cuda") for k, v in inputs.items()} audio_values = self.model.generate( **inputs, do_sample=data.get("do_sample", True), guidance_scale=data.get("guidance_scale", 3), max_new_tokens=data.get("max_new_tokens", 256), ) return {"audio_values": audio_values.cpu().numpy()}