Spaces:
Running
Running
fix: make update_audio fn async to avoid server freeze
Browse files- src/audio/audio_generator.py +62 -44
- src/main.py +0 -1
src/audio/audio_generator.py
CHANGED
|
@@ -1,18 +1,14 @@
|
|
| 1 |
import asyncio
|
| 2 |
from google.genai import types
|
| 3 |
import wave
|
| 4 |
-
import queue
|
| 5 |
import logging
|
| 6 |
import io
|
| 7 |
-
import
|
| 8 |
from config import settings
|
| 9 |
from services.google import GoogleClientFactory
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
async def generate_music(user_hash: str, music_tone: str, receive_audio):
|
| 17 |
if user_hash in sessions:
|
| 18 |
logger.info(
|
|
@@ -44,7 +40,7 @@ async def generate_music(user_hash: str, music_tone: str, receive_audio):
|
|
| 44 |
logger.info(
|
| 45 |
f"Started music generation for user hash {user_hash}, music tone: {music_tone}"
|
| 46 |
)
|
| 47 |
-
sessions[user_hash] = {"session": session, "queue":
|
| 48 |
|
| 49 |
|
| 50 |
async def change_music_tone(user_hash: str, new_tone):
|
|
@@ -75,7 +71,7 @@ async def receive_audio(session, user_hash):
|
|
| 75 |
audio_data = message.server_content.audio_chunks[0].data
|
| 76 |
queue = sessions[user_hash]["queue"]
|
| 77 |
# audio_data is already bytes (raw PCM)
|
| 78 |
-
await
|
| 79 |
await asyncio.sleep(10**-12)
|
| 80 |
except Exception as e:
|
| 81 |
logger.error(f"Error in receive_audio: {e}")
|
|
@@ -102,44 +98,66 @@ async def cleanup_music_session(user_hash: str):
|
|
| 102 |
del sessions[user_hash]
|
| 103 |
|
| 104 |
|
| 105 |
-
def update_audio(user_hash):
|
| 106 |
-
"""
|
|
|
|
|
|
|
|
|
|
| 107 |
if user_hash == "":
|
| 108 |
return
|
| 109 |
|
| 110 |
logger.info(f"Starting audio update loop for user hash: {user_hash}")
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
| 2 |
from google.genai import types
|
| 3 |
import wave
|
|
|
|
| 4 |
import logging
|
| 5 |
import io
|
| 6 |
+
import gradio as gr
|
| 7 |
from config import settings
|
| 8 |
from services.google import GoogleClientFactory
|
| 9 |
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
|
|
|
|
|
|
|
|
|
| 12 |
async def generate_music(user_hash: str, music_tone: str, receive_audio):
|
| 13 |
if user_hash in sessions:
|
| 14 |
logger.info(
|
|
|
|
| 40 |
logger.info(
|
| 41 |
f"Started music generation for user hash {user_hash}, music tone: {music_tone}"
|
| 42 |
)
|
| 43 |
+
sessions[user_hash] = {"session": session, "queue": asyncio.Queue()}
|
| 44 |
|
| 45 |
|
| 46 |
async def change_music_tone(user_hash: str, new_tone):
|
|
|
|
| 71 |
audio_data = message.server_content.audio_chunks[0].data
|
| 72 |
queue = sessions[user_hash]["queue"]
|
| 73 |
# audio_data is already bytes (raw PCM)
|
| 74 |
+
await queue.put(audio_data)
|
| 75 |
await asyncio.sleep(10**-12)
|
| 76 |
except Exception as e:
|
| 77 |
logger.error(f"Error in receive_audio: {e}")
|
|
|
|
| 98 |
del sessions[user_hash]
|
| 99 |
|
| 100 |
|
| 101 |
+
async def update_audio(user_hash: str, request: gr.Request):
|
| 102 |
+
"""
|
| 103 |
+
Continuously stream audio from the queue as WAV bytes, and clean up
|
| 104 |
+
when the user disconnects.
|
| 105 |
+
"""
|
| 106 |
if user_hash == "":
|
| 107 |
return
|
| 108 |
|
| 109 |
logger.info(f"Starting audio update loop for user hash: {user_hash}")
|
| 110 |
+
try:
|
| 111 |
+
while True:
|
| 112 |
+
if await request.request.is_disconnected():
|
| 113 |
+
logger.info(f"Client disconnected for user hash {user_hash}.")
|
| 114 |
+
break
|
| 115 |
+
|
| 116 |
+
if user_hash not in sessions:
|
| 117 |
+
await asyncio.sleep(0.5)
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
try:
|
| 121 |
+
queue = sessions[user_hash]["queue"]
|
| 122 |
+
pcm_data = await asyncio.wait_for(queue.get(), timeout=1.0)
|
| 123 |
+
except asyncio.TimeoutError:
|
| 124 |
+
continue # Check for disconnect again
|
| 125 |
+
except (KeyError, AttributeError):
|
| 126 |
+
logger.warning(
|
| 127 |
+
f"Session or queue for {user_hash} not found. Stopping audio loop."
|
| 128 |
+
)
|
| 129 |
+
break
|
| 130 |
+
|
| 131 |
+
if not isinstance(pcm_data, bytes):
|
| 132 |
+
logger.warning(
|
| 133 |
+
f"Expected bytes from audio_queue, got {type(pcm_data)}. Skipping."
|
| 134 |
+
)
|
| 135 |
+
continue
|
| 136 |
+
|
| 137 |
+
# Lyria provides stereo, 16-bit PCM at 48kHz.
|
| 138 |
+
# Ensure the number of bytes is consistent with stereo 16-bit audio.
|
| 139 |
+
# Each frame = NUM_CHANNELS * SAMPLE_WIDTH bytes.
|
| 140 |
+
# If len(pcm_data) is not a multiple of (NUM_CHANNELS * SAMPLE_WIDTH),
|
| 141 |
+
# it might indicate an incomplete chunk or an issue.
|
| 142 |
+
bytes_per_frame = NUM_CHANNELS * SAMPLE_WIDTH
|
| 143 |
+
if len(pcm_data) % bytes_per_frame != 0:
|
| 144 |
+
logger.warning(
|
| 145 |
+
f"Received PCM data with length {len(pcm_data)}, which is not a multiple of "
|
| 146 |
+
f"bytes_per_frame ({bytes_per_frame}). This might cause issues with WAV formatting."
|
| 147 |
+
)
|
| 148 |
+
# Depending on strictness, you might want to skip this chunk:
|
| 149 |
+
# continue
|
| 150 |
+
|
| 151 |
+
wav_buffer = io.BytesIO()
|
| 152 |
+
with wave.open(wav_buffer, "wb") as wf:
|
| 153 |
+
wf.setnchannels(NUM_CHANNELS)
|
| 154 |
+
wf.setsampwidth(SAMPLE_WIDTH) # Corresponds to 16-bit audio
|
| 155 |
+
wf.setframerate(SAMPLE_RATE)
|
| 156 |
+
wf.writeframes(pcm_data)
|
| 157 |
+
wav_bytes = wav_buffer.getvalue()
|
| 158 |
+
yield wav_bytes
|
| 159 |
+
finally:
|
| 160 |
+
logger.info(
|
| 161 |
+
f"Audio update loop finished for {user_hash}. Cleaning up music session."
|
| 162 |
+
)
|
| 163 |
+
await cleanup_music_session(user_hash)
|
src/main.py
CHANGED
|
@@ -357,7 +357,6 @@ with gr.Blocks(
|
|
| 357 |
outputs=[game_text, game_image, game_choices, custom_choice],
|
| 358 |
)
|
| 359 |
|
| 360 |
-
demo.unload(cleanup_music_session)
|
| 361 |
demo.load(
|
| 362 |
fn=generate_user_hash,
|
| 363 |
inputs=[],
|
|
|
|
| 357 |
outputs=[game_text, game_image, game_choices, custom_choice],
|
| 358 |
)
|
| 359 |
|
|
|
|
| 360 |
demo.load(
|
| 361 |
fn=generate_user_hash,
|
| 362 |
inputs=[],
|