Surn's picture
add sounds.py
a94fc8f
raw
history blame
2.31 kB
# file: battlewords/sounds.py
import os
import tempfile
import torch
from diffusers import StableAudioPipeline
import scipy.io.wavfile as wav
import base64
# Predefined prompts for sound effects
EFFECT_PROMPTS = {
"correct guess": "A short, sharp ding sound for a correct guess",
"incorrect guess": "A low buzz sound for an incorrect guess",
"miss": "A soft thud sound for a miss",
"hit": "A bright chime sound for a hit",
"congratulations": "A triumphant fanfare sound for congratulations"
}
_sound_cache = {}
def generate_sound_effect(effect: str) -> str:
"""
Generate a sound effect using Stable Audio Open based on the effect string.
Returns the path to the generated WAV file.
"""
if effect not in EFFECT_PROMPTS:
raise ValueError(f"Unknown effect: {effect}. Available effects: {list(EFFECT_PROMPTS.keys())}")
# Check cache first
if effect in _sound_cache:
return _sound_cache[effect]
# Load the model (cached globally)
if not hasattr(generate_sound_effect, 'pipe'):
generate_sound_effect.pipe = StableAudioPipeline.from_pretrained(
"stabilityai/stable-audio-open-1.0",
torch_dtype=torch.float16
)
if torch.cuda.is_available():
generate_sound_effect.pipe = generate_sound_effect.pipe.to("cuda")
prompt = EFFECT_PROMPTS[effect]
# Generate audio
audio = generate_sound_effect.pipe(
prompt,
duration=2, # Short duration for sound effects
num_inference_steps=50
).audio
# Save to temporary file
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
wav.write(tmpfile.name, 44100, audio[0])
path = tmpfile.name
# Cache the path
_sound_cache[effect] = path
return path
def get_sound_effect_path(effect: str) -> str:
"""
Get the path to a sound effect, generating it if necessary.
"""
return generate_sound_effect(effect)
def get_sound_effect_data_url(effect: str) -> str:
"""
Get a data URL for the sound effect, suitable for embedding in HTML.
"""
path = generate_sound_effect(effect)
with open(path, "rb") as f:
data = f.read()
encoded = base64.b64encode(data).decode()
return f"data:audio/wav;base64,{encoded}"