Joffrey Thomas commited on
Commit
68d5702
·
1 Parent(s): f259931

change app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -49
app.py CHANGED
@@ -3,9 +3,9 @@ import asyncio
3
  import base64
4
  import json
5
  import os
6
- import queue
7
  import threading
8
  import time
 
9
 
10
  import gradio as gr
11
  import numpy as np
@@ -22,19 +22,48 @@ SAMPLE_RATE = 16_000
22
  WARMUP_DURATION = 2.0 # seconds of silence for warmup
23
  WPM_WINDOW = 10 # seconds for running mean calculation
24
  CALIBRATION_PERIOD = 5 # seconds before showing WPM
25
- SESSION_TIMEOUT = 60 # 60 seconds session timeout
26
- # Close the websocket after this many seconds without receiving any audio frames.
27
  INACTIVITY_TIMEOUT = int(os.environ.get("INACTIVITY_TIMEOUT", "20"))
 
28
 
29
  # Global config (shared across users)
30
  ws_url = ""
31
  model = ""
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  class UserSession:
35
  """Per-user session state."""
36
  def __init__(self):
37
- self.audio_queue = queue.Queue(maxsize=100) # Limit queue size
 
38
  self.transcription_text = ""
39
  self.is_running = False
40
  self.status_message = "ready"
@@ -43,6 +72,7 @@ class UserSession:
43
  self.session_start_time = None
44
  self.last_audio_time = None
45
  self._start_lock = threading.Lock()
 
46
 
47
 
48
  # Load CSS from external file
@@ -212,7 +242,6 @@ async def websocket_handler(session):
212
  if session.last_audio_time is not None:
213
  idle = time.time() - session.last_audio_time
214
  if idle >= INACTIVITY_TIMEOUT:
215
- print(f"Inactivity timeout reached ({INACTIVITY_TIMEOUT}s). Closing websocket.")
216
  session.is_running = False
217
  session.status_message = "ready"
218
  break
@@ -220,24 +249,23 @@ async def websocket_handler(session):
220
  if session.session_start_time is not None:
221
  elapsed = time.time() - session.session_start_time
222
  if elapsed >= SESSION_TIMEOUT:
223
- print(f"Session timeout reached ({SESSION_TIMEOUT}s)")
224
  session.is_running = False
225
  session.status_message = "timeout"
226
  break
227
 
228
- chunk = await asyncio.get_event_loop().run_in_executor(
229
- None, lambda: session.audio_queue.get(timeout=0.1)
230
- )
231
- if session.is_running:
232
- await ws.send(
233
- json.dumps(
234
- {"type": "input_audio_buffer.append", "audio": chunk}
235
  )
236
- )
237
- except queue.Empty:
238
- continue
239
  except Exception as e:
240
- print(f"Error sending audio: {e}")
 
241
  session.is_running = False
242
  break
243
 
@@ -264,37 +292,45 @@ async def websocket_handler(session):
264
  session.word_timestamps.append(time.time())
265
 
266
  session.current_wpm = calculate_wpm(session)
 
 
267
  except Exception as e:
268
- print(f"Error receiving transcription: {e}")
 
269
  session.is_running = False
270
 
271
  await asyncio.gather(send_audio(), receive_transcription(), return_exceptions=True)
 
 
272
  except websockets.exceptions.ConnectionClosed:
273
- # Normal closure, not an error
274
- pass
275
  except Exception as e:
276
  error_msg = str(e) if str(e) else type(e).__name__
277
- print(f"WebSocket connection error: {error_msg}")
 
278
  session.status_message = "error"
279
  finally:
280
  session.is_running = False
 
 
 
281
 
282
 
283
  def start_websocket(session):
284
- """Start WebSocket connection in background thread."""
285
  session.is_running = True
286
- loop = asyncio.new_event_loop()
287
- asyncio.set_event_loop(loop)
288
- try:
289
- loop.run_until_complete(websocket_handler(session))
290
- except Exception as e:
291
- print(f"WebSocket error: {e}")
292
- finally:
293
- session.is_running = False
294
- try:
295
- loop.close()
296
- except Exception:
297
- pass
298
 
299
 
300
  def auto_start_recording(session):
@@ -304,14 +340,21 @@ def auto_start_recording(session):
304
  if session.is_running:
305
  return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm)
306
 
 
 
 
 
 
 
307
  session.transcription_text = ""
308
  session.word_timestamps = []
309
  session.current_wpm = "Calibrating..."
310
  session.session_start_time = time.time()
311
  session.last_audio_time = time.time()
312
  session.status_message = "connecting"
313
- thread = threading.Thread(target=start_websocket, args=(session,), daemon=True)
314
- thread.start()
 
315
 
316
  return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm)
317
 
@@ -321,12 +364,17 @@ def clear_history(session):
321
  session.is_running = False
322
  session.last_audio_time = None
323
 
324
- # Clear the audio queue without blocking
325
- try:
326
- while True:
327
- session.audio_queue.get_nowait()
328
- except queue.Empty:
329
- pass
 
 
 
 
 
330
 
331
  session.transcription_text = ""
332
  session.word_timestamps = []
@@ -381,11 +429,12 @@ def process_audio(audio, session):
381
  pcm16 = (audio_float * 32767).astype(np.int16)
382
  b64_chunk = base64.b64encode(pcm16.tobytes()).decode("utf-8")
383
 
384
- # Non-blocking put to queue
385
  try:
386
- session.audio_queue.put_nowait(b64_chunk)
387
- except queue.Full:
388
- pass # Skip if queue is full
 
389
 
390
  return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm)
391
  except Exception as e:
@@ -393,6 +442,14 @@ def process_audio(audio, session):
393
  return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm)
394
 
395
 
 
 
 
 
 
 
 
 
396
  # Gradio interface
397
  with gr.Blocks(title="Voxtral Real-time Transcription") as demo:
398
  session_state = gr.State(value=lambda: UserSession())
@@ -437,7 +494,7 @@ with gr.Blocks(title="Voxtral Real-time Transcription") as demo:
437
  inputs=[audio_input, session_state],
438
  outputs=[transcription_display],
439
  show_progress="hidden",
440
- concurrency_limit=None,
441
  )
442
 
443
  model = os.environ.get("MODEL", "mistralai/Voxtral-Mini-4B-Realtime-2602")
@@ -445,5 +502,8 @@ host = os.environ.get("HOST", "")
445
 
446
  ws_url = f"wss://{host}/v1/realtime"
447
 
448
- demo.queue(default_concurrency_limit=20)
449
- demo.launch(css=CUSTOM_CSS, theme=gr.themes.Base(), ssr_mode=False, max_threads=40)
 
 
 
 
3
  import base64
4
  import json
5
  import os
 
6
  import threading
7
  import time
8
+ import uuid
9
 
10
  import gradio as gr
11
  import numpy as np
 
22
  WARMUP_DURATION = 2.0 # seconds of silence for warmup
23
  WPM_WINDOW = 10 # seconds for running mean calculation
24
  CALIBRATION_PERIOD = 5 # seconds before showing WPM
25
+ SESSION_TIMEOUT = int(os.environ.get("SESSION_TIMEOUT", "60"))
 
26
  INACTIVITY_TIMEOUT = int(os.environ.get("INACTIVITY_TIMEOUT", "20"))
27
+ MAX_CONCURRENT_SESSIONS = int(os.environ.get("MAX_SESSIONS", "50"))
28
 
29
  # Global config (shared across users)
30
  ws_url = ""
31
  model = ""
32
 
33
+ # Global event loop for all websocket connections (runs in single background thread)
34
+ _event_loop = None
35
+ _loop_thread = None
36
+ _loop_lock = threading.Lock()
37
+
38
+ # Track active sessions for resource management
39
+ _active_sessions = {}
40
+ _sessions_lock = threading.Lock()
41
+
42
+
43
+ def get_event_loop():
44
+ """Get or create the shared event loop."""
45
+ global _event_loop, _loop_thread
46
+ with _loop_lock:
47
+ if _event_loop is None or not _event_loop.is_running():
48
+ _event_loop = asyncio.new_event_loop()
49
+ _loop_thread = threading.Thread(target=_run_event_loop, daemon=True)
50
+ _loop_thread.start()
51
+ # Wait for loop to start
52
+ time.sleep(0.1)
53
+ return _event_loop
54
+
55
+
56
+ def _run_event_loop():
57
+ """Run the event loop in background thread."""
58
+ asyncio.set_event_loop(_event_loop)
59
+ _event_loop.run_forever()
60
+
61
 
62
  class UserSession:
63
  """Per-user session state."""
64
  def __init__(self):
65
+ self.session_id = str(uuid.uuid4())
66
+ self.audio_queue = asyncio.Queue(maxsize=100) # Use async queue
67
  self.transcription_text = ""
68
  self.is_running = False
69
  self.status_message = "ready"
 
72
  self.session_start_time = None
73
  self.last_audio_time = None
74
  self._start_lock = threading.Lock()
75
+ self._task = None # Track the async task
76
 
77
 
78
  # Load CSS from external file
 
242
  if session.last_audio_time is not None:
243
  idle = time.time() - session.last_audio_time
244
  if idle >= INACTIVITY_TIMEOUT:
 
245
  session.is_running = False
246
  session.status_message = "ready"
247
  break
 
249
  if session.session_start_time is not None:
250
  elapsed = time.time() - session.session_start_time
251
  if elapsed >= SESSION_TIMEOUT:
 
252
  session.is_running = False
253
  session.status_message = "timeout"
254
  break
255
 
256
+ try:
257
+ chunk = await asyncio.wait_for(session.audio_queue.get(), timeout=0.1)
258
+ if session.is_running:
259
+ await ws.send(
260
+ json.dumps(
261
+ {"type": "input_audio_buffer.append", "audio": chunk}
262
+ )
263
  )
264
+ except asyncio.TimeoutError:
265
+ continue
 
266
  except Exception as e:
267
+ if session.is_running: # Only log if unexpected
268
+ print(f"Error sending audio: {e}")
269
  session.is_running = False
270
  break
271
 
 
292
  session.word_timestamps.append(time.time())
293
 
294
  session.current_wpm = calculate_wpm(session)
295
+ except asyncio.CancelledError:
296
+ pass # Normal cancellation
297
  except Exception as e:
298
+ if session.is_running:
299
+ print(f"Error receiving transcription: {e}")
300
  session.is_running = False
301
 
302
  await asyncio.gather(send_audio(), receive_transcription(), return_exceptions=True)
303
+ except asyncio.CancelledError:
304
+ pass # Normal cancellation
305
  except websockets.exceptions.ConnectionClosed:
306
+ pass # Normal closure
 
307
  except Exception as e:
308
  error_msg = str(e) if str(e) else type(e).__name__
309
+ if "ConnectionReset" not in error_msg: # Suppress common disconnect errors
310
+ print(f"WebSocket error: {error_msg}")
311
  session.status_message = "error"
312
  finally:
313
  session.is_running = False
314
+ # Remove from active sessions
315
+ with _sessions_lock:
316
+ _active_sessions.pop(session.session_id, None)
317
 
318
 
319
  def start_websocket(session):
320
+ """Start WebSocket connection using the shared event loop."""
321
  session.is_running = True
322
+
323
+ # Register this session
324
+ with _sessions_lock:
325
+ _active_sessions[session.session_id] = session
326
+
327
+ # Submit to the shared event loop
328
+ loop = get_event_loop()
329
+ future = asyncio.run_coroutine_threadsafe(websocket_handler(session), loop)
330
+ session._task = future
331
+
332
+ # Don't block - the coroutine runs in the background
333
+ # Cleanup happens in websocket_handler's finally block
334
 
335
 
336
  def auto_start_recording(session):
 
340
  if session.is_running:
341
  return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm)
342
 
343
+ # Check if we've hit max concurrent sessions
344
+ with _sessions_lock:
345
+ if len(_active_sessions) >= MAX_CONCURRENT_SESSIONS:
346
+ session.status_message = "error"
347
+ return get_transcription_html("Server at capacity. Please try again later.", "error", "")
348
+
349
  session.transcription_text = ""
350
  session.word_timestamps = []
351
  session.current_wpm = "Calibrating..."
352
  session.session_start_time = time.time()
353
  session.last_audio_time = time.time()
354
  session.status_message = "connecting"
355
+
356
+ # Start websocket (now non-blocking, uses shared event loop)
357
+ start_websocket(session)
358
 
359
  return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm)
360
 
 
364
  session.is_running = False
365
  session.last_audio_time = None
366
 
367
+ # Cancel the running task if any
368
+ if session._task is not None:
369
+ session._task.cancel()
370
+ session._task = None
371
+
372
+ # Remove from active sessions
373
+ with _sessions_lock:
374
+ _active_sessions.pop(session.session_id, None)
375
+
376
+ # Create a fresh async queue (old one may have items)
377
+ session.audio_queue = asyncio.Queue(maxsize=100)
378
 
379
  session.transcription_text = ""
380
  session.word_timestamps = []
 
429
  pcm16 = (audio_float * 32767).astype(np.int16)
430
  b64_chunk = base64.b64encode(pcm16.tobytes()).decode("utf-8")
431
 
432
+ # Non-blocking put to async queue (thread-safe)
433
  try:
434
+ loop = get_event_loop()
435
+ loop.call_soon_threadsafe(lambda: _safe_queue_put(session.audio_queue, b64_chunk))
436
+ except Exception:
437
+ pass # Skip if queue is full or loop issues
438
 
439
  return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm)
440
  except Exception as e:
 
442
  return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm)
443
 
444
 
445
+ def _safe_queue_put(q, item):
446
+ """Safely put item in async queue without blocking."""
447
+ try:
448
+ q.put_nowait(item)
449
+ except asyncio.QueueFull:
450
+ pass # Drop frame if queue is full
451
+
452
+
453
  # Gradio interface
454
  with gr.Blocks(title="Voxtral Real-time Transcription") as demo:
455
  session_state = gr.State(value=lambda: UserSession())
 
494
  inputs=[audio_input, session_state],
495
  outputs=[transcription_display],
496
  show_progress="hidden",
497
+ concurrency_limit=100, # Allow many concurrent audio streams
498
  )
499
 
500
  model = os.environ.get("MODEL", "mistralai/Voxtral-Mini-4B-Realtime-2602")
 
502
 
503
  ws_url = f"wss://{host}/v1/realtime"
504
 
505
+ # Initialize the shared event loop at startup
506
+ get_event_loop()
507
+
508
+ demo.queue(default_concurrency_limit=50)
509
+ demo.launch(css=CUSTOM_CSS, theme=gr.themes.Base(), ssr_mode=False, max_threads=100)