saos / app.py
hugofloresgarcia's picture
Fix audio saving: use soundfile instead of torchaudio to avoid torchcodec dependency
ace23de
raw
history blame
5.77 kB
import torch
import gradio as gr
import os
import soundfile as sf
import numpy as np
from stable_audio_tools import get_pretrained_model
from stable_audio_tools.inference.generation import generate_diffusion_cond
from huggingface_hub import login
# Global model variables
model = None
model_config = None
device = None
def load_model():
"""Load the pretrained model on startup"""
global model, model_config, device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model on device: {device}")
# Check for HF_TOKEN environment variable (set in Space settings)
hf_token = os.getenv("HF_TOKEN")
if hf_token:
print("Using HF_TOKEN for authentication")
login(token=hf_token)
else:
print("Warning: HF_TOKEN not found. Model access may fail if authentication is required.")
print("Please set HF_TOKEN as a secret in your Space settings.")
# Download and load the pretrained model
model, model_config = get_pretrained_model("stabilityai/stable-audio-open-small")
sample_rate = model_config["sample_rate"]
sample_size = model_config["sample_size"]
model = model.to(device).eval().requires_grad_(False)
model = model.to(torch.float16) # Use half precision for efficiency
print(f"Model loaded successfully. Sample rate: {sample_rate}, Sample size: {sample_size}")
return model, model_config
def generate_audio(prompt, seconds_total=11):
"""Generate 4 audio variations from a text prompt"""
global model, model_config, device
if model is None:
return [], "Model not loaded. Please wait..."
if not prompt or not prompt.strip():
return [], "Please enter a text prompt."
# Set up text and timing conditioning (repeat for batch_size)
conditioning = [{
"prompt": prompt,
"seconds_total": seconds_total
}] * 4 # Repeat for batch_size=4
# Generate 4 variations using batch_size=4
try:
output = generate_diffusion_cond(
model,
steps=8,
cfg_scale=1.0,
conditioning=conditioning,
sample_size=model_config["sample_size"],
sampler_type="pingpong",
device=device,
batch_size=4 # Generate 4 variations
)
# Rearrange audio batch: [batch, channels, samples] -> [channels, batch*samples]
# Then split back into individual files
sample_rate = model_config["sample_rate"]
audio_files = []
# Process each variation in the batch
for i in range(4):
# Extract single variation: [channels, samples]
audio = output[i] # Shape: [channels, samples]
# Peak normalize, clip, convert to float32 numpy array
audio = audio.to(torch.float32)
audio_max = torch.max(torch.abs(audio))
if audio_max > 0:
audio = audio.div(audio_max)
audio = audio.clamp(-1, 1).cpu().numpy()
# Transpose to [samples, channels] for soundfile
if audio.ndim == 1:
audio = audio.reshape(-1, 1)
else:
audio = audio.T # [channels, samples] -> [samples, channels]
# Save to temporary file using soundfile
filename = f"output_variation_{i+1}.wav"
sf.write(filename, audio, sample_rate)
audio_files.append(filename)
return audio_files, f"Generated 4 variations for: '{prompt}'"
except Exception as e:
import traceback
error_msg = f"Error generating audio: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return [], error_msg
# Load model on startup
print("Initializing model...")
load_model()
# Create Gradio interface
with gr.Blocks(title="Stable Audio Open Small - 4 Variations") as demo:
gr.Markdown("""
# Stable Audio Open Small
Generate up to 4 audio variations from a text prompt.
**Model**: [stabilityai/stable-audio-open-small](https://huggingface.co/stabilityai/stable-audio-open-small)
**Note**: This model requires accepting the license agreement. Make sure to set `HF_TOKEN` as a secret in your Space settings.
Enter a text description and click Generate to create 4 different audio variations.
""")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Text Prompt",
placeholder="e.g., 128 BPM tech house drum loop",
lines=2
)
seconds_input = gr.Slider(
minimum=1,
maximum=11,
value=11,
step=1,
label="Duration (seconds)",
info="Maximum 11 seconds"
)
generate_btn = gr.Button("Generate", variant="primary")
with gr.Column():
status_output = gr.Textbox(label="Status", interactive=False)
audio_gallery = gr.Gallery(
label="Generated Audio Variations",
show_label=True,
elem_id="gallery",
columns=2,
rows=2,
height="auto"
)
generate_btn.click(
fn=generate_audio,
inputs=[prompt_input, seconds_input],
outputs=[audio_gallery, status_output]
)
gr.Markdown("""
### Tips
- The model works best with English descriptions
- Better at generating sound effects and field recordings than music
- Each variation uses a different random seed for diversity
""")
if __name__ == "__main__":
demo.launch()