Jofthomas commited on
Commit
a6fa7a0
·
verified ·
1 Parent(s): d54e7c0

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +187 -146
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,16 +1,24 @@
1
-
2
  import asyncio
3
  import base64
4
- import json
5
  import os
6
  import queue
7
  import threading
8
  import time
9
  import uuid
 
10
 
11
  import gradio as gr
12
  import numpy as np
13
- import websockets
 
 
 
 
 
 
 
 
 
14
 
15
  # Load Voxtral icon as base64
16
  VOXTRAL_ICON_B64 = ""
@@ -28,8 +36,8 @@ INACTIVITY_TIMEOUT = int(os.environ.get("INACTIVITY_TIMEOUT", "10")) # Close af
28
  MAX_CONCURRENT_SESSIONS = int(os.environ.get("MAX_SESSIONS", "50"))
29
 
30
  # Global config (shared across users)
31
- ws_url = ""
32
- model = ""
33
 
34
  # Global event loop for all websocket connections (runs in single background thread)
35
  _event_loop = None
@@ -138,14 +146,16 @@ def kill_all_sessions():
138
  session.is_running = False
139
  session._stopped_by_user = True
140
 
141
- # Close websocket immediately
142
- if session._websocket is not None:
143
  loop = get_event_loop()
144
  try:
145
- asyncio.run_coroutine_threadsafe(session._websocket.close(), loop)
 
 
146
  except Exception:
147
  pass
148
- session._websocket = None
149
 
150
  # Cancel the task
151
  if session._task is not None:
@@ -165,6 +175,11 @@ def kill_all_sessions():
165
  print(f"CAPACITY RESET: Killed {killed_count} sessions. All sessions cleared.")
166
 
167
 
 
 
 
 
 
168
  def get_event_loop():
169
  """Get or create the shared event loop."""
170
  global _event_loop, _loop_thread
@@ -186,8 +201,9 @@ def _run_event_loop():
186
 
187
  class UserSession:
188
  """Per-user session state."""
189
- def __init__(self):
190
  self.session_id = str(uuid.uuid4())
 
191
  # Use a thread-safe queue for cross-thread communication
192
  self._audio_queue = queue.Queue(maxsize=200)
193
  self.transcription_text = ""
@@ -199,7 +215,7 @@ class UserSession:
199
  self.last_audio_time = None
200
  self._start_lock = threading.Lock()
201
  self._task = None # Track the async task
202
- self._websocket = None # Store websocket for forced closure
203
  self._stopped_by_user = False # Track if user explicitly stopped
204
 
205
  @property
@@ -228,7 +244,7 @@ def get_header_html() -> str:
228
  return f"""
229
  <div class="header-card">
230
  <h1 class="header-title">{logo_html}Real-time Speech Transcription</h1>
231
- <p class="header-subtitle">Click the microphone to start streaming transcriptions. The system will warm up automatically - so there will be a small delay</p>
232
  <p class="header-subtitle">Talk naturally. Talk fast. Talk ridiculously fast. I can handle it.</p>
233
  </div>
234
  """
@@ -343,126 +359,121 @@ def calculate_wpm(session):
343
  return f"{round(wpm, 1)} WPM"
344
 
345
 
346
- async def send_silence(ws, duration=2.0):
347
- """Send silence to warm up the model."""
348
- num_samples = int(SAMPLE_RATE * duration)
 
 
349
  silence = np.zeros(num_samples, dtype=np.int16)
 
350
 
351
- chunk_size = int(SAMPLE_RATE * 0.1)
352
  for i in range(0, num_samples, chunk_size):
 
 
353
  chunk = silence[i:i + chunk_size]
354
- b64_chunk = base64.b64encode(chunk.tobytes()).decode("utf-8")
355
- await ws.send(
356
- json.dumps(
357
- {"type": "input_audio_buffer.append", "audio": b64_chunk}
358
- )
359
- )
360
  await asyncio.sleep(0.05)
361
-
362
-
363
- async def websocket_handler(session):
364
- """Connect to WebSocket and handle audio streaming + transcription."""
365
- ws = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  try:
367
- # Add connection timeout to prevent hanging
368
- async with asyncio.timeout(10): # 10 second connection timeout
369
- ws = await websockets.connect(ws_url)
 
370
 
371
- # Store websocket reference so it can be closed externally
372
- session._websocket = ws
 
373
 
374
- async with ws:
375
- await asyncio.wait_for(ws.recv(), timeout=5)
376
- await ws.send(json.dumps({"type": "session.update", "model": model}))
377
-
378
- session.status_message = "warming"
379
- await send_silence(ws, WARMUP_DURATION)
380
- await ws.send(json.dumps({"type": "input_audio_buffer.commit"}))
381
- session.status_message = "listening"
382
-
383
- async def send_audio():
384
- while session.is_running:
385
- try:
386
- # Check for inactivity timeout
387
- if session.last_audio_time is not None:
388
- idle = time.time() - session.last_audio_time
389
- if idle >= INACTIVITY_TIMEOUT:
390
- session.is_running = False
391
- session.status_message = "ready"
392
- break
393
-
394
- if session.session_start_time is not None:
395
- elapsed = time.time() - session.session_start_time
396
- if elapsed >= SESSION_TIMEOUT:
397
- session.is_running = False
398
- session.status_message = "timeout"
399
- break
400
-
401
- # Use thread-safe queue with non-blocking get + async sleep
402
- try:
403
- chunk = session.audio_queue.get_nowait()
404
- if session.is_running:
405
- await ws.send(
406
- json.dumps(
407
- {"type": "input_audio_buffer.append", "audio": chunk}
408
- )
409
- )
410
- except queue.Empty:
411
- # No audio available, yield control briefly
412
- await asyncio.sleep(0.05)
413
- continue
414
- except Exception as e:
415
- if session.is_running: # Only log if unexpected
416
- print(f"Error sending audio: {e}")
417
- session.is_running = False
418
- break
419
-
420
- async def receive_transcription():
421
- try:
422
- async for message in ws:
423
- if not session.is_running:
424
- break
425
-
426
- if session.session_start_time is not None:
427
- elapsed = time.time() - session.session_start_time
428
- if elapsed >= SESSION_TIMEOUT:
429
- session.status_message = "timeout"
430
- session.is_running = False
431
- break
432
-
433
- data = json.loads(message)
434
- if data.get("type") == "transcription.delta":
435
- delta = data["delta"]
436
- session.transcription_text += delta
437
-
438
- words = delta.split()
439
- for _ in words:
440
- session.word_timestamps.append(time.time())
441
-
442
- session.current_wpm = calculate_wpm(session)
443
- except asyncio.CancelledError:
444
- pass # Normal cancellation
445
- except Exception as e:
446
- if session.is_running:
447
- print(f"Error receiving transcription: {e}")
448
- session.is_running = False
449
-
450
- await asyncio.gather(send_audio(), receive_transcription(), return_exceptions=True)
451
  except asyncio.CancelledError:
452
  pass # Normal cancellation
453
- except websockets.exceptions.ConnectionClosed:
454
- pass # Normal closure
455
- except asyncio.TimeoutError:
456
- print(f"WebSocket connection timeout for session {session.session_id[:8]}")
457
- session.status_message = "error"
458
  except Exception as e:
459
  error_msg = str(e) if str(e) else type(e).__name__
460
- if "ConnectionReset" not in error_msg: # Suppress common disconnect errors
461
- print(f"WebSocket error: {error_msg}")
462
  session.status_message = "error"
463
  finally:
464
  session.is_running = False
465
- session._websocket = None
466
 
467
  # Only remove and log if not already handled by stop_session
468
  if not session._stopped_by_user:
@@ -473,9 +484,10 @@ async def websocket_handler(session):
473
  print(f"Session {session.session_id[:8]} ended. Active sessions: {active_count}")
474
 
475
 
476
- def start_websocket(session):
477
- """Start WebSocket connection using the shared event loop."""
478
  session.is_running = True
 
479
 
480
  # Register this session
481
  with _sessions_lock:
@@ -486,11 +498,11 @@ def start_websocket(session):
486
 
487
  # Submit to the shared event loop
488
  loop = get_event_loop()
489
- future = asyncio.run_coroutine_threadsafe(websocket_handler(session), loop)
490
  session._task = future
491
 
492
  # Don't block - the coroutine runs in the background
493
- # Cleanup happens in websocket_handler's finally block
494
 
495
 
496
  def ensure_session(session_id):
@@ -522,6 +534,11 @@ def auto_start_recording(session):
522
  if session.is_running:
523
  return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm)
524
 
 
 
 
 
 
525
  # Check if we've hit max concurrent sessions - kill all if so
526
  with _sessions_lock:
527
  active_at_capacity = len(_active_sessions) >= MAX_CONCURRENT_SESSIONS
@@ -540,14 +557,14 @@ def auto_start_recording(session):
540
  session.last_audio_time = time.time()
541
  session.status_message = "connecting"
542
 
543
- # Start websocket (now non-blocking, uses shared event loop)
544
- start_websocket(session)
545
 
546
  return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm)
547
 
548
 
549
- def stop_session(session_id):
550
- """Stop the websocket connection and invalidate the session.
551
 
552
  Returns None for session_id so a fresh session is created on next recording.
553
  This prevents duplicate session issues when users stop and restart quickly.
@@ -561,14 +578,16 @@ def stop_session(session_id):
561
  session.last_audio_time = None
562
  session._stopped_by_user = True # Mark as user-stopped to avoid duplicate logging
563
 
564
- # Close the websocket immediately to force cleanup
565
- if session._websocket is not None:
566
  loop = get_event_loop()
567
  try:
568
- asyncio.run_coroutine_threadsafe(session._websocket.close(), loop)
 
 
569
  except Exception:
570
- pass # Ignore errors during close
571
- session._websocket = None
572
 
573
  # Cancel the running task if any
574
  if session._task is not None:
@@ -590,21 +609,28 @@ def stop_session(session_id):
590
  return get_transcription_html(old_transcript, "ready", old_wpm), None
591
 
592
 
593
- def clear_history(session_id):
594
- """Stop the websocket connection and clear all history."""
 
 
 
 
 
595
  session = ensure_session(session_id)
596
  session.is_running = False
597
  session.last_audio_time = None
598
  session._stopped_by_user = True # Mark as user-stopped
599
 
600
- # Close the websocket immediately
601
- if session._websocket is not None:
602
  loop = get_event_loop()
603
  try:
604
- asyncio.run_coroutine_threadsafe(session._websocket.close(), loop)
 
 
605
  except Exception:
606
  pass
607
- session._websocket = None
608
 
609
  # Cancel the running task if any
610
  if session._task is not None:
@@ -628,7 +654,7 @@ def clear_history(session_id):
628
  return get_transcription_html("", "ready", "Calibrating..."), None, session.session_id
629
 
630
 
631
- def process_audio(audio, session_id):
632
  """Process incoming audio and queue for streaming."""
633
  # Check capacity - if at or above max, kill ALL sessions to reset
634
  with _sessions_lock:
@@ -650,13 +676,23 @@ def process_audio(audio, session_id):
650
  ""
651
  ), None
652
 
 
 
 
 
 
 
 
 
653
  # Always ensure we have a valid session first
654
  try:
655
  session = ensure_session(session_id)
 
 
656
  except Exception as e:
657
  print(f"Error creating session: {e}")
658
  # Create a fresh session if ensure_session fails
659
- session = UserSession()
660
  _session_registry[session.session_id] = session
661
 
662
  # Cache session_id early in case of later errors
@@ -727,6 +763,16 @@ with gr.Blocks(title="Voxtral Real-time Transcription") as demo:
727
  # Header
728
  gr.HTML(get_header_html())
729
 
 
 
 
 
 
 
 
 
 
 
730
  # Transcription output
731
  transcription_display = gr.HTML(
732
  value=get_transcription_html("", "ready", "Calibrating..."),
@@ -755,30 +801,25 @@ with gr.Blocks(title="Voxtral Real-time Transcription") as demo:
755
  # Event handlers
756
  clear_btn.click(
757
  clear_history,
758
- inputs=[session_state],
759
  outputs=[transcription_display, audio_input, session_state]
760
  )
761
 
762
 
763
  audio_input.stop_recording(
764
  stop_session,
765
- inputs=[session_state],
766
  outputs=[transcription_display, session_state]
767
  )
768
 
769
  audio_input.stream(
770
  process_audio,
771
- inputs=[audio_input, session_state],
772
  outputs=[transcription_display, session_state],
773
  show_progress="hidden",
774
  concurrency_limit=500,
775
  )
776
 
777
- model = os.environ.get("MODEL", "mistralai/Voxtral-Mini-4B-Realtime-2602")
778
- host = os.environ.get("HOST", "")
779
-
780
- ws_url = f"wss://{host}/v1/realtime"
781
-
782
  get_event_loop()
783
 
784
  demo.queue(default_concurrency_limit=200)
 
 
1
  import asyncio
2
  import base64
 
3
  import os
4
  import queue
5
  import threading
6
  import time
7
  import uuid
8
+ from typing import AsyncIterator
9
 
10
  import gradio as gr
11
  import numpy as np
12
+
13
+ from mistralai import Mistral
14
+ from mistralai.extra.realtime import UnknownRealtimeEvent
15
+ from mistralai.models import (
16
+ AudioFormat,
17
+ RealtimeTranscriptionError,
18
+ RealtimeTranscriptionSessionCreated,
19
+ TranscriptionStreamDone,
20
+ TranscriptionStreamTextDelta,
21
+ )
22
 
23
  # Load Voxtral icon as base64
24
  VOXTRAL_ICON_B64 = ""
 
36
  MAX_CONCURRENT_SESSIONS = int(os.environ.get("MAX_SESSIONS", "50"))
37
 
38
  # Global config (shared across users)
39
+ MISTRAL_BASE_URL = "wss://api.mistral.ai"
40
+ MODEL = "voxtral-mini-transcribe-realtime-2602"
41
 
42
  # Global event loop for all websocket connections (runs in single background thread)
43
  _event_loop = None
 
146
  session.is_running = False
147
  session._stopped_by_user = True
148
 
149
+ # Signal stop event
150
+ if session._stop_event is not None:
151
  loop = get_event_loop()
152
  try:
153
+ asyncio.run_coroutine_threadsafe(
154
+ _set_stop_event_sync(session._stop_event), loop
155
+ )
156
  except Exception:
157
  pass
158
+ session._stop_event = None
159
 
160
  # Cancel the task
161
  if session._task is not None:
 
175
  print(f"CAPACITY RESET: Killed {killed_count} sessions. All sessions cleared.")
176
 
177
 
178
+ async def _set_stop_event_sync(event):
179
+ """Helper to set asyncio event."""
180
+ event.set()
181
+
182
+
183
  def get_event_loop():
184
  """Get or create the shared event loop."""
185
  global _event_loop, _loop_thread
 
201
 
202
  class UserSession:
203
  """Per-user session state."""
204
+ def __init__(self, api_key: str = None):
205
  self.session_id = str(uuid.uuid4())
206
+ self.api_key = api_key
207
  # Use a thread-safe queue for cross-thread communication
208
  self._audio_queue = queue.Queue(maxsize=200)
209
  self.transcription_text = ""
 
215
  self.last_audio_time = None
216
  self._start_lock = threading.Lock()
217
  self._task = None # Track the async task
218
+ self._stop_event = None # Event to signal stop
219
  self._stopped_by_user = False # Track if user explicitly stopped
220
 
221
  @property
 
244
  return f"""
245
  <div class="header-card">
246
  <h1 class="header-title">{logo_html}Real-time Speech Transcription</h1>
247
+ <p class="header-subtitle">Enter your Mistral API key below, then click the microphone to start streaming transcriptions.</p>
248
  <p class="header-subtitle">Talk naturally. Talk fast. Talk ridiculously fast. I can handle it.</p>
249
  </div>
250
  """
 
359
  return f"{round(wpm, 1)} WPM"
360
 
361
 
362
+ async def audio_stream_from_queue(session) -> AsyncIterator[bytes]:
363
+ """Async generator that yields audio bytes from the session queue."""
364
+ # First, send silence for warmup
365
+ session.status_message = "warming"
366
+ num_samples = int(SAMPLE_RATE * WARMUP_DURATION)
367
  silence = np.zeros(num_samples, dtype=np.int16)
368
+ chunk_size = int(SAMPLE_RATE * 0.1) # 100ms chunks
369
 
 
370
  for i in range(0, num_samples, chunk_size):
371
+ if not session.is_running:
372
+ return
373
  chunk = silence[i:i + chunk_size]
374
+ yield chunk.tobytes()
 
 
 
 
 
375
  await asyncio.sleep(0.05)
376
+
377
+ session.status_message = "listening"
378
+
379
+ # Then stream real audio from the queue
380
+ while session.is_running:
381
+ # Check for inactivity timeout
382
+ if session.last_audio_time is not None:
383
+ idle = time.time() - session.last_audio_time
384
+ if idle >= INACTIVITY_TIMEOUT:
385
+ session.is_running = False
386
+ session.status_message = "ready"
387
+ return
388
+
389
+ # Check for session timeout
390
+ if session.session_start_time is not None:
391
+ elapsed = time.time() - session.session_start_time
392
+ if elapsed >= SESSION_TIMEOUT:
393
+ session.is_running = False
394
+ session.status_message = "timeout"
395
+ return
396
+
397
+ # Check if stop was requested
398
+ if session._stop_event and session._stop_event.is_set():
399
+ return
400
+
401
+ # Get audio from queue
402
+ try:
403
+ # The queue contains base64-encoded PCM16 audio
404
+ b64_chunk = session.audio_queue.get_nowait()
405
+ # Decode base64 to raw bytes
406
+ audio_bytes = base64.b64decode(b64_chunk)
407
+ yield audio_bytes
408
+ except queue.Empty:
409
+ # No audio available, yield control briefly
410
+ await asyncio.sleep(0.05)
411
+ continue
412
+
413
+
414
+ async def mistral_transcription_handler(session):
415
+ """Connect to Mistral realtime API and handle transcription."""
416
  try:
417
+ if not session.api_key:
418
+ session.status_message = "error"
419
+ print(f"Session {session.session_id[:8]}: No API key provided")
420
+ return
421
 
422
+ # Create Mistral client
423
+ client = Mistral(api_key=session.api_key, server_url=MISTRAL_BASE_URL)
424
+ audio_format = AudioFormat(encoding="pcm_s16le", sample_rate=SAMPLE_RATE)
425
 
426
+ session.status_message = "connecting"
427
+
428
+ # Create the audio stream generator
429
+ audio_stream = audio_stream_from_queue(session)
430
+
431
+ print(f"Session {session.session_id[:8]}: Connecting to Mistral realtime API...")
432
+
433
+ async for event in client.audio.realtime.transcribe_stream(
434
+ audio_stream=audio_stream,
435
+ model=MODEL,
436
+ audio_format=audio_format,
437
+ ):
438
+ if not session.is_running:
439
+ break
440
+
441
+ if isinstance(event, RealtimeTranscriptionSessionCreated):
442
+ print(f"Session {session.session_id[:8]}: Connected to Mistral")
443
+ # Status is already set by audio_stream_from_queue
444
+
445
+ elif isinstance(event, TranscriptionStreamTextDelta):
446
+ delta = event.text
447
+ session.transcription_text += delta
448
+
449
+ # Track words for WPM calculation
450
+ words = delta.split()
451
+ for _ in words:
452
+ session.word_timestamps.append(time.time())
453
+
454
+ session.current_wpm = calculate_wpm(session)
455
+
456
+ elif isinstance(event, TranscriptionStreamDone):
457
+ print(f"Session {session.session_id[:8]}: Transcription done")
458
+ break
459
+
460
+ elif isinstance(event, RealtimeTranscriptionError):
461
+ print(f"Session {session.session_id[:8]}: Error - {event.error}")
462
+ session.status_message = "error"
463
+ break
464
+
465
+ elif isinstance(event, UnknownRealtimeEvent):
466
+ continue # Ignore unknown events
467
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
  except asyncio.CancelledError:
469
  pass # Normal cancellation
 
 
 
 
 
470
  except Exception as e:
471
  error_msg = str(e) if str(e) else type(e).__name__
472
+ if "ConnectionReset" not in error_msg and "CancelledError" not in error_msg:
473
+ print(f"Session {session.session_id[:8]}: Mistral API error - {error_msg}")
474
  session.status_message = "error"
475
  finally:
476
  session.is_running = False
 
477
 
478
  # Only remove and log if not already handled by stop_session
479
  if not session._stopped_by_user:
 
484
  print(f"Session {session.session_id[:8]} ended. Active sessions: {active_count}")
485
 
486
 
487
+ def start_transcription(session):
488
+ """Start Mistral transcription using the shared event loop."""
489
  session.is_running = True
490
+ session._stop_event = asyncio.Event()
491
 
492
  # Register this session
493
  with _sessions_lock:
 
498
 
499
  # Submit to the shared event loop
500
  loop = get_event_loop()
501
+ future = asyncio.run_coroutine_threadsafe(mistral_transcription_handler(session), loop)
502
  session._task = future
503
 
504
  # Don't block - the coroutine runs in the background
505
+ # Cleanup happens in mistral_transcription_handler's finally block
506
 
507
 
508
  def ensure_session(session_id):
 
534
  if session.is_running:
535
  return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm)
536
 
537
+ # Check if API key is set
538
+ if not session.api_key:
539
+ session.status_message = "error"
540
+ return get_transcription_html("Please enter your Mistral API key above to start transcription.", "error", "")
541
+
542
  # Check if we've hit max concurrent sessions - kill all if so
543
  with _sessions_lock:
544
  active_at_capacity = len(_active_sessions) >= MAX_CONCURRENT_SESSIONS
 
557
  session.last_audio_time = time.time()
558
  session.status_message = "connecting"
559
 
560
+ # Start Mistral transcription (now non-blocking, uses shared event loop)
561
+ start_transcription(session)
562
 
563
  return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm)
564
 
565
 
566
+ def stop_session(session_id, api_key=None):
567
+ """Stop the transcription and invalidate the session.
568
 
569
  Returns None for session_id so a fresh session is created on next recording.
570
  This prevents duplicate session issues when users stop and restart quickly.
 
578
  session.last_audio_time = None
579
  session._stopped_by_user = True # Mark as user-stopped to avoid duplicate logging
580
 
581
+ # Signal the stop event to terminate the audio stream
582
+ if session._stop_event is not None:
583
  loop = get_event_loop()
584
  try:
585
+ asyncio.run_coroutine_threadsafe(
586
+ _set_stop_event(session._stop_event), loop
587
+ )
588
  except Exception:
589
+ pass
590
+ session._stop_event = None
591
 
592
  # Cancel the running task if any
593
  if session._task is not None:
 
609
  return get_transcription_html(old_transcript, "ready", old_wpm), None
610
 
611
 
612
+ async def _set_stop_event(event):
613
+ """Helper to set asyncio event from sync context."""
614
+ event.set()
615
+
616
+
617
+ def clear_history(session_id, api_key=None):
618
+ """Stop the transcription and clear all history."""
619
  session = ensure_session(session_id)
620
  session.is_running = False
621
  session.last_audio_time = None
622
  session._stopped_by_user = True # Mark as user-stopped
623
 
624
+ # Signal the stop event
625
+ if session._stop_event is not None:
626
  loop = get_event_loop()
627
  try:
628
+ asyncio.run_coroutine_threadsafe(
629
+ _set_stop_event(session._stop_event), loop
630
+ )
631
  except Exception:
632
  pass
633
+ session._stop_event = None
634
 
635
  # Cancel the running task if any
636
  if session._task is not None:
 
654
  return get_transcription_html("", "ready", "Calibrating..."), None, session.session_id
655
 
656
 
657
+ def process_audio(audio, session_id, api_key):
658
  """Process incoming audio and queue for streaming."""
659
  # Check capacity - if at or above max, kill ALL sessions to reset
660
  with _sessions_lock:
 
676
  ""
677
  ), None
678
 
679
+ # Check if API key is provided
680
+ if not api_key or not api_key.strip():
681
+ return get_transcription_html(
682
+ "Please enter your Mistral API key above to start transcription.",
683
+ "error",
684
+ ""
685
+ ), None
686
+
687
  # Always ensure we have a valid session first
688
  try:
689
  session = ensure_session(session_id)
690
+ # Update API key on the session
691
+ session.api_key = api_key.strip()
692
  except Exception as e:
693
  print(f"Error creating session: {e}")
694
  # Create a fresh session if ensure_session fails
695
+ session = UserSession(api_key=api_key.strip())
696
  _session_registry[session.session_id] = session
697
 
698
  # Cache session_id early in case of later errors
 
763
  # Header
764
  gr.HTML(get_header_html())
765
 
766
+ # API Key input
767
+ with gr.Row():
768
+ api_key_input = gr.Textbox(
769
+ label="Mistral API Key",
770
+ placeholder="Enter your Mistral API key...",
771
+ type="password",
772
+ elem_id="api-key-input",
773
+ info="Get your API key from console.mistral.ai"
774
+ )
775
+
776
  # Transcription output
777
  transcription_display = gr.HTML(
778
  value=get_transcription_html("", "ready", "Calibrating..."),
 
801
  # Event handlers
802
  clear_btn.click(
803
  clear_history,
804
+ inputs=[session_state, api_key_input],
805
  outputs=[transcription_display, audio_input, session_state]
806
  )
807
 
808
 
809
  audio_input.stop_recording(
810
  stop_session,
811
+ inputs=[session_state, api_key_input],
812
  outputs=[transcription_display, session_state]
813
  )
814
 
815
  audio_input.stream(
816
  process_audio,
817
+ inputs=[audio_input, session_state, api_key_input],
818
  outputs=[transcription_display, session_state],
819
  show_progress="hidden",
820
  concurrency_limit=500,
821
  )
822
 
 
 
 
 
 
823
  get_event_loop()
824
 
825
  demo.queue(default_concurrency_limit=200)
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  gradio>=4.0.0
2
  websockets
3
  numpy
 
 
1
  gradio>=4.0.0
2
  websockets
3
  numpy
4
+ mistralai[realtime]