File size: 3,428 Bytes
2e07021
d4a8611
 
2e07021
2193261
 
 
 
 
 
2e07021
d4a8611
7e79b1c
2193261
 
 
2e07021
d4a8611
2193261
b14b378
 
 
 
2e07021
d4a8611
2193261
 
 
 
b14b378
2193261
 
 
 
 
 
 
 
 
 
 
d4a8611
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b14b378
d4a8611
aa89711
d4a8611
aa89711
 
d4a8611
 
 
 
aa89711
d4a8611
1c7efde
d4a8611
 
 
b14b378
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import time
import requests
from huggingface_hub import login
import torch
import torchaudio
from einops import rearrange
import gradio as gr
from stable_audio_tools import get_pretrained_model
from stable_audio_tools.inference.generation import generate_diffusion_cond

# Authenticate Hugging Face Hub
token = os.getenv("HUGGINGFACE_TOKEN")
if not token:
    raise RuntimeError("HUGGINGFACE_TOKEN not set")
login(token=token, add_to_git_credential=False)

# Load audio model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, config = get_pretrained_model("stabilityai/stable-audio-open-small")
model = model.to(device)
sample_rate = config["sample_rate"]
sample_size = config["sample_size"]

# Audio generation function
def generate_audio(prompt):
    conditioning = [{"prompt": prompt, "seconds_total": 11}]
    with torch.no_grad():
        output = generate_diffusion_cond(
            model,
            steps=8,
            conditioning=conditioning,
            sample_size=sample_size,
            device=device
        )
    output = rearrange(output, "b d n -> d (b n)")
    output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
    path = "output.wav"
    torchaudio.save(path, output, sample_rate)
    return path

# Image generation function using Replicate
def generate_image(prompt):
    replicate_token = os.getenv("REPLICATE_API_TOKEN")
    if not replicate_token:
        raise RuntimeError("REPLICATE_API_TOKEN not set")

    url = "https://api.replicate.com/v1/predictions"
    headers = {
        "Authorization": f"Token {replicate_token}",
        "Content-Type": "application/json"
    }
    data = {
        "version": "5ee6b41748a4e3e3d3a212ed4a29379d6a13b9265fd00fe59e28c2767a5e82eb",
        "input": {
            "prompt": prompt,
            "style": "surreal"
        }
    }
    response = requests.post(url, headers=headers, json=data)
    response.raise_for_status()
    prediction = response.json()

    status = prediction["status"]
    get_url = prediction["urls"]["get"]

    while status not in ["succeeded", "failed"]:
        time.sleep(1.5)
        resp = requests.get(get_url, headers=headers)
        prediction = resp.json()
        status = prediction["status"]
    
    if status != "succeeded":
        raise RuntimeError(f"Image generation failed: {prediction}")
    
    image_url = prediction["output"]
    image_path = "output.png"
    image_data = requests.get(image_url).content
    with open(image_path, "wb") as f:
        f.write(image_data)
    
    return image_path

# Combined generation function
def generate_assets(prompt):
    audio_path = generate_audio(prompt)
    image_path = generate_image(prompt)
    return audio_path, image_path

# Gradio UI
gr.Interface(
    fn=generate_assets,
    inputs=gr.Textbox(
        label="🎀 Prompt your sonic + visual art",
        placeholder="e.g. 'drunk driving with mario and yung lean'"
    ),
    outputs=[
        gr.Audio(type="filepath", label="🧠 Generated Audio"),
        gr.Image(type="filepath", label="🎨 Generated Image")
    ],
    title='🌐 Hot Prompts in Your Area: "My Husband Is Dead"',
    description="Enter a fun sound idea β€” generate audio *and* visual from one prompt.",
    examples=[
        "ghosts peeing",
        "Tech startup boss villain entrance music",
        "Dolphin hootin'"
    ]
).launch()