testaudio2 / app.py
nock2's picture
Update app.py
d4a8611 verified
raw
history blame
3.43 kB
import os
import time
import requests
from huggingface_hub import login
import torch
import torchaudio
from einops import rearrange
import gradio as gr
from stable_audio_tools import get_pretrained_model
from stable_audio_tools.inference.generation import generate_diffusion_cond
# Authenticate Hugging Face Hub
token = os.getenv("HUGGINGFACE_TOKEN")
if not token:
raise RuntimeError("HUGGINGFACE_TOKEN not set")
login(token=token, add_to_git_credential=False)
# Load audio model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, config = get_pretrained_model("stabilityai/stable-audio-open-small")
model = model.to(device)
sample_rate = config["sample_rate"]
sample_size = config["sample_size"]
# Audio generation function
def generate_audio(prompt):
conditioning = [{"prompt": prompt, "seconds_total": 11}]
with torch.no_grad():
output = generate_diffusion_cond(
model,
steps=8,
conditioning=conditioning,
sample_size=sample_size,
device=device
)
output = rearrange(output, "b d n -> d (b n)")
output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
path = "output.wav"
torchaudio.save(path, output, sample_rate)
return path
# Image generation function using Replicate
def generate_image(prompt):
replicate_token = os.getenv("REPLICATE_API_TOKEN")
if not replicate_token:
raise RuntimeError("REPLICATE_API_TOKEN not set")
url = "https://api.replicate.com/v1/predictions"
headers = {
"Authorization": f"Token {replicate_token}",
"Content-Type": "application/json"
}
data = {
"version": "5ee6b41748a4e3e3d3a212ed4a29379d6a13b9265fd00fe59e28c2767a5e82eb",
"input": {
"prompt": prompt,
"style": "surreal"
}
}
response = requests.post(url, headers=headers, json=data)
response.raise_for_status()
prediction = response.json()
status = prediction["status"]
get_url = prediction["urls"]["get"]
while status not in ["succeeded", "failed"]:
time.sleep(1.5)
resp = requests.get(get_url, headers=headers)
prediction = resp.json()
status = prediction["status"]
if status != "succeeded":
raise RuntimeError(f"Image generation failed: {prediction}")
image_url = prediction["output"]
image_path = "output.png"
image_data = requests.get(image_url).content
with open(image_path, "wb") as f:
f.write(image_data)
return image_path
# Combined generation function
def generate_assets(prompt):
audio_path = generate_audio(prompt)
image_path = generate_image(prompt)
return audio_path, image_path
# Gradio UI
gr.Interface(
fn=generate_assets,
inputs=gr.Textbox(
label="🎀 Prompt your sonic + visual art",
placeholder="e.g. 'drunk driving with mario and yung lean'"
),
outputs=[
gr.Audio(type="filepath", label="🧠 Generated Audio"),
gr.Image(type="filepath", label="🎨 Generated Image")
],
title='🌐 Hot Prompts in Your Area: "My Husband Is Dead"',
description="Enter a fun sound idea β€” generate audio *and* visual from one prompt.",
examples=[
"ghosts peeing",
"Tech startup boss villain entrance music",
"Dolphin hootin'"
]
).launch()