import asyncio import base64 import json import os import queue import threading import time import uuid import gradio as gr import numpy as np import websockets # Load Voxtral icon as base64 VOXTRAL_ICON_B64 = "" icon_path = os.path.join(os.path.dirname(__file__), "assets", "voxtral.png") if os.path.exists(icon_path): with open(icon_path, "rb") as f: VOXTRAL_ICON_B64 = base64.b64encode(f.read()).decode("utf-8") SAMPLE_RATE = 16_000 WARMUP_DURATION = 2.0 # seconds of silence for warmup WPM_WINDOW = 10 # seconds for running mean calculation CALIBRATION_PERIOD = 5 # seconds before showing WPM SESSION_TIMEOUT = int(os.environ.get("SESSION_TIMEOUT", "35")) # Max 30s per session INACTIVITY_TIMEOUT = int(os.environ.get("INACTIVITY_TIMEOUT", "10")) # Close after 10s silence MAX_CONCURRENT_SESSIONS = int(os.environ.get("MAX_SESSIONS", "50")) # Global config (shared across users) ws_url = "" model = "" # Global event loop for all websocket connections (runs in single background thread) _event_loop = None _loop_thread = None _loop_lock = threading.Lock() # Track active sessions for resource management _active_sessions = {} _sessions_lock = threading.Lock() # Global session registry - sessions are stored here and looked up by ID _session_registry = {} _registry_lock = threading.Lock() _last_cleanup = time.time() SESSION_REGISTRY_CLEANUP_INTERVAL = 30 # seconds SESSION_MAX_AGE = 30 # 30 seconds - remove sessions older than this def get_or_create_session(session_id: str = None) -> "UserSession": """Get existing session by ID or create a new one.""" global _last_cleanup # Periodic cleanup of stale sessions now = time.time() if now - _last_cleanup > SESSION_REGISTRY_CLEANUP_INTERVAL: _cleanup_stale_sessions() _last_cleanup = now with _registry_lock: if session_id and session_id in _session_registry: session = _session_registry[session_id] # Validate the session is actually a UserSession instance if isinstance(session, UserSession): session._last_accessed = now return session else: # Corrupted registry entry - remove and create new print(f"WARNING: Corrupted session registry entry for {session_id}: {type(session)}") del _session_registry[session_id] # Create new session session = UserSession() session._last_accessed = now _session_registry[session.session_id] = session return session def _cleanup_stale_sessions(): """Remove sessions that haven't been accessed recently.""" now = time.time() to_remove_from_registry = [] to_remove_from_active = [] # Need both locks to safely check both dictionaries with _registry_lock: with _sessions_lock: # Find stale sessions in registry for session_id, session in _session_registry.items(): # NEVER remove if still in active_sessions (websocket still running) if session_id in _active_sessions: continue last_accessed = getattr(session, '_last_accessed', 0) # Remove if: not running AND not active AND old if not session.is_running and (now - last_accessed > SESSION_MAX_AGE): to_remove_from_registry.append(session_id) # Find orphaned sessions in active_sessions (not in registry anymore) for session_id, session in list(_active_sessions.items()): if session_id not in _session_registry: # Orphaned - mark for removal if not session.is_running: to_remove_from_active.append(session_id) # Clean up registry for session_id in to_remove_from_registry: _session_registry.pop(session_id, None) # Clean up orphaned active sessions for session_id in to_remove_from_active: _active_sessions.pop(session_id, None) active_count = len(_active_sessions) registry_count = len(_session_registry) total_cleaned = len(to_remove_from_registry) + len(to_remove_from_active) if total_cleaned > 0: print(f"Cleaned up {len(to_remove_from_registry)} stale + {len(to_remove_from_active)} orphaned sessions. Registry: {registry_count}, Active: {active_count}") def cleanup_session(session_id: str): """Remove session from registry.""" with _registry_lock: _session_registry.pop(session_id, None) def kill_all_sessions(): """Emergency cleanup - kill ALL active sessions to free capacity.""" killed_count = 0 with _sessions_lock: sessions_to_kill = list(_active_sessions.values()) for session in sessions_to_kill: try: session.is_running = False session._stopped_by_user = True # Close websocket immediately if session._websocket is not None: loop = get_event_loop() try: asyncio.run_coroutine_threadsafe(session._websocket.close(), loop) except Exception: pass session._websocket = None # Cancel the task if session._task is not None: session._task.cancel() session._task = None killed_count += 1 except Exception as e: print(f"Error killing session {session.session_id[:8]}: {e}") # Clear both dictionaries with _registry_lock: with _sessions_lock: _active_sessions.clear() _session_registry.clear() print(f"CAPACITY RESET: Killed {killed_count} sessions. All sessions cleared.") def get_event_loop(): """Get or create the shared event loop.""" global _event_loop, _loop_thread with _loop_lock: if _event_loop is None or not _event_loop.is_running(): _event_loop = asyncio.new_event_loop() _loop_thread = threading.Thread(target=_run_event_loop, daemon=True) _loop_thread.start() # Wait for loop to start time.sleep(0.1) return _event_loop def _run_event_loop(): """Run the event loop in background thread.""" asyncio.set_event_loop(_event_loop) _event_loop.run_forever() class UserSession: """Per-user session state.""" def __init__(self): self.session_id = str(uuid.uuid4()) # Use a thread-safe queue for cross-thread communication self._audio_queue = queue.Queue(maxsize=200) self.transcription_text = "" self.is_running = False self.status_message = "ready" self.word_timestamps = [] self.current_wpm = "Calibrating..." self.session_start_time = None self.last_audio_time = None self._start_lock = threading.Lock() self._task = None # Track the async task self._websocket = None # Store websocket for forced closure self._stopped_by_user = False # Track if user explicitly stopped @property def audio_queue(self): """Return the thread-safe queue.""" return self._audio_queue def reset_queue(self): """Reset the audio queue.""" self._audio_queue = queue.Queue(maxsize=200) # Load CSS from external file css_path = os.path.join(os.path.dirname(__file__), "style.css") with open(css_path, "r") as f: CUSTOM_CSS = f.read() def get_header_html() -> str: """Generate the header HTML with Voxtral logo.""" if VOXTRAL_ICON_B64: logo_html = f'' else: logo_html = '' return f"""

{logo_html}Real-time Speech Transcription

Click the microphone to start streaming transcriptions. The system will warm up automatically - so there will be a small delay

Talk naturally. Talk fast. Talk ridiculously fast. I can handle it.

""" def get_status_html(status: str) -> str: """Generate status badge HTML based on current status.""" status_configs = { "ready": ("STANDBY", "status-ready", ""), "connecting": ("CONNECTING", "status-connecting", "fast"), "warming": ("WARMING UP", "status-warming", "fast"), "listening": ("LISTENING", "status-listening", "animate"), "timeout": ("TIMEOUT", "status-timeout", ""), "error": ("ERROR", "status-error", ""), } label, css_class, dot_class = status_configs.get(status, status_configs["ready"]) dot_anim = f" {dot_class}" if dot_class else "" return f"""
{label}
""" def get_transcription_html(transcript: str, status: str, wpm: str = "Calibrating...") -> str: """Generate the full transcription card HTML.""" status_badge = get_status_html(status) wpm_badge = f'
{wpm}
' if transcript: cursor_html = '' if status == "listening" else "" content_html = f"""
{transcript}{cursor_html}
""" elif status in ["listening", "warming", "connecting"]: content_html = """

Listening for audio...

""" elif status == "timeout": content_html = """

Session timeout (5 minutes)

Click 'Clear History' and refresh to restart.

""" else: content_html = """

// Awaiting audio input...

// Click the microphone to start.

""" # Use base64 image if available if VOXTRAL_ICON_B64: icon_html = f'Voxtral' else: icon_html = '🎙️' return f"""
{icon_html} Transcription Output
{wpm_badge} {status_badge}
{content_html}
""" def calculate_wpm(session): """Calculate words per minute based on running mean of last WPM_WINDOW seconds.""" if session.session_start_time is not None: elapsed = time.time() - session.session_start_time if elapsed < CALIBRATION_PERIOD: return "Calibrating..." if len(session.word_timestamps) < 2: return "0.0 WPM" current_time = time.time() cutoff_time = current_time - WPM_WINDOW session.word_timestamps = [ts for ts in session.word_timestamps if ts >= cutoff_time] if len(session.word_timestamps) < 2: return "0.0 WPM" time_span = current_time - session.word_timestamps[0] if time_span == 0: return "0.0 WPM" word_count = len(session.word_timestamps) wpm = (word_count / time_span) * 60 return f"{round(wpm, 1)} WPM" async def send_silence(ws, duration=2.0): """Send silence to warm up the model.""" num_samples = int(SAMPLE_RATE * duration) silence = np.zeros(num_samples, dtype=np.int16) chunk_size = int(SAMPLE_RATE * 0.1) for i in range(0, num_samples, chunk_size): chunk = silence[i:i + chunk_size] b64_chunk = base64.b64encode(chunk.tobytes()).decode("utf-8") await ws.send( json.dumps( {"type": "input_audio_buffer.append", "audio": b64_chunk} ) ) await asyncio.sleep(0.05) async def websocket_handler(session): """Connect to WebSocket and handle audio streaming + transcription.""" ws = None try: # Add connection timeout to prevent hanging async with asyncio.timeout(10): # 10 second connection timeout ws = await websockets.connect(ws_url) # Store websocket reference so it can be closed externally session._websocket = ws async with ws: await asyncio.wait_for(ws.recv(), timeout=5) await ws.send(json.dumps({"type": "session.update", "model": model})) session.status_message = "warming" await send_silence(ws, WARMUP_DURATION) await ws.send(json.dumps({"type": "input_audio_buffer.commit"})) session.status_message = "listening" async def send_audio(): while session.is_running: try: # Check for inactivity timeout if session.last_audio_time is not None: idle = time.time() - session.last_audio_time if idle >= INACTIVITY_TIMEOUT: session.is_running = False session.status_message = "ready" break if session.session_start_time is not None: elapsed = time.time() - session.session_start_time if elapsed >= SESSION_TIMEOUT: session.is_running = False session.status_message = "timeout" break # Use thread-safe queue with non-blocking get + async sleep try: chunk = session.audio_queue.get_nowait() if session.is_running: await ws.send( json.dumps( {"type": "input_audio_buffer.append", "audio": chunk} ) ) except queue.Empty: # No audio available, yield control briefly await asyncio.sleep(0.05) continue except Exception as e: if session.is_running: # Only log if unexpected print(f"Error sending audio: {e}") session.is_running = False break async def receive_transcription(): try: async for message in ws: if not session.is_running: break if session.session_start_time is not None: elapsed = time.time() - session.session_start_time if elapsed >= SESSION_TIMEOUT: session.status_message = "timeout" session.is_running = False break data = json.loads(message) if data.get("type") == "transcription.delta": delta = data["delta"] session.transcription_text += delta words = delta.split() for _ in words: session.word_timestamps.append(time.time()) session.current_wpm = calculate_wpm(session) except asyncio.CancelledError: pass # Normal cancellation except Exception as e: if session.is_running: print(f"Error receiving transcription: {e}") session.is_running = False await asyncio.gather(send_audio(), receive_transcription(), return_exceptions=True) except asyncio.CancelledError: pass # Normal cancellation except websockets.exceptions.ConnectionClosed: pass # Normal closure except asyncio.TimeoutError: print(f"WebSocket connection timeout for session {session.session_id[:8]}") session.status_message = "error" except Exception as e: error_msg = str(e) if str(e) else type(e).__name__ if "ConnectionReset" not in error_msg: # Suppress common disconnect errors print(f"WebSocket error: {error_msg}") session.status_message = "error" finally: session.is_running = False session._websocket = None # Only remove and log if not already handled by stop_session if not session._stopped_by_user: with _sessions_lock: removed = _active_sessions.pop(session.session_id, None) active_count = len(_active_sessions) if removed: print(f"Session {session.session_id[:8]} ended. Active sessions: {active_count}") def start_websocket(session): """Start WebSocket connection using the shared event loop.""" session.is_running = True # Register this session with _sessions_lock: _active_sessions[session.session_id] = session active_count = len(_active_sessions) print(f"Starting session {session.session_id[:8]}. Active sessions: {active_count}") # Submit to the shared event loop loop = get_event_loop() future = asyncio.run_coroutine_threadsafe(websocket_handler(session), loop) session._task = future # Don't block - the coroutine runs in the background # Cleanup happens in websocket_handler's finally block def ensure_session(session_id): """Get or create a valid UserSession from a session_id.""" # Handle various invalid inputs if session_id is None or callable(session_id): session = get_or_create_session() return session # If it's already a UserSession object (legacy), return it if isinstance(session_id, UserSession): return session_id # Otherwise treat it as a session_id string session = get_or_create_session(str(session_id)) # Defensive check - this should never happen but helps debug if not isinstance(session, UserSession): print(f"WARNING: ensure_session returned non-UserSession: {type(session)}") return get_or_create_session() return session def auto_start_recording(session): """Automatically start the transcription service when audio begins.""" # Protect against startup races: Gradio can call `process_audio` concurrently. with session._start_lock: if session.is_running: return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm) # Check if we've hit max concurrent sessions - kill all if so with _sessions_lock: active_at_capacity = len(_active_sessions) >= MAX_CONCURRENT_SESSIONS with _registry_lock: registry_over = len(_session_registry) > MAX_CONCURRENT_SESSIONS if active_at_capacity or registry_over: kill_all_sessions() session.status_message = "error" return get_transcription_html("Server reset due to capacity. Please click the microphone to restart.", "error", "") session.transcription_text = "" session.word_timestamps = [] session.current_wpm = "Calibrating..." session.session_start_time = time.time() session.last_audio_time = time.time() session.status_message = "connecting" # Start websocket (now non-blocking, uses shared event loop) start_websocket(session) return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm) def stop_session(session_id): """Stop the websocket connection and invalidate the session. Returns None for session_id so a fresh session is created on next recording. This prevents duplicate session issues when users stop and restart quickly. """ session = ensure_session(session_id) old_transcript = session.transcription_text old_wpm = session.current_wpm if session.is_running: session.is_running = False session.last_audio_time = None session._stopped_by_user = True # Mark as user-stopped to avoid duplicate logging # Close the websocket immediately to force cleanup if session._websocket is not None: loop = get_event_loop() try: asyncio.run_coroutine_threadsafe(session._websocket.close(), loop) except Exception: pass # Ignore errors during close session._websocket = None # Cancel the running task if any if session._task is not None: session._task.cancel() session._task = None # Remove from active sessions with _sessions_lock: _active_sessions.pop(session.session_id, None) active_count = len(_active_sessions) print(f"Mic stopped - session {session.session_id[:8]} ended. Active sessions: {active_count}") # Remove from registry - the session is done cleanup_session(session.session_id) # Return None for session_id - a fresh session will be created on next recording # This ensures no duplicate sessions when users stop/start quickly return get_transcription_html(old_transcript, "ready", old_wpm), None def clear_history(session_id): """Stop the websocket connection and clear all history.""" session = ensure_session(session_id) session.is_running = False session.last_audio_time = None session._stopped_by_user = True # Mark as user-stopped # Close the websocket immediately if session._websocket is not None: loop = get_event_loop() try: asyncio.run_coroutine_threadsafe(session._websocket.close(), loop) except Exception: pass session._websocket = None # Cancel the running task if any if session._task is not None: session._task.cancel() session._task = None # Remove from active sessions with _sessions_lock: _active_sessions.pop(session.session_id, None) # Reset the queue session.reset_queue() session.transcription_text = "" session.word_timestamps = [] session.current_wpm = "Calibrating..." session.session_start_time = None session.status_message = "ready" # Return the session_id to maintain state return get_transcription_html("", "ready", "Calibrating..."), None, session.session_id def process_audio(audio, session_id): """Process incoming audio and queue for streaming.""" # Check capacity - if at or above max, kill ALL sessions to reset with _sessions_lock: active_count = len(_active_sessions) is_active_user = session_id and any(s.session_id == session_id for s in _active_sessions.values()) with _registry_lock: registry_count = len(_session_registry) # Kill all if: # 1. Registry exceeds limit (memory safety) # 2. Active sessions exceed limit # 3. At active capacity AND new user trying to join if registry_count > MAX_CONCURRENT_SESSIONS or active_count > MAX_CONCURRENT_SESSIONS or (active_count >= MAX_CONCURRENT_SESSIONS and not is_active_user): kill_all_sessions() return get_transcription_html( "Server reset due to capacity. Please click the microphone to restart.", "error", "" ), None # Always ensure we have a valid session first try: session = ensure_session(session_id) except Exception as e: print(f"Error creating session: {e}") # Create a fresh session if ensure_session fails session = UserSession() _session_registry[session.session_id] = session # Cache session_id early in case of later errors current_session_id = session.session_id try: # Quick return if audio is None if audio is None: wpm = session.current_wpm if session.is_running else "Calibrating..." return get_transcription_html(session.transcription_text, session.status_message, wpm), current_session_id # Update last audio time for inactivity tracking session.last_audio_time = time.time() # Auto-start if not running if not session.is_running and session.status_message not in ["timeout", "error"]: auto_start_recording(session) # Skip processing if session stopped if not session.is_running: return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm), current_session_id sample_rate, audio_data = audio # Convert to mono if stereo if len(audio_data.shape) > 1: audio_data = audio_data.mean(axis=1) # Normalize to float if audio_data.dtype == np.int16: audio_float = audio_data.astype(np.float32) / 32767.0 else: audio_float = audio_data.astype(np.float32) # Resample to 16kHz if needed if sample_rate != SAMPLE_RATE: num_samples = int(len(audio_float) * SAMPLE_RATE / sample_rate) audio_float = np.interp( np.linspace(0, len(audio_float) - 1, num_samples), np.arange(len(audio_float)), audio_float, ) # Convert to PCM16 and base64 encode pcm16 = (audio_float * 32767).astype(np.int16) b64_chunk = base64.b64encode(pcm16.tobytes()).decode("utf-8") # Put directly into thread-safe queue (no event loop needed) try: session.audio_queue.put_nowait(b64_chunk) except Exception: pass # Skip if queue is full return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm), current_session_id except Exception as e: print(f"Error processing audio: {e}") # Return safe defaults - always include session_id to maintain state return get_transcription_html("", "error", ""), current_session_id # Gradio interface with gr.Blocks(title="Voxtral Real-time Transcription") as demo: # Store just the session_id string - much more reliable than complex objects session_state = gr.State(value=None) # Header gr.HTML(get_header_html()) # Transcription output transcription_display = gr.HTML( value=get_transcription_html("", "ready", "Calibrating..."), elem_id="transcription-output" ) # Audio input audio_input = gr.Audio( sources=["microphone"], streaming=True, type="numpy", format="wav", elem_id="audio-input", label="Microphone Input" ) # Clear button clear_btn = gr.Button( "Clear History", elem_classes=["clear-btn"] ) # Info text gr.HTML('

To start again - click on Clear History AND refresh your website.

') # Event handlers clear_btn.click( clear_history, inputs=[session_state], outputs=[transcription_display, audio_input, session_state] ) audio_input.stop_recording( stop_session, inputs=[session_state], outputs=[transcription_display, session_state] ) audio_input.stream( process_audio, inputs=[audio_input, session_state], outputs=[transcription_display, session_state], show_progress="hidden", concurrency_limit=500, ) model = os.environ.get("MODEL", "mistralai/Voxtral-Mini-4B-Realtime-2602") host = os.environ.get("HOST", "") ws_url = f"wss://{host}/v1/realtime" get_event_loop() demo.queue(default_concurrency_limit=200) demo.launch(css=CUSTOM_CSS, theme=gr.themes.Base(), ssr_mode=False, max_threads=200)