File size: 1,152 Bytes
5c04df5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from stable_audio_tools import get_pretrained_model
from stable_audio_tools.inference.generation import generate_diffusion_cond
import torch
import base64
import io
import soundfile as sf

# Load once (IMPORTANT)
model, cfg = get_pretrained_model("bharatverse11/BeatGeneration")
model.eval().to("cuda")

SAMPLE_RATE = cfg.get("sample_rate", 44100)

def handler(data):
    inputs = data["inputs"]

    prompt = inputs.get("prompt", "")
    duration = inputs.get("duration", 10)
    steps = inputs.get("steps", 50)
    cfg_scale = inputs.get("cfg_scale", 7)

    conditioning = [{
        "prompt": prompt,
        "seconds_start": 0,
        "seconds_total": duration,
    }]

    with torch.no_grad():
        output = generate_diffusion_cond(
            model,
            steps=steps,
            cfg_scale=cfg_scale,
            conditioning=conditioning,
            sample_size=int(duration * SAMPLE_RATE),
            device="cuda",
        )

    audio = output.cpu().numpy()[0].T

    buffer = io.BytesIO()
    sf.write(buffer, audio, SAMPLE_RATE, format="WAV")

    return {
        "audio": base64.b64encode(buffer.getvalue()).decode()
    }