Joffrey Thomas commited on
Commit
9930cc6
·
1 Parent(s): f67b151

change app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -25
app.py CHANGED
@@ -39,6 +39,61 @@ _loop_lock = threading.Lock()
39
  _active_sessions = {}
40
  _sessions_lock = threading.Lock()
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  def get_event_loop():
44
  """Get or create the shared event loop."""
@@ -63,7 +118,7 @@ class UserSession:
63
  """Per-user session state."""
64
  def __init__(self):
65
  self.session_id = str(uuid.uuid4())
66
- self.audio_queue = asyncio.Queue(maxsize=100) # Use async queue
67
  self.transcription_text = ""
68
  self.is_running = False
69
  self.status_message = "ready"
@@ -73,6 +128,17 @@ class UserSession:
73
  self.last_audio_time = None
74
  self._start_lock = threading.Lock()
75
  self._task = None # Track the async task
 
 
 
 
 
 
 
 
 
 
 
76
 
77
 
78
  # Load CSS from external file
@@ -314,6 +380,7 @@ async def websocket_handler(session):
314
  # Remove from active sessions
315
  with _sessions_lock:
316
  _active_sessions.pop(session.session_id, None)
 
317
 
318
 
319
  def start_websocket(session):
@@ -333,11 +400,18 @@ def start_websocket(session):
333
  # Cleanup happens in websocket_handler's finally block
334
 
335
 
336
- def ensure_session(session):
337
- """Ensure we have a valid UserSession instance (not the lambda factory)."""
338
- if session is None or callable(session):
339
- return UserSession()
340
- return session
 
 
 
 
 
 
 
341
 
342
 
343
  def auto_start_recording(session):
@@ -366,9 +440,9 @@ def auto_start_recording(session):
366
  return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm)
367
 
368
 
369
- def clear_history(session):
370
  """Stop the websocket connection and clear all history."""
371
- session = ensure_session(session)
372
  session.is_running = False
373
  session.last_audio_time = None
374
 
@@ -381,8 +455,8 @@ def clear_history(session):
381
  with _sessions_lock:
382
  _active_sessions.pop(session.session_id, None)
383
 
384
- # Create a fresh async queue (old one may have items)
385
- session.audio_queue = asyncio.Queue(maxsize=100)
386
 
387
  session.transcription_text = ""
388
  session.word_timestamps = []
@@ -390,17 +464,18 @@ def clear_history(session):
390
  session.session_start_time = None
391
  session.status_message = "ready"
392
 
393
- return get_transcription_html("", "ready", "Calibrating..."), None
 
394
 
395
 
396
- def process_audio(audio, session):
397
  """Process incoming audio and queue for streaming."""
398
- session = ensure_session(session)
399
  try:
400
  # Quick return if audio is None
401
  if audio is None:
402
  wpm = session.current_wpm if session.is_running else "Calibrating..."
403
- return get_transcription_html(session.transcription_text, session.status_message, wpm)
404
 
405
  # Update last audio time for inactivity tracking
406
  session.last_audio_time = time.time()
@@ -411,7 +486,7 @@ def process_audio(audio, session):
411
 
412
  # Skip processing if session stopped
413
  if not session.is_running:
414
- return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm)
415
 
416
  sample_rate, audio_data = audio
417
 
@@ -445,14 +520,11 @@ def process_audio(audio, session):
445
  except Exception:
446
  pass # Skip if queue is full or loop issues
447
 
448
- return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm)
449
  except Exception as e:
450
  print(f"Error processing audio: {e}")
451
- # Safely get session attributes with fallbacks
452
- text = getattr(session, 'transcription_text', '') if not callable(session) else ''
453
- status = getattr(session, 'status_message', 'error') if not callable(session) else 'error'
454
- wpm = getattr(session, 'current_wpm', '') if not callable(session) else ''
455
- return get_transcription_html(text, status, wpm)
456
 
457
 
458
  def _safe_queue_put(q, item):
@@ -465,7 +537,8 @@ def _safe_queue_put(q, item):
465
 
466
  # Gradio interface
467
  with gr.Blocks(title="Voxtral Real-time Transcription") as demo:
468
- session_state = gr.State(value=lambda: UserSession())
 
469
 
470
  # Header
471
  gr.HTML(get_header_html())
@@ -499,13 +572,13 @@ with gr.Blocks(title="Voxtral Real-time Transcription") as demo:
499
  clear_btn.click(
500
  clear_history,
501
  inputs=[session_state],
502
- outputs=[transcription_display, audio_input]
503
  )
504
 
505
  audio_input.stream(
506
  process_audio,
507
  inputs=[audio_input, session_state],
508
- outputs=[transcription_display],
509
  show_progress="hidden",
510
  concurrency_limit=100, # Allow many concurrent audio streams
511
  )
@@ -515,7 +588,6 @@ host = os.environ.get("HOST", "")
515
 
516
  ws_url = f"wss://{host}/v1/realtime"
517
 
518
- # Initialize the shared event loop at startup
519
  get_event_loop()
520
 
521
  demo.queue(default_concurrency_limit=50)
 
39
  _active_sessions = {}
40
  _sessions_lock = threading.Lock()
41
 
42
+ # Global session registry - sessions are stored here and looked up by ID
43
+ _session_registry = {}
44
+ _registry_lock = threading.Lock()
45
+ _last_cleanup = time.time()
46
+ SESSION_REGISTRY_CLEANUP_INTERVAL = 60 # seconds
47
+ SESSION_MAX_AGE = 300 # 5 minutes - remove sessions older than this
48
+
49
+
50
+ def get_or_create_session(session_id: str = None) -> "UserSession":
51
+ """Get existing session by ID or create a new one."""
52
+ global _last_cleanup
53
+
54
+ # Periodic cleanup of stale sessions
55
+ now = time.time()
56
+ if now - _last_cleanup > SESSION_REGISTRY_CLEANUP_INTERVAL:
57
+ _cleanup_stale_sessions()
58
+ _last_cleanup = now
59
+
60
+ with _registry_lock:
61
+ if session_id and session_id in _session_registry:
62
+ session = _session_registry[session_id]
63
+ session._last_accessed = now
64
+ return session
65
+
66
+ # Create new session
67
+ session = UserSession()
68
+ session._last_accessed = now
69
+ _session_registry[session.session_id] = session
70
+ return session
71
+
72
+
73
+ def _cleanup_stale_sessions():
74
+ """Remove sessions that haven't been accessed recently."""
75
+ now = time.time()
76
+ to_remove = []
77
+
78
+ with _registry_lock:
79
+ for session_id, session in _session_registry.items():
80
+ last_accessed = getattr(session, '_last_accessed', 0)
81
+ # Remove if: not running AND (no activity for SESSION_MAX_AGE)
82
+ if not session.is_running and (now - last_accessed > SESSION_MAX_AGE):
83
+ to_remove.append(session_id)
84
+
85
+ for session_id in to_remove:
86
+ _session_registry.pop(session_id, None)
87
+
88
+ if to_remove:
89
+ print(f"Cleaned up {len(to_remove)} stale sessions. Active: {len(_session_registry)}")
90
+
91
+
92
+ def cleanup_session(session_id: str):
93
+ """Remove session from registry."""
94
+ with _registry_lock:
95
+ _session_registry.pop(session_id, None)
96
+
97
 
98
  def get_event_loop():
99
  """Get or create the shared event loop."""
 
118
  """Per-user session state."""
119
  def __init__(self):
120
  self.session_id = str(uuid.uuid4())
121
+ self._audio_queue = None # Created lazily in the correct event loop
122
  self.transcription_text = ""
123
  self.is_running = False
124
  self.status_message = "ready"
 
128
  self.last_audio_time = None
129
  self._start_lock = threading.Lock()
130
  self._task = None # Track the async task
131
+
132
+ @property
133
+ def audio_queue(self):
134
+ """Lazily create audio queue to ensure it's in the right event loop."""
135
+ if self._audio_queue is None:
136
+ self._audio_queue = asyncio.Queue(maxsize=100)
137
+ return self._audio_queue
138
+
139
+ def reset_queue(self):
140
+ """Reset the audio queue."""
141
+ self._audio_queue = asyncio.Queue(maxsize=100)
142
 
143
 
144
  # Load CSS from external file
 
380
  # Remove from active sessions
381
  with _sessions_lock:
382
  _active_sessions.pop(session.session_id, None)
383
+ # Note: Don't remove from registry here - session might be reused
384
 
385
 
386
  def start_websocket(session):
 
400
  # Cleanup happens in websocket_handler's finally block
401
 
402
 
403
+ def ensure_session(session_id):
404
+ """Get or create a valid UserSession from a session_id."""
405
+ # Handle various invalid inputs
406
+ if session_id is None or callable(session_id):
407
+ return get_or_create_session()
408
+
409
+ # If it's already a UserSession object (legacy), return it
410
+ if isinstance(session_id, UserSession):
411
+ return session_id
412
+
413
+ # Otherwise treat it as a session_id string
414
+ return get_or_create_session(str(session_id))
415
 
416
 
417
  def auto_start_recording(session):
 
440
  return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm)
441
 
442
 
443
+ def clear_history(session_id):
444
  """Stop the websocket connection and clear all history."""
445
+ session = ensure_session(session_id)
446
  session.is_running = False
447
  session.last_audio_time = None
448
 
 
455
  with _sessions_lock:
456
  _active_sessions.pop(session.session_id, None)
457
 
458
+ # Reset the queue
459
+ session.reset_queue()
460
 
461
  session.transcription_text = ""
462
  session.word_timestamps = []
 
464
  session.session_start_time = None
465
  session.status_message = "ready"
466
 
467
+ # Return the session_id to maintain state
468
+ return get_transcription_html("", "ready", "Calibrating..."), None, session.session_id
469
 
470
 
471
+ def process_audio(audio, session_id):
472
  """Process incoming audio and queue for streaming."""
473
+ session = ensure_session(session_id)
474
  try:
475
  # Quick return if audio is None
476
  if audio is None:
477
  wpm = session.current_wpm if session.is_running else "Calibrating..."
478
+ return get_transcription_html(session.transcription_text, session.status_message, wpm), session.session_id
479
 
480
  # Update last audio time for inactivity tracking
481
  session.last_audio_time = time.time()
 
486
 
487
  # Skip processing if session stopped
488
  if not session.is_running:
489
+ return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm), session.session_id
490
 
491
  sample_rate, audio_data = audio
492
 
 
520
  except Exception:
521
  pass # Skip if queue is full or loop issues
522
 
523
+ return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm), session.session_id
524
  except Exception as e:
525
  print(f"Error processing audio: {e}")
526
+ # Return safe defaults
527
+ return get_transcription_html("", "error", ""), session.session_id if hasattr(session, 'session_id') else None
 
 
 
528
 
529
 
530
  def _safe_queue_put(q, item):
 
537
 
538
  # Gradio interface
539
  with gr.Blocks(title="Voxtral Real-time Transcription") as demo:
540
+ # Store just the session_id string - much more reliable than complex objects
541
+ session_state = gr.State(value=None)
542
 
543
  # Header
544
  gr.HTML(get_header_html())
 
572
  clear_btn.click(
573
  clear_history,
574
  inputs=[session_state],
575
+ outputs=[transcription_display, audio_input, session_state]
576
  )
577
 
578
  audio_input.stream(
579
  process_audio,
580
  inputs=[audio_input, session_state],
581
+ outputs=[transcription_display, session_state],
582
  show_progress="hidden",
583
  concurrency_limit=100, # Allow many concurrent audio streams
584
  )
 
588
 
589
  ws_url = f"wss://{host}/v1/realtime"
590
 
 
591
  get_event_loop()
592
 
593
  demo.queue(default_concurrency_limit=50)