import asyncio
import base64
import json
import os
import queue
import threading
import time
import uuid
import gradio as gr
import numpy as np
import websockets
# Load Voxtral icon as base64
VOXTRAL_ICON_B64 = ""
icon_path = os.path.join(os.path.dirname(__file__), "assets", "voxtral.png")
if os.path.exists(icon_path):
with open(icon_path, "rb") as f:
VOXTRAL_ICON_B64 = base64.b64encode(f.read()).decode("utf-8")
SAMPLE_RATE = 16_000
WARMUP_DURATION = 2.0 # seconds of silence for warmup
WPM_WINDOW = 10 # seconds for running mean calculation
CALIBRATION_PERIOD = 5 # seconds before showing WPM
SESSION_TIMEOUT = int(os.environ.get("SESSION_TIMEOUT", "35")) # Max 30s per session
INACTIVITY_TIMEOUT = int(os.environ.get("INACTIVITY_TIMEOUT", "10")) # Close after 10s silence
MAX_CONCURRENT_SESSIONS = int(os.environ.get("MAX_SESSIONS", "50"))
# Global config (shared across users)
ws_url = ""
model = ""
# Global event loop for all websocket connections (runs in single background thread)
_event_loop = None
_loop_thread = None
_loop_lock = threading.Lock()
# Track active sessions for resource management
_active_sessions = {}
_sessions_lock = threading.Lock()
# Global session registry - sessions are stored here and looked up by ID
_session_registry = {}
_registry_lock = threading.Lock()
_last_cleanup = time.time()
SESSION_REGISTRY_CLEANUP_INTERVAL = 30 # seconds
SESSION_MAX_AGE = 30 # 30 seconds - remove sessions older than this
def get_or_create_session(session_id: str = None) -> "UserSession":
"""Get existing session by ID or create a new one."""
global _last_cleanup
# Periodic cleanup of stale sessions
now = time.time()
if now - _last_cleanup > SESSION_REGISTRY_CLEANUP_INTERVAL:
_cleanup_stale_sessions()
_last_cleanup = now
with _registry_lock:
if session_id and session_id in _session_registry:
session = _session_registry[session_id]
# Validate the session is actually a UserSession instance
if isinstance(session, UserSession):
session._last_accessed = now
return session
else:
# Corrupted registry entry - remove and create new
print(f"WARNING: Corrupted session registry entry for {session_id}: {type(session)}")
del _session_registry[session_id]
# Create new session
session = UserSession()
session._last_accessed = now
_session_registry[session.session_id] = session
return session
def _cleanup_stale_sessions():
"""Remove sessions that haven't been accessed recently."""
now = time.time()
to_remove_from_registry = []
to_remove_from_active = []
# Need both locks to safely check both dictionaries
with _registry_lock:
with _sessions_lock:
# Find stale sessions in registry
for session_id, session in _session_registry.items():
# NEVER remove if still in active_sessions (websocket still running)
if session_id in _active_sessions:
continue
last_accessed = getattr(session, '_last_accessed', 0)
# Remove if: not running AND not active AND old
if not session.is_running and (now - last_accessed > SESSION_MAX_AGE):
to_remove_from_registry.append(session_id)
# Find orphaned sessions in active_sessions (not in registry anymore)
for session_id, session in list(_active_sessions.items()):
if session_id not in _session_registry:
# Orphaned - mark for removal
if not session.is_running:
to_remove_from_active.append(session_id)
# Clean up registry
for session_id in to_remove_from_registry:
_session_registry.pop(session_id, None)
# Clean up orphaned active sessions
for session_id in to_remove_from_active:
_active_sessions.pop(session_id, None)
active_count = len(_active_sessions)
registry_count = len(_session_registry)
total_cleaned = len(to_remove_from_registry) + len(to_remove_from_active)
if total_cleaned > 0:
print(f"Cleaned up {len(to_remove_from_registry)} stale + {len(to_remove_from_active)} orphaned sessions. Registry: {registry_count}, Active: {active_count}")
def cleanup_session(session_id: str):
"""Remove session from registry."""
with _registry_lock:
_session_registry.pop(session_id, None)
def kill_all_sessions():
"""Emergency cleanup - kill ALL active sessions to free capacity."""
killed_count = 0
with _sessions_lock:
sessions_to_kill = list(_active_sessions.values())
for session in sessions_to_kill:
try:
session.is_running = False
session._stopped_by_user = True
# Close websocket immediately
if session._websocket is not None:
loop = get_event_loop()
try:
asyncio.run_coroutine_threadsafe(session._websocket.close(), loop)
except Exception:
pass
session._websocket = None
# Cancel the task
if session._task is not None:
session._task.cancel()
session._task = None
killed_count += 1
except Exception as e:
print(f"Error killing session {session.session_id[:8]}: {e}")
# Clear both dictionaries
with _registry_lock:
with _sessions_lock:
_active_sessions.clear()
_session_registry.clear()
print(f"CAPACITY RESET: Killed {killed_count} sessions. All sessions cleared.")
def get_event_loop():
"""Get or create the shared event loop."""
global _event_loop, _loop_thread
with _loop_lock:
if _event_loop is None or not _event_loop.is_running():
_event_loop = asyncio.new_event_loop()
_loop_thread = threading.Thread(target=_run_event_loop, daemon=True)
_loop_thread.start()
# Wait for loop to start
time.sleep(0.1)
return _event_loop
def _run_event_loop():
"""Run the event loop in background thread."""
asyncio.set_event_loop(_event_loop)
_event_loop.run_forever()
class UserSession:
"""Per-user session state."""
def __init__(self):
self.session_id = str(uuid.uuid4())
# Use a thread-safe queue for cross-thread communication
self._audio_queue = queue.Queue(maxsize=200)
self.transcription_text = ""
self.is_running = False
self.status_message = "ready"
self.word_timestamps = []
self.current_wpm = "Calibrating..."
self.session_start_time = None
self.last_audio_time = None
self._start_lock = threading.Lock()
self._task = None # Track the async task
self._websocket = None # Store websocket for forced closure
self._stopped_by_user = False # Track if user explicitly stopped
@property
def audio_queue(self):
"""Return the thread-safe queue."""
return self._audio_queue
def reset_queue(self):
"""Reset the audio queue."""
self._audio_queue = queue.Queue(maxsize=200)
# Load CSS from external file
css_path = os.path.join(os.path.dirname(__file__), "style.css")
with open(css_path, "r") as f:
CUSTOM_CSS = f.read()
def get_header_html() -> str:
"""Generate the header HTML with Voxtral logo."""
if VOXTRAL_ICON_B64:
logo_html = f''
else:
logo_html = ''
return f"""
"""
def get_status_html(status: str) -> str:
"""Generate status badge HTML based on current status."""
status_configs = {
"ready": ("STANDBY", "status-ready", ""),
"connecting": ("CONNECTING", "status-connecting", "fast"),
"warming": ("WARMING UP", "status-warming", "fast"),
"listening": ("LISTENING", "status-listening", "animate"),
"timeout": ("TIMEOUT", "status-timeout", ""),
"error": ("ERROR", "status-error", ""),
}
label, css_class, dot_class = status_configs.get(status, status_configs["ready"])
dot_anim = f" {dot_class}" if dot_class else ""
return f"""{label}
"""
def get_transcription_html(transcript: str, status: str, wpm: str = "Calibrating...") -> str:
"""Generate the full transcription card HTML."""
status_badge = get_status_html(status)
wpm_badge = f'{wpm}
'
if transcript:
cursor_html = '' if status == "listening" else ""
content_html = f"""
{transcript}{cursor_html}
"""
elif status in ["listening", "warming", "connecting"]:
content_html = """
"""
elif status == "timeout":
content_html = """
Session timeout (5 minutes)
Click 'Clear History' and refresh to restart.
"""
else:
content_html = """
// Awaiting audio input...
// Click the microphone to start.
"""
# Use base64 image if available
if VOXTRAL_ICON_B64:
icon_html = f'
'
else:
icon_html = '🎙️'
return f"""
"""
def calculate_wpm(session):
"""Calculate words per minute based on running mean of last WPM_WINDOW seconds."""
if session.session_start_time is not None:
elapsed = time.time() - session.session_start_time
if elapsed < CALIBRATION_PERIOD:
return "Calibrating..."
if len(session.word_timestamps) < 2:
return "0.0 WPM"
current_time = time.time()
cutoff_time = current_time - WPM_WINDOW
session.word_timestamps = [ts for ts in session.word_timestamps if ts >= cutoff_time]
if len(session.word_timestamps) < 2:
return "0.0 WPM"
time_span = current_time - session.word_timestamps[0]
if time_span == 0:
return "0.0 WPM"
word_count = len(session.word_timestamps)
wpm = (word_count / time_span) * 60
return f"{round(wpm, 1)} WPM"
async def send_silence(ws, duration=2.0):
"""Send silence to warm up the model."""
num_samples = int(SAMPLE_RATE * duration)
silence = np.zeros(num_samples, dtype=np.int16)
chunk_size = int(SAMPLE_RATE * 0.1)
for i in range(0, num_samples, chunk_size):
chunk = silence[i:i + chunk_size]
b64_chunk = base64.b64encode(chunk.tobytes()).decode("utf-8")
await ws.send(
json.dumps(
{"type": "input_audio_buffer.append", "audio": b64_chunk}
)
)
await asyncio.sleep(0.05)
async def websocket_handler(session):
"""Connect to WebSocket and handle audio streaming + transcription."""
ws = None
try:
# Add connection timeout to prevent hanging
async with asyncio.timeout(10): # 10 second connection timeout
ws = await websockets.connect(ws_url)
# Store websocket reference so it can be closed externally
session._websocket = ws
async with ws:
await asyncio.wait_for(ws.recv(), timeout=5)
await ws.send(json.dumps({"type": "session.update", "model": model}))
session.status_message = "warming"
await send_silence(ws, WARMUP_DURATION)
await ws.send(json.dumps({"type": "input_audio_buffer.commit"}))
session.status_message = "listening"
async def send_audio():
while session.is_running:
try:
# Check for inactivity timeout
if session.last_audio_time is not None:
idle = time.time() - session.last_audio_time
if idle >= INACTIVITY_TIMEOUT:
session.is_running = False
session.status_message = "ready"
break
if session.session_start_time is not None:
elapsed = time.time() - session.session_start_time
if elapsed >= SESSION_TIMEOUT:
session.is_running = False
session.status_message = "timeout"
break
# Use thread-safe queue with non-blocking get + async sleep
try:
chunk = session.audio_queue.get_nowait()
if session.is_running:
await ws.send(
json.dumps(
{"type": "input_audio_buffer.append", "audio": chunk}
)
)
except queue.Empty:
# No audio available, yield control briefly
await asyncio.sleep(0.05)
continue
except Exception as e:
if session.is_running: # Only log if unexpected
print(f"Error sending audio: {e}")
session.is_running = False
break
async def receive_transcription():
try:
async for message in ws:
if not session.is_running:
break
if session.session_start_time is not None:
elapsed = time.time() - session.session_start_time
if elapsed >= SESSION_TIMEOUT:
session.status_message = "timeout"
session.is_running = False
break
data = json.loads(message)
if data.get("type") == "transcription.delta":
delta = data["delta"]
session.transcription_text += delta
words = delta.split()
for _ in words:
session.word_timestamps.append(time.time())
session.current_wpm = calculate_wpm(session)
except asyncio.CancelledError:
pass # Normal cancellation
except Exception as e:
if session.is_running:
print(f"Error receiving transcription: {e}")
session.is_running = False
await asyncio.gather(send_audio(), receive_transcription(), return_exceptions=True)
except asyncio.CancelledError:
pass # Normal cancellation
except websockets.exceptions.ConnectionClosed:
pass # Normal closure
except asyncio.TimeoutError:
print(f"WebSocket connection timeout for session {session.session_id[:8]}")
session.status_message = "error"
except Exception as e:
error_msg = str(e) if str(e) else type(e).__name__
if "ConnectionReset" not in error_msg: # Suppress common disconnect errors
print(f"WebSocket error: {error_msg}")
session.status_message = "error"
finally:
session.is_running = False
session._websocket = None
# Only remove and log if not already handled by stop_session
if not session._stopped_by_user:
with _sessions_lock:
removed = _active_sessions.pop(session.session_id, None)
active_count = len(_active_sessions)
if removed:
print(f"Session {session.session_id[:8]} ended. Active sessions: {active_count}")
def start_websocket(session):
"""Start WebSocket connection using the shared event loop."""
session.is_running = True
# Register this session
with _sessions_lock:
_active_sessions[session.session_id] = session
active_count = len(_active_sessions)
print(f"Starting session {session.session_id[:8]}. Active sessions: {active_count}")
# Submit to the shared event loop
loop = get_event_loop()
future = asyncio.run_coroutine_threadsafe(websocket_handler(session), loop)
session._task = future
# Don't block - the coroutine runs in the background
# Cleanup happens in websocket_handler's finally block
def ensure_session(session_id):
"""Get or create a valid UserSession from a session_id."""
# Handle various invalid inputs
if session_id is None or callable(session_id):
session = get_or_create_session()
return session
# If it's already a UserSession object (legacy), return it
if isinstance(session_id, UserSession):
return session_id
# Otherwise treat it as a session_id string
session = get_or_create_session(str(session_id))
# Defensive check - this should never happen but helps debug
if not isinstance(session, UserSession):
print(f"WARNING: ensure_session returned non-UserSession: {type(session)}")
return get_or_create_session()
return session
def auto_start_recording(session):
"""Automatically start the transcription service when audio begins."""
# Protect against startup races: Gradio can call `process_audio` concurrently.
with session._start_lock:
if session.is_running:
return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm)
# Check if we've hit max concurrent sessions - kill all if so
with _sessions_lock:
active_at_capacity = len(_active_sessions) >= MAX_CONCURRENT_SESSIONS
with _registry_lock:
registry_over = len(_session_registry) > MAX_CONCURRENT_SESSIONS
if active_at_capacity or registry_over:
kill_all_sessions()
session.status_message = "error"
return get_transcription_html("Server reset due to capacity. Please click the microphone to restart.", "error", "")
session.transcription_text = ""
session.word_timestamps = []
session.current_wpm = "Calibrating..."
session.session_start_time = time.time()
session.last_audio_time = time.time()
session.status_message = "connecting"
# Start websocket (now non-blocking, uses shared event loop)
start_websocket(session)
return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm)
def stop_session(session_id):
"""Stop the websocket connection and invalidate the session.
Returns None for session_id so a fresh session is created on next recording.
This prevents duplicate session issues when users stop and restart quickly.
"""
session = ensure_session(session_id)
old_transcript = session.transcription_text
old_wpm = session.current_wpm
if session.is_running:
session.is_running = False
session.last_audio_time = None
session._stopped_by_user = True # Mark as user-stopped to avoid duplicate logging
# Close the websocket immediately to force cleanup
if session._websocket is not None:
loop = get_event_loop()
try:
asyncio.run_coroutine_threadsafe(session._websocket.close(), loop)
except Exception:
pass # Ignore errors during close
session._websocket = None
# Cancel the running task if any
if session._task is not None:
session._task.cancel()
session._task = None
# Remove from active sessions
with _sessions_lock:
_active_sessions.pop(session.session_id, None)
active_count = len(_active_sessions)
print(f"Mic stopped - session {session.session_id[:8]} ended. Active sessions: {active_count}")
# Remove from registry - the session is done
cleanup_session(session.session_id)
# Return None for session_id - a fresh session will be created on next recording
# This ensures no duplicate sessions when users stop/start quickly
return get_transcription_html(old_transcript, "ready", old_wpm), None
def clear_history(session_id):
"""Stop the websocket connection and clear all history."""
session = ensure_session(session_id)
session.is_running = False
session.last_audio_time = None
session._stopped_by_user = True # Mark as user-stopped
# Close the websocket immediately
if session._websocket is not None:
loop = get_event_loop()
try:
asyncio.run_coroutine_threadsafe(session._websocket.close(), loop)
except Exception:
pass
session._websocket = None
# Cancel the running task if any
if session._task is not None:
session._task.cancel()
session._task = None
# Remove from active sessions
with _sessions_lock:
_active_sessions.pop(session.session_id, None)
# Reset the queue
session.reset_queue()
session.transcription_text = ""
session.word_timestamps = []
session.current_wpm = "Calibrating..."
session.session_start_time = None
session.status_message = "ready"
# Return the session_id to maintain state
return get_transcription_html("", "ready", "Calibrating..."), None, session.session_id
def process_audio(audio, session_id):
"""Process incoming audio and queue for streaming."""
# Check capacity - if at or above max, kill ALL sessions to reset
with _sessions_lock:
active_count = len(_active_sessions)
is_active_user = session_id and any(s.session_id == session_id for s in _active_sessions.values())
with _registry_lock:
registry_count = len(_session_registry)
# Kill all if:
# 1. Registry exceeds limit (memory safety)
# 2. Active sessions exceed limit
# 3. At active capacity AND new user trying to join
if registry_count > MAX_CONCURRENT_SESSIONS or active_count > MAX_CONCURRENT_SESSIONS or (active_count >= MAX_CONCURRENT_SESSIONS and not is_active_user):
kill_all_sessions()
return get_transcription_html(
"Server reset due to capacity. Please click the microphone to restart.",
"error",
""
), None
# Always ensure we have a valid session first
try:
session = ensure_session(session_id)
except Exception as e:
print(f"Error creating session: {e}")
# Create a fresh session if ensure_session fails
session = UserSession()
_session_registry[session.session_id] = session
# Cache session_id early in case of later errors
current_session_id = session.session_id
try:
# Quick return if audio is None
if audio is None:
wpm = session.current_wpm if session.is_running else "Calibrating..."
return get_transcription_html(session.transcription_text, session.status_message, wpm), current_session_id
# Update last audio time for inactivity tracking
session.last_audio_time = time.time()
# Auto-start if not running
if not session.is_running and session.status_message not in ["timeout", "error"]:
auto_start_recording(session)
# Skip processing if session stopped
if not session.is_running:
return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm), current_session_id
sample_rate, audio_data = audio
# Convert to mono if stereo
if len(audio_data.shape) > 1:
audio_data = audio_data.mean(axis=1)
# Normalize to float
if audio_data.dtype == np.int16:
audio_float = audio_data.astype(np.float32) / 32767.0
else:
audio_float = audio_data.astype(np.float32)
# Resample to 16kHz if needed
if sample_rate != SAMPLE_RATE:
num_samples = int(len(audio_float) * SAMPLE_RATE / sample_rate)
audio_float = np.interp(
np.linspace(0, len(audio_float) - 1, num_samples),
np.arange(len(audio_float)),
audio_float,
)
# Convert to PCM16 and base64 encode
pcm16 = (audio_float * 32767).astype(np.int16)
b64_chunk = base64.b64encode(pcm16.tobytes()).decode("utf-8")
# Put directly into thread-safe queue (no event loop needed)
try:
session.audio_queue.put_nowait(b64_chunk)
except Exception:
pass # Skip if queue is full
return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm), current_session_id
except Exception as e:
print(f"Error processing audio: {e}")
# Return safe defaults - always include session_id to maintain state
return get_transcription_html("", "error", ""), current_session_id
# Gradio interface
with gr.Blocks(title="Voxtral Real-time Transcription") as demo:
# Store just the session_id string - much more reliable than complex objects
session_state = gr.State(value=None)
# Header
gr.HTML(get_header_html())
# Transcription output
transcription_display = gr.HTML(
value=get_transcription_html("", "ready", "Calibrating..."),
elem_id="transcription-output"
)
# Audio input
audio_input = gr.Audio(
sources=["microphone"],
streaming=True,
type="numpy",
format="wav",
elem_id="audio-input",
label="Microphone Input"
)
# Clear button
clear_btn = gr.Button(
"Clear History",
elem_classes=["clear-btn"]
)
# Info text
gr.HTML('To start again - click on Clear History AND refresh your website.
')
# Event handlers
clear_btn.click(
clear_history,
inputs=[session_state],
outputs=[transcription_display, audio_input, session_state]
)
audio_input.stop_recording(
stop_session,
inputs=[session_state],
outputs=[transcription_display, session_state]
)
audio_input.stream(
process_audio,
inputs=[audio_input, session_state],
outputs=[transcription_display, session_state],
show_progress="hidden",
concurrency_limit=500,
)
model = os.environ.get("MODEL", "mistralai/Voxtral-Mini-4B-Realtime-2602")
host = os.environ.get("HOST", "")
ws_url = f"wss://{host}/v1/realtime"
get_event_loop()
demo.queue(default_concurrency_limit=200)
demo.launch(css=CUSTOM_CSS, theme=gr.themes.Base(), ssr_mode=False, max_threads=200)