| from stable_audio_tools import get_pretrained_model |
| from stable_audio_tools.inference.generation import generate_diffusion_cond |
| import torch |
| import base64 |
| import io |
| import soundfile as sf |
|
|
| |
| model, cfg = get_pretrained_model("bharatverse11/BeatGeneration") |
| model.eval().to("cuda") |
|
|
| SAMPLE_RATE = cfg.get("sample_rate", 44100) |
|
|
| def handler(data): |
| inputs = data["inputs"] |
|
|
| prompt = inputs.get("prompt", "") |
| duration = inputs.get("duration", 10) |
| steps = inputs.get("steps", 50) |
| cfg_scale = inputs.get("cfg_scale", 7) |
|
|
| conditioning = [{ |
| "prompt": prompt, |
| "seconds_start": 0, |
| "seconds_total": duration, |
| }] |
|
|
| with torch.no_grad(): |
| output = generate_diffusion_cond( |
| model, |
| steps=steps, |
| cfg_scale=cfg_scale, |
| conditioning=conditioning, |
| sample_size=int(duration * SAMPLE_RATE), |
| device="cuda", |
| ) |
|
|
| audio = output.cpu().numpy()[0].T |
|
|
| buffer = io.BytesIO() |
| sf.write(buffer, audio, SAMPLE_RATE, format="WAV") |
|
|
| return { |
| "audio": base64.b64encode(buffer.getvalue()).decode() |
| } |