Delay support with partial transcript

#4
by pandora-s - opened
Files changed (3) hide show
  1. app.py +498 -80
  2. requirements.txt +1 -1
  3. style.css +73 -1
app.py CHANGED
@@ -6,6 +6,7 @@ import threading
6
  import time
7
  import uuid
8
  from typing import AsyncIterator
 
9
 
10
  import gradio as gr
11
  import numpy as np
@@ -56,6 +57,7 @@ _last_cleanup = time.time()
56
  SESSION_REGISTRY_CLEANUP_INTERVAL = 90 # seconds
57
  SESSION_MAX_AGE = 90 # 90 seconds - remove sessions older than this
58
 
 
59
 
60
  def get_or_create_session(session_id: str = None) -> "UserSession":
61
  """Get existing session by ID or create a new one."""
@@ -202,12 +204,13 @@ def _run_event_loop():
202
 
203
  class UserSession:
204
  """Per-user session state."""
205
- def __init__(self, api_key: str = None):
206
  self.session_id = str(uuid.uuid4())
207
  self.api_key = api_key
 
208
  # Use a thread-safe queue for cross-thread communication
209
  self._audio_queue = queue.Queue(maxsize=200)
210
- self.transcription_text = ""
211
  self.is_running = False
212
  self.status_message = "ready"
213
  self.word_timestamps = []
@@ -218,6 +221,15 @@ class UserSession:
218
  self._task = None # Track the async task
219
  self._stop_event = None # Event to signal stop
220
  self._stopped_by_user = False # Track if user explicitly stopped
 
 
 
 
 
 
 
 
 
221
 
222
  @property
223
  def audio_queue(self):
@@ -227,6 +239,96 @@ class UserSession:
227
  def reset_queue(self):
228
  """Reset the audio queue."""
229
  self._audio_queue = queue.Queue(maxsize=200)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
 
232
  # Load CSS from external file
@@ -267,18 +369,81 @@ def get_status_html(status: str) -> str:
267
  return f"""<div class="status-badge {css_class}"><span class="status-dot{dot_anim}"></span><span style="color: inherit !important;">{label}</span></div>"""
268
 
269
 
270
- def get_transcription_html(transcript: str, status: str, wpm: str = "Calibrating...") -> str:
271
  """Generate the full transcription card HTML."""
272
  status_badge = get_status_html(status)
273
  wpm_badge = f'<div class="wpm-badge"><span style="color: #1E1E1E !important;">{wpm}</span></div>'
274
 
275
- if transcript:
276
- cursor_html = '<span class="transcript-cursor"></span>' if status == "listening" else ""
277
- content_html = f"""
278
- <div class="transcript-text" style="color: #000000 !important;">
279
- {transcript}{cursor_html}
280
- </div>
281
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
  elif status in ["listening", "warming", "connecting"]:
283
  content_html = """
284
  <div class="empty-state">
@@ -412,8 +577,101 @@ async def audio_stream_from_queue(session) -> AsyncIterator[bytes]:
412
  continue
413
 
414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
  async def mistral_transcription_handler(session):
416
- """Connect to Mistral realtime API and handle transcription."""
417
  try:
418
  if not session.api_key:
419
  session.status_message = "error"
@@ -426,46 +684,187 @@ async def mistral_transcription_handler(session):
426
 
427
  session.status_message = "connecting"
428
 
429
- # Create the audio stream generator
430
- audio_stream = audio_stream_from_queue(session)
431
-
432
  print(f"Session {session.session_id[:8]}: Connecting to Mistral realtime API...")
433
 
434
- async for event in client.audio.realtime.transcribe_stream(
435
- audio_stream=audio_stream,
436
- model=MODEL,
437
- audio_format=audio_format,
438
- ):
439
- if not session.is_running:
440
- break
441
-
442
- if isinstance(event, RealtimeTranscriptionSessionCreated):
443
- print(f"Session {session.session_id[:8]}: Connected to Mistral")
444
- # Status is already set by audio_stream_from_queue
445
-
446
- elif isinstance(event, TranscriptionStreamTextDelta):
447
- delta = event.text
448
- session.transcription_text += delta
 
 
 
 
449
 
450
- # Track words for WPM calculation
451
- words = delta.split()
452
- for _ in words:
453
- session.word_timestamps.append(time.time())
454
 
455
- session.current_wpm = calculate_wpm(session)
456
-
457
- elif isinstance(event, TranscriptionStreamDone):
458
- print(f"Session {session.session_id[:8]}: Transcription done")
459
- break
460
-
461
- elif isinstance(event, RealtimeTranscriptionError):
462
- print(f"Session {session.session_id[:8]}: Error - {event.error}")
463
- session.status_message = "error"
464
- break
465
-
466
- elif isinstance(event, UnknownRealtimeEvent):
467
- continue # Ignore unknown events
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  except asyncio.CancelledError:
470
  pass # Normal cancellation
471
  except Exception as e:
@@ -533,12 +932,12 @@ def auto_start_recording(session):
533
  # Protect against startup races: Gradio can call `process_audio` concurrently.
534
  with session._start_lock:
535
  if session.is_running:
536
- return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm)
537
 
538
  # Check if API key is set
539
  if not session.api_key:
540
  session.status_message = "error"
541
- return get_transcription_html("Please enter your Mistral API key above to start transcription.", "error", "")
542
 
543
  # Check if we've hit max concurrent sessions - kill all if so
544
  with _sessions_lock:
@@ -549,29 +948,32 @@ def auto_start_recording(session):
549
  if active_at_capacity or registry_over:
550
  kill_all_sessions()
551
  session.status_message = "error"
552
- return get_transcription_html("Server reset due to capacity. Please click the microphone to restart.", "error", "")
553
 
554
- session.transcription_text = ""
555
  session.word_timestamps = []
556
  session.current_wpm = "Calibrating..."
557
  session.session_start_time = time.time()
558
  session.last_audio_time = time.time()
559
  session.status_message = "connecting"
 
 
 
 
560
 
561
  # Start Mistral transcription (now non-blocking, uses shared event loop)
562
  start_transcription(session)
563
 
564
- return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm)
565
 
566
 
567
- def stop_session(session_id, api_key=None):
568
  """Stop the transcription and invalidate the session.
569
 
570
  Returns None for session_id so a fresh session is created on next recording.
571
  This prevents duplicate session issues when users stop and restart quickly.
572
  """
573
  session = ensure_session(session_id)
574
- old_transcript = session.transcription_text
575
  old_wpm = session.current_wpm
576
 
577
  if session.is_running:
@@ -607,7 +1009,7 @@ def stop_session(session_id, api_key=None):
607
 
608
  # Return None for session_id - a fresh session will be created on next recording
609
  # This ensures no duplicate sessions when users stop/start quickly
610
- return get_transcription_html(old_transcript, "ready", old_wpm), None
611
 
612
 
613
  async def _set_stop_event(event):
@@ -615,7 +1017,7 @@ async def _set_stop_event(event):
615
  event.set()
616
 
617
 
618
- def clear_history(session_id, api_key=None):
619
  """Stop the transcription and clear all history."""
620
  session = ensure_session(session_id)
621
  session.is_running = False
@@ -645,17 +1047,23 @@ def clear_history(session_id, api_key=None):
645
  # Reset the queue
646
  session.reset_queue()
647
 
648
- session.transcription_text = ""
 
 
649
  session.word_timestamps = []
650
  session.current_wpm = "Calibrating..."
651
  session.session_start_time = None
652
  session.status_message = "ready"
 
 
 
 
653
 
654
  # Return the session_id to maintain state
655
- return get_transcription_html("", "ready", "Calibrating..."), None, session.session_id
656
 
657
 
658
- def process_audio(audio, session_id, api_key):
659
  """Process incoming audio and queue for streaming."""
660
  # Check capacity - if at or above max, kill ALL sessions to reset
661
  with _sessions_lock:
@@ -672,28 +1080,33 @@ def process_audio(audio, session_id, api_key):
672
  if registry_count > MAX_CONCURRENT_SESSIONS or active_count > MAX_CONCURRENT_SESSIONS or (active_count >= MAX_CONCURRENT_SESSIONS and not is_active_user):
673
  kill_all_sessions()
674
  return get_transcription_html(
675
- "Server reset due to capacity. Please click the microphone to restart.",
676
  "error",
677
- ""
 
678
  ), None
679
 
680
  # Check if API key is provided
681
  if not api_key or not api_key.strip():
682
- return get_transcription_html(
683
- "Please enter your Mistral API key above to start transcription.",
684
- "error",
685
- ""
686
- ), None
 
687
 
688
  # Always ensure we have a valid session first
689
  try:
690
  session = ensure_session(session_id)
691
  # Update API key on the session
692
  session.api_key = api_key.strip()
 
 
693
  except Exception as e:
694
  print(f"Error creating session: {e}")
695
  # Create a fresh session if ensure_session fails
696
  session = UserSession(api_key=api_key.strip())
 
697
  _session_registry[session.session_id] = session
698
 
699
  # Cache session_id early in case of later errors
@@ -703,7 +1116,7 @@ def process_audio(audio, session_id, api_key):
703
  # Quick return if audio is None
704
  if audio is None:
705
  wpm = session.current_wpm if session.is_running else "Calibrating..."
706
- return get_transcription_html(session.transcription_text, session.status_message, wpm), current_session_id
707
 
708
  # Update last audio time for inactivity tracking
709
  session.last_audio_time = time.time()
@@ -714,7 +1127,7 @@ def process_audio(audio, session_id, api_key):
714
 
715
  # Skip processing if session stopped
716
  if not session.is_running:
717
- return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm), current_session_id
718
 
719
  sample_rate, audio_data = audio
720
 
@@ -747,14 +1160,11 @@ def process_audio(audio, session_id, api_key):
747
  except Exception:
748
  pass # Skip if queue is full
749
 
750
- return get_transcription_html(session.transcription_text, session.status_message, session.current_wpm), current_session_id
751
  except Exception as e:
752
  print(f"Error processing audio: {e}")
753
  # Return safe defaults - always include session_id to maintain state
754
- return get_transcription_html("", "error", ""), current_session_id
755
-
756
-
757
-
758
 
759
  # Gradio interface
760
  with gr.Blocks(title="Voxtral Real-time Transcription") as demo:
@@ -764,19 +1174,27 @@ with gr.Blocks(title="Voxtral Real-time Transcription") as demo:
764
  # Header
765
  gr.HTML(get_header_html())
766
 
767
- # API Key input
768
  with gr.Row():
769
  api_key_input = gr.Textbox(
770
- label="Mistral API Key",
771
- placeholder="Enter your Mistral API key...",
772
  type="password",
773
  elem_id="api-key-input",
774
- info="Get your API key from console.mistral.ai"
 
 
 
 
 
 
 
 
775
  )
776
 
777
  # Transcription output
778
  transcription_display = gr.HTML(
779
- value=get_transcription_html("", "ready", "Calibrating..."),
780
  elem_id="transcription-output"
781
  )
782
 
@@ -802,20 +1220,20 @@ with gr.Blocks(title="Voxtral Real-time Transcription") as demo:
802
  # Event handlers
803
  clear_btn.click(
804
  clear_history,
805
- inputs=[session_state, api_key_input],
806
  outputs=[transcription_display, audio_input, session_state]
807
  )
808
 
809
 
810
  audio_input.stop_recording(
811
  stop_session,
812
- inputs=[session_state, api_key_input],
813
  outputs=[transcription_display, session_state]
814
  )
815
 
816
  audio_input.stream(
817
  process_audio,
818
- inputs=[audio_input, session_state, api_key_input],
819
  outputs=[transcription_display, session_state],
820
  show_progress="hidden",
821
  concurrency_limit=500,
 
6
  import time
7
  import uuid
8
  from typing import AsyncIterator
9
+ import difflib
10
 
11
  import gradio as gr
12
  import numpy as np
 
57
  SESSION_REGISTRY_CLEANUP_INTERVAL = 90 # seconds
58
  SESSION_MAX_AGE = 90 # 90 seconds - remove sessions older than this
59
 
60
+ DEFAULT_API_KEY = os.environ.get("DEFAULT_API_KEY", "")
61
 
62
  def get_or_create_session(session_id: str = None) -> "UserSession":
63
  """Get existing session by ID or create a new one."""
 
204
 
205
  class UserSession:
206
  """Per-user session state."""
207
+ def __init__(self, api_key: str = ""):
208
  self.session_id = str(uuid.uuid4())
209
  self.api_key = api_key
210
+ self.partial_transcript_enabled = False # Default to disabled
211
  # Use a thread-safe queue for cross-thread communication
212
  self._audio_queue = queue.Queue(maxsize=200)
213
+ self.transcription_tuple = ("", "", "") # For 3 streams
214
  self.is_running = False
215
  self.status_message = "ready"
216
  self.word_timestamps = []
 
221
  self._task = None # Track the async task
222
  self._stop_event = None # Event to signal stop
223
  self._stopped_by_user = False # Track if user explicitly stopped
224
+ self.new_color_open = '<span style="color: #FFA500";>'
225
+ self.new_color_close = "</span>"
226
+
227
+ # Enhanced event tracking
228
+ self.stream_events = {
229
+ 'stream_1': [], # List of (timestamp, event_type, event_data) tuples
230
+ 'stream_2': [] # List of (timestamp, event_type, event_data) tuples
231
+ }
232
+ self.last_event_timestamp = None
233
 
234
  @property
235
  def audio_queue(self):
 
239
  def reset_queue(self):
240
  """Reset the audio queue."""
241
  self._audio_queue = queue.Queue(maxsize=200)
242
+
243
+ def get_event_summary(self):
244
+ """Get a summary of all stream events with timestamps."""
245
+ summary = {
246
+ 'stream_1': [],
247
+ 'stream_2': [],
248
+ 'stats': {
249
+ 'stream_1_count': len(self.stream_events['stream_1']),
250
+ 'stream_2_count': len(self.stream_events['stream_2']),
251
+ 'last_event_time': self.last_event_timestamp,
252
+ 'total_events': len(self.stream_events['stream_1']) + len(self.stream_events['stream_2'])
253
+ }
254
+ }
255
+
256
+ for stream_name in ['stream_1', 'stream_2']:
257
+ for event in self.stream_events[stream_name]:
258
+ summary[stream_name].append({
259
+ 'timestamp': event.get('timestamp', 0),
260
+ 'type': event.get('type', 'unknown'),
261
+ 'data': {k: v for k, v in event.items() if k not in ['timestamp', 'type']}
262
+ })
263
+
264
+ return summary
265
+
266
+ def clear_events(self):
267
+ """Clear all event history."""
268
+ self.stream_events = {
269
+ 'stream_1': [],
270
+ 'stream_2': []
271
+ }
272
+ self.last_event_timestamp = None
273
+ self.transcription_tuple = ("", "", "")
274
+
275
+ @staticmethod
276
+ def _normalize_word(word: str) -> str:
277
+ return word.strip(".,!?;:\"'()[]{}").lower()
278
+
279
+ def _compute_display_texts(self, slow_text, fast_text) -> tuple[str, str]:
280
+ slow_words = slow_text.split()
281
+ fast_words = fast_text.split()
282
+
283
+ if not slow_words:
284
+ partial_text = f" {fast_text}".rstrip()
285
+ return "", partial_text
286
+
287
+ slow_norm = [self._normalize_word(word) for word in slow_words]
288
+ fast_norm = [self._normalize_word(word) for word in fast_words]
289
+
290
+ matcher = difflib.SequenceMatcher(None, slow_norm, fast_norm)
291
+ last_fast_index = 0
292
+ slow_progress = 0
293
+ for block in matcher.get_matching_blocks():
294
+ if block.size == 0:
295
+ continue
296
+ slow_end = block.a + block.size
297
+ if slow_end > slow_progress:
298
+ slow_progress = slow_end
299
+ last_fast_index = block.b + block.size
300
+
301
+ if last_fast_index < len(fast_words):
302
+ ahead_words = fast_words[last_fast_index:]
303
+ partial_text = " " + " ".join(ahead_words) if ahead_words else ""
304
+ else:
305
+ partial_text = ""
306
+
307
+ return slow_text, partial_text
308
+
309
+ def reconstruct_transcription(self):
310
+ """Reconstruct transcription text from stream events."""
311
+ stream1_text = ""
312
+ stream2_text = ""
313
+
314
+ # Reconstruct from text_delta events
315
+ for event in self.stream_events['stream_1']:
316
+ if event.get('type') == 'text_delta':
317
+ stream1_text += event.get('text', '')
318
+
319
+ for event in self.stream_events['stream_2']:
320
+ if event.get('type') == 'text_delta':
321
+ stream2_text += event.get('text', '')
322
+
323
+ # Stream 3
324
+ stream3_final = stream2_text
325
+ stream3_preview = stream1_text
326
+
327
+ stream3_final, stream3_preview = self._compute_display_texts(stream3_final, stream3_preview)
328
+ stream3_text = stream3_final + self.new_color_open + stream3_preview + self.new_color_close
329
+
330
+ # Return as tuple for compatibility with HTML function
331
+ return (stream1_text, stream2_text, stream3_text)
332
 
333
 
334
  # Load CSS from external file
 
369
  return f"""<div class="status-badge {css_class}"><span class="status-dot{dot_anim}"></span><span style="color: inherit !important;">{label}</span></div>"""
370
 
371
 
372
+ def get_transcription_html(transcripts: tuple, status: str, wpm: str = "Calibrating...", partial_transcript_enabled: bool = False) -> str:
373
  """Generate the full transcription card HTML."""
374
  status_badge = get_status_html(status)
375
  wpm_badge = f'<div class="wpm-badge"><span style="color: #1E1E1E !important;">{wpm}</span></div>'
376
 
377
+ if transcripts:
378
+ # Check if partial transcript is enabled and we have 3 streams
379
+ if partial_transcript_enabled and len(transcripts) >= 3 and transcripts[0] and transcripts[1] and transcripts[2]:
380
+ # Split into three streams
381
+ stream1_content, stream2_content, stream3_content = transcripts
382
+
383
+ cursor_html = '<span class="transcript-cursor"></span>' if status == "listening" else ""
384
+ content_html = f"""
385
+ <div class="triple-stream-container">
386
+ <div class="stream-box">
387
+ <div class="stream-label">Stream 1 (Preview - 240ms)</div>
388
+ <div class="transcript-text" style="color: #000000 !important;">
389
+ {stream1_content}{cursor_html}
390
+ </div>
391
+ </div>
392
+ <div class="stream-box">
393
+ <div class="stream-label">Stream 2 (Final - 2.4s)</div>
394
+ <div class="transcript-text" style="color: #000000 !important;">
395
+ {stream2_content}{cursor_html}
396
+ </div>
397
+ </div>
398
+ <div class="stream-box">
399
+ <div class="stream-label">Stream 3 (Merged)</div>
400
+ <div class="transcript-text" style="color: #000000 !important;">
401
+ {stream3_content}{cursor_html}
402
+ </div>
403
+ </div>
404
+ </div>
405
+ """
406
+ # Check if we have 3 streams (backward compatibility for when partial transcript is disabled)
407
+ elif len(transcripts) >= 3 and transcripts[0] and transcripts[1] and transcripts[2]:
408
+ # Show only the merged stream when partial transcript is disabled
409
+ stream3_content = transcripts[2]
410
+
411
+ cursor_html = '<span class="transcript-cursor"></span>' if status == "listening" else ""
412
+ content_html = f"""
413
+ <div class="transcript-text" style="color: #000000 !important;">
414
+ {stream3_content}{cursor_html}
415
+ </div>
416
+ """
417
+ # Check if transcript contains both streams (backward compatibility)
418
+ elif transcripts[0] and transcripts[1]:
419
+ # Split the transcript into two streams
420
+ stream1_content, stream2_content = transcripts
421
+
422
+ cursor_html = '<span class="transcript-cursor"></span>' if status == "listening" else ""
423
+ content_html = f"""
424
+ <div class="dual-stream-container">
425
+ <div class="stream-box">
426
+ <div class="stream-label">Stream 1</div>
427
+ <div class="transcript-text" style="color: #000000 !important;">
428
+ {stream1_content}{cursor_html}
429
+ </div>
430
+ </div>
431
+ <div class="stream-box">
432
+ <div class="stream-label">Stream 2</div>
433
+ <div class="transcript-text" style="color: #000000 !important;">
434
+ {stream2_content}{cursor_html}
435
+ </div>
436
+ </div>
437
+ </div>
438
+ """
439
+ else:
440
+ # Single stream (backward compatibility)
441
+ cursor_html = '<span class="transcript-cursor"></span>' if status == "listening" else ""
442
+ content_html = f"""
443
+ <div class="transcript-text" style="color: #000000 !important;">
444
+ {transcripts[0]}{cursor_html}
445
+ </div>
446
+ """
447
  elif status in ["listening", "warming", "connecting"]:
448
  content_html = """
449
  <div class="empty-state">
 
577
  continue
578
 
579
 
580
+ class AudioStreamDuplicator:
581
+ """Duplicates an audio stream so it can be consumed by multiple consumers."""
582
+
583
+ def __init__(self, session):
584
+ self.session = session
585
+ self.consumers = []
586
+ self.buffer = []
587
+ self.consumer_positions = {} # Track position for each consumer
588
+ self.lock = asyncio.Lock()
589
+
590
+ async def add_consumer(self):
591
+ """Add a new consumer to the duplicator."""
592
+ consumer_id = len(self.consumers)
593
+ self.consumers.append(consumer_id)
594
+ self.consumer_positions[consumer_id] = 0 # Start at beginning
595
+ return self._create_consumer_stream(consumer_id)
596
+
597
+ async def _create_consumer_stream(self, consumer_id):
598
+ """Create a stream for a specific consumer."""
599
+ # First yield warmup silence for this consumer
600
+ num_samples = int(SAMPLE_RATE * WARMUP_DURATION)
601
+ silence = np.zeros(num_samples, dtype=np.int16)
602
+ chunk_size = int(SAMPLE_RATE * 0.1) # 100ms chunks
603
+
604
+ for i in range(0, num_samples, chunk_size):
605
+ if not self.session.is_running:
606
+ return
607
+ chunk = silence[i:i + chunk_size]
608
+ yield chunk.tobytes()
609
+ await asyncio.sleep(0.05)
610
+
611
+ # Then stream from the shared buffer
612
+ while self.session.is_running:
613
+ # Check for inactivity timeout
614
+ if self.session.last_audio_time is not None:
615
+ idle = time.time() - self.session.last_audio_time
616
+ if idle >= INACTIVITY_TIMEOUT:
617
+ self.session.is_running = False
618
+ self.session.status_message = "ready"
619
+ return
620
+
621
+ # Check for session timeout
622
+ if self.session.session_start_time is not None:
623
+ elapsed = time.time() - self.session.session_start_time
624
+ if elapsed >= SESSION_TIMEOUT:
625
+ self.session.is_running = False
626
+ self.session.status_message = "timeout"
627
+ return
628
+
629
+ # Check if stop was requested
630
+ if self.session._stop_event and self.session._stop_event.is_set():
631
+ return
632
+
633
+ # Get audio from the shared buffer - each consumer gets all chunks
634
+ async with self.lock:
635
+ position = self.consumer_positions[consumer_id]
636
+ if position < len(self.buffer):
637
+ audio_bytes = self.buffer[position]
638
+ self.consumer_positions[consumer_id] += 1
639
+ yield audio_bytes
640
+ else:
641
+ # No audio available, yield control briefly
642
+ await asyncio.sleep(0.05)
643
+ continue
644
+
645
+
646
+ async def audio_stream_duplicator_from_queue(session):
647
+ """Create a duplicator that can serve multiple audio streams."""
648
+ duplicator = AudioStreamDuplicator(session)
649
+
650
+ # Start a background task to fill the buffer from the queue
651
+ async def fill_buffer():
652
+ while session.is_running:
653
+ try:
654
+ # The queue contains base64-encoded PCM16 audio
655
+ b64_chunk = session.audio_queue.get_nowait()
656
+ # Decode base64 to raw bytes
657
+ audio_bytes = base64.b64decode(b64_chunk)
658
+
659
+ async with duplicator.lock:
660
+ # Add to buffer - all consumers will get this chunk
661
+ duplicator.buffer.append(audio_bytes)
662
+ except queue.Empty:
663
+ # No audio available, yield control briefly
664
+ await asyncio.sleep(0.05)
665
+ continue
666
+
667
+ # Start the buffer filler task
668
+ asyncio.create_task(fill_buffer())
669
+
670
+ return duplicator
671
+
672
+
673
  async def mistral_transcription_handler(session):
674
+ """Connect to Mistral realtime API and handle transcription with 2 parallel streams."""
675
  try:
676
  if not session.api_key:
677
  session.status_message = "error"
 
684
 
685
  session.status_message = "connecting"
686
 
 
 
 
687
  print(f"Session {session.session_id[:8]}: Connecting to Mistral realtime API...")
688
 
689
+ # Create a duplicator that can serve multiple audio streams
690
+ duplicator = await audio_stream_duplicator_from_queue(session)
691
+ print(f"Session {session.session_id[:8]}: Created audio stream duplicator for parallel processing")
692
+
693
+ # Create separate audio streams from the duplicator
694
+ audio_stream_1 = await duplicator.add_consumer()
695
+ audio_stream_2 = await duplicator.add_consumer()
696
+ print(f"Session {session.session_id[:8]}: Created 2 separate audio streams from duplicator")
697
+
698
+ # Create tasks for both transcription streams
699
+ async def process_stream_1():
700
+ async for event_1 in client.audio.realtime.transcribe_stream(
701
+ audio_stream=audio_stream_1,
702
+ model=MODEL,
703
+ audio_format=audio_format,
704
+ target_streaming_delay_ms=240
705
+ ):
706
+ if not session.is_running:
707
+ break
708
 
709
+ current_time = time.time()
 
 
 
710
 
711
+ if isinstance(event_1, RealtimeTranscriptionSessionCreated):
712
+ event_data = {
713
+ 'type': 'session_created',
714
+ 'timestamp': current_time,
715
+ 'session_id': event_1.session_id if hasattr(event_1, 'session_id') else None
716
+ }
717
+ session.stream_events['stream_1'].append(event_data)
718
+ session.last_event_timestamp = current_time
719
+ print(f"Session {session.session_id[:8]}: Stream 1 connected to Mistral - {current_time:.3f}")
720
+
721
+ elif isinstance(event_1, TranscriptionStreamTextDelta):
722
+ delta = event_1.text
723
+
724
+ # Get current full text by reconstructing from events
725
+ current_full_text = ""
726
+ for e in session.stream_events['stream_1']:
727
+ if e.get('type') == 'text_delta':
728
+ current_full_text += e.get('text', '')
729
+ current_full_text += delta
730
+
731
+ event_data = {
732
+ 'type': 'text_delta',
733
+ 'timestamp': current_time,
734
+ 'text': delta,
735
+ 'full_text': current_full_text
736
+ }
737
+ session.stream_events['stream_1'].append(event_data)
738
+ session.last_event_timestamp = current_time
739
+ print(f'1 [{current_time:.3f}]', delta, end="", flush=True)
740
+
741
+ words = delta.split()
742
+ for _ in words:
743
+ session.word_timestamps.append(time.time())
744
+
745
+ session.current_wpm = calculate_wpm(session)
746
+
747
+ elif isinstance(event_1, TranscriptionStreamDone):
748
+ event_data = {
749
+ 'type': 'stream_done',
750
+ 'timestamp': current_time
751
+ }
752
+ session.stream_events['stream_1'].append(event_data)
753
+ session.last_event_timestamp = current_time
754
+ print(f"Session {session.session_id[:8]}: Stream 1 transcription done - {current_time:.3f}")
755
+ break
756
+
757
+ elif isinstance(event_1, RealtimeTranscriptionError):
758
+ event_data = {
759
+ 'type': 'error',
760
+ 'timestamp': current_time,
761
+ 'error': str(event_1.error)
762
+ }
763
+ session.stream_events['stream_1'].append(event_data)
764
+ session.last_event_timestamp = current_time
765
+ print(f"Session {session.session_id[:8]}: Stream 1 error - {event_1.error} - {current_time:.3f}")
766
+ break
767
+
768
+ elif isinstance(event_1, UnknownRealtimeEvent):
769
+ event_data = {
770
+ 'type': 'unknown_event',
771
+ 'timestamp': current_time,
772
+ 'event': str(event_1)
773
+ }
774
+ session.stream_events['stream_1'].append(event_data)
775
+ session.last_event_timestamp = current_time
776
+ continue # Ignore unknown events
777
+
778
+ async def process_stream_2():
779
+ async for event_2 in client.audio.realtime.transcribe_stream(
780
+ audio_stream=audio_stream_2,
781
+ model=MODEL,
782
+ audio_format=audio_format,
783
+ target_streaming_delay_ms=2400
784
+ ):
785
+ if not session.is_running:
786
+ break
787
+
788
+ current_time = time.time()
789
 
790
+ if isinstance(event_2, RealtimeTranscriptionSessionCreated):
791
+ event_data = {
792
+ 'type': 'session_created',
793
+ 'timestamp': current_time,
794
+ 'session_id': event_2.session_id if hasattr(event_2, 'session_id') else None
795
+ }
796
+ session.stream_events['stream_2'].append(event_data)
797
+ session.last_event_timestamp = current_time
798
+ print(f"Session {session.session_id[:8]}: Stream 2 connected to Mistral - {current_time:.3f}")
799
+
800
+ elif isinstance(event_2, TranscriptionStreamTextDelta):
801
+ delta = event_2.text
802
+
803
+ # Get current full text by reconstructing from events
804
+ current_full_text = ""
805
+ for e in session.stream_events['stream_2']:
806
+ if e.get('type') == 'text_delta':
807
+ current_full_text += e.get('text', '')
808
+ current_full_text += delta
809
+
810
+ event_data = {
811
+ 'type': 'text_delta',
812
+ 'timestamp': current_time,
813
+ 'text': delta,
814
+ 'full_text': current_full_text
815
+ }
816
+ session.stream_events['stream_2'].append(event_data)
817
+ session.last_event_timestamp = current_time
818
+ print(f'2 [{current_time:.3f}]', delta, end="", flush=True)
819
+
820
+ session.current_wpm = calculate_wpm(session)
821
+
822
+ elif isinstance(event_2, TranscriptionStreamDone):
823
+ event_data = {
824
+ 'type': 'stream_done',
825
+ 'timestamp': current_time
826
+ }
827
+ session.stream_events['stream_2'].append(event_data)
828
+ session.last_event_timestamp = current_time
829
+ print(f"Session {session.session_id[:8]}: Stream 2 transcription done - {current_time:.3f}")
830
+ break
831
+
832
+ elif isinstance(event_2, RealtimeTranscriptionError):
833
+ event_data = {
834
+ 'type': 'error',
835
+ 'timestamp': current_time,
836
+ 'error': str(event_2.error)
837
+ }
838
+ session.stream_events['stream_2'].append(event_data)
839
+ session.last_event_timestamp = current_time
840
+ print(f"Session {session.session_id[:8]}: Stream 2 error - {event_2.error} - {current_time:.3f}")
841
+ break
842
+
843
+ elif isinstance(event_2, UnknownRealtimeEvent):
844
+ event_data = {
845
+ 'type': 'unknown_event',
846
+ 'timestamp': current_time,
847
+ 'event': str(event_2)
848
+ }
849
+ session.stream_events['stream_2'].append(event_data)
850
+ session.last_event_timestamp = current_time
851
+ continue # Ignore unknown events
852
+
853
+ # Run both streams in parallel
854
+ stream1_task = asyncio.create_task(process_stream_1())
855
+ stream2_task = asyncio.create_task(process_stream_2())
856
+
857
+ # Wait for both streams to complete
858
+ await asyncio.gather(stream1_task, stream2_task)
859
+
860
+ # Final transcription is already reconstructed from events
861
+ # Just add stats to the display
862
+ event_summary = session.get_event_summary()
863
+ stats_text = f"Events: {event_summary['stats']['total_events']} (S1: {event_summary['stats']['stream_1_count']}, S2: {event_summary['stats']['stream_2_count']})"
864
+
865
+ # Store the reconstructed transcription as tuple
866
+ session.transcription_tuple = session.reconstruct_transcription()
867
+
868
  except asyncio.CancelledError:
869
  pass # Normal cancellation
870
  except Exception as e:
 
932
  # Protect against startup races: Gradio can call `process_audio` concurrently.
933
  with session._start_lock:
934
  if session.is_running:
935
+ return get_transcription_html(session.reconstruct_transcription(), session.status_message, session.current_wpm, session.partial_transcript_enabled)
936
 
937
  # Check if API key is set
938
  if not session.api_key:
939
  session.status_message = "error"
940
+ return get_transcription_html(("Please enter your Mistral API key above to start transcription.","",""), "error", "", False)
941
 
942
  # Check if we've hit max concurrent sessions - kill all if so
943
  with _sessions_lock:
 
948
  if active_at_capacity or registry_over:
949
  kill_all_sessions()
950
  session.status_message = "error"
951
+ return get_transcription_html(("Server reset due to capacity. Please click the microphone to restart.","",""), "error", "", False)
952
 
 
953
  session.word_timestamps = []
954
  session.current_wpm = "Calibrating..."
955
  session.session_start_time = time.time()
956
  session.last_audio_time = time.time()
957
  session.status_message = "connecting"
958
+ session.stream_events = {
959
+ 'stream_1': [],
960
+ 'stream_2': []
961
+ }
962
 
963
  # Start Mistral transcription (now non-blocking, uses shared event loop)
964
  start_transcription(session)
965
 
966
+ return get_transcription_html(session.reconstruct_transcription(), session.status_message, session.current_wpm, session.partial_transcript_enabled)
967
 
968
 
969
+ def stop_session(session_id, api_key=None, partial_transcript=False):
970
  """Stop the transcription and invalidate the session.
971
 
972
  Returns None for session_id so a fresh session is created on next recording.
973
  This prevents duplicate session issues when users stop and restart quickly.
974
  """
975
  session = ensure_session(session_id)
976
+ old_transcripts = session.reconstruct_transcription()
977
  old_wpm = session.current_wpm
978
 
979
  if session.is_running:
 
1009
 
1010
  # Return None for session_id - a fresh session will be created on next recording
1011
  # This ensures no duplicate sessions when users stop/start quickly
1012
+ return get_transcription_html(old_transcripts, "ready", old_wpm, partial_transcript), None
1013
 
1014
 
1015
  async def _set_stop_event(event):
 
1017
  event.set()
1018
 
1019
 
1020
+ def clear_history(session_id, api_key=None, partial_transcript=False):
1021
  """Stop the transcription and clear all history."""
1022
  session = ensure_session(session_id)
1023
  session.is_running = False
 
1047
  # Reset the queue
1048
  session.reset_queue()
1049
 
1050
+ # Clear event history
1051
+ session.clear_events()
1052
+
1053
  session.word_timestamps = []
1054
  session.current_wpm = "Calibrating..."
1055
  session.session_start_time = None
1056
  session.status_message = "ready"
1057
+ session.stream_events = {
1058
+ 'stream_1': [],
1059
+ 'stream_2': []
1060
+ }
1061
 
1062
  # Return the session_id to maintain state
1063
+ return get_transcription_html(("",), "ready", "Calibrating...", False), None, session.session_id
1064
 
1065
 
1066
+ def process_audio(audio, session_id, api_key, partial_transcript=False):
1067
  """Process incoming audio and queue for streaming."""
1068
  # Check capacity - if at or above max, kill ALL sessions to reset
1069
  with _sessions_lock:
 
1080
  if registry_count > MAX_CONCURRENT_SESSIONS or active_count > MAX_CONCURRENT_SESSIONS or (active_count >= MAX_CONCURRENT_SESSIONS and not is_active_user):
1081
  kill_all_sessions()
1082
  return get_transcription_html(
1083
+ ("Server reset due to capacity. Please click the microphone to restart.","",""),
1084
  "error",
1085
+ "",
1086
+ False
1087
  ), None
1088
 
1089
  # Check if API key is provided
1090
  if not api_key or not api_key.strip():
1091
+ # return get_transcription_html(
1092
+ # ("Please enter your Mistral API key above to start transcription.","",""),
1093
+ # "error",
1094
+ # ""
1095
+ # ), None
1096
+ api_key = DEFAULT_API_KEY
1097
 
1098
  # Always ensure we have a valid session first
1099
  try:
1100
  session = ensure_session(session_id)
1101
  # Update API key on the session
1102
  session.api_key = api_key.strip()
1103
+ # Store partial transcript preference on the session
1104
+ session.partial_transcript_enabled = partial_transcript
1105
  except Exception as e:
1106
  print(f"Error creating session: {e}")
1107
  # Create a fresh session if ensure_session fails
1108
  session = UserSession(api_key=api_key.strip())
1109
+ session.partial_transcript_enabled = partial_transcript
1110
  _session_registry[session.session_id] = session
1111
 
1112
  # Cache session_id early in case of later errors
 
1116
  # Quick return if audio is None
1117
  if audio is None:
1118
  wpm = session.current_wpm if session.is_running else "Calibrating..."
1119
+ return get_transcription_html(session.reconstruct_transcription(), session.status_message, wpm, session.partial_transcript_enabled), current_session_id
1120
 
1121
  # Update last audio time for inactivity tracking
1122
  session.last_audio_time = time.time()
 
1127
 
1128
  # Skip processing if session stopped
1129
  if not session.is_running:
1130
+ return get_transcription_html(session.reconstruct_transcription(), session.status_message, session.current_wpm, session.partial_transcript_enabled), current_session_id
1131
 
1132
  sample_rate, audio_data = audio
1133
 
 
1160
  except Exception:
1161
  pass # Skip if queue is full
1162
 
1163
+ return get_transcription_html(session.reconstruct_transcription(), session.status_message, session.current_wpm, session.partial_transcript_enabled), current_session_id
1164
  except Exception as e:
1165
  print(f"Error processing audio: {e}")
1166
  # Return safe defaults - always include session_id to maintain state
1167
+ return get_transcription_html(("",), "error", "", False), current_session_id
 
 
 
1168
 
1169
  # Gradio interface
1170
  with gr.Blocks(title="Voxtral Real-time Transcription") as demo:
 
1174
  # Header
1175
  gr.HTML(get_header_html())
1176
 
1177
+ # API Key input with partial transcript checkbox
1178
  with gr.Row():
1179
  api_key_input = gr.Textbox(
1180
+ label="Mistral API Key (optional)",
1181
+ placeholder="Enter your own Mistral API key if you encounter issues.",
1182
  type="password",
1183
  elem_id="api-key-input",
1184
+ info="Get your API key from console.mistral.ai",
1185
+ scale=4
1186
+ )
1187
+ partial_transcript_checkbox = gr.Checkbox(
1188
+ label="Partial Transcript",
1189
+ info="Enable to show 2 streams + merged output",
1190
+ value=False,
1191
+ elem_id="partial-transcript-checkbox",
1192
+ scale=1
1193
  )
1194
 
1195
  # Transcription output
1196
  transcription_display = gr.HTML(
1197
+ value=get_transcription_html(("","",""), "ready", "Calibrating...", False),
1198
  elem_id="transcription-output"
1199
  )
1200
 
 
1220
  # Event handlers
1221
  clear_btn.click(
1222
  clear_history,
1223
+ inputs=[session_state, api_key_input, partial_transcript_checkbox],
1224
  outputs=[transcription_display, audio_input, session_state]
1225
  )
1226
 
1227
 
1228
  audio_input.stop_recording(
1229
  stop_session,
1230
+ inputs=[session_state, api_key_input, partial_transcript_checkbox],
1231
  outputs=[transcription_display, session_state]
1232
  )
1233
 
1234
  audio_input.stream(
1235
  process_audio,
1236
+ inputs=[audio_input, session_state, api_key_input, partial_transcript_checkbox],
1237
  outputs=[transcription_display, session_state],
1238
  show_progress="hidden",
1239
  concurrency_limit=500,
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
  gradio>=4.0.0
2
  websockets
3
  numpy
4
- mistralai[realtime]
 
1
  gradio>=4.0.0
2
  websockets
3
  numpy
4
+ mistralai[realtime]>=1.12.3
style.css CHANGED
@@ -191,6 +191,57 @@ body, .gradio-container {
191
  color: #000000 !important;
192
  white-space: pre-wrap;
193
  word-break: break-word;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  }
195
 
196
  .transcript-cursor {
@@ -283,6 +334,27 @@ footer {
283
  display: none !important;
284
  }
285
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
  .gradio-container .prose {
287
  max-width: none !important;
288
  }
@@ -294,4 +366,4 @@ footer {
294
  font-style: italic;
295
  text-align: center;
296
  margin-top: 1rem;
297
- }
 
191
  color: #000000 !important;
192
  white-space: pre-wrap;
193
  word-break: break-word;
194
+ text-align: left !important;
195
+ margin: 0 !important;
196
+ padding: 0 !important;
197
+ }
198
+
199
+ /* Fix for first text chunk alignment */
200
+ .transcript-text:first-letter {
201
+ margin-left: 0 !important;
202
+ }
203
+
204
+ .transcript-text::first-line {
205
+ text-indent: 0 !important;
206
+ }
207
+
208
+ .dual-stream-container {
209
+ display: grid;
210
+ grid-template-columns: 1fr 1fr;
211
+ gap: 1rem;
212
+ height: 100%;
213
+ text-align: left !important;
214
+ }
215
+
216
+ .triple-stream-container {
217
+ display: grid;
218
+ grid-template-columns: 1fr 1fr 1fr;
219
+ gap: 1rem;
220
+ height: 100%;
221
+ text-align: left !important;
222
+ }
223
+
224
+ .stream-box {
225
+ background: rgba(255, 255, 255, 0.6) !important;
226
+ border: 1px solid #E9E2CB;
227
+ border-radius: 4px;
228
+ padding: 0.75rem;
229
+ height: 100%;
230
+ overflow-y: auto;
231
+ text-align: left !important;
232
+ }
233
+
234
+ .stream-label {
235
+ font-family: 'JetBrains Mono', monospace !important;
236
+ font-size: 0.75rem !important;
237
+ font-weight: 700 !important;
238
+ color: #FF8205 !important;
239
+ text-transform: uppercase;
240
+ letter-spacing: 0.05em;
241
+ margin-bottom: 0.5rem;
242
+ padding-bottom: 0.25rem;
243
+ border-bottom: 1px solid #FF8205;
244
+ text-align: left !important;
245
  }
246
 
247
  .transcript-cursor {
 
334
  display: none !important;
335
  }
336
 
337
+ /* Partial transcript checkbox styling */
338
+ #partial-transcript-checkbox {
339
+ display: flex;
340
+ align-items: center;
341
+ justify-content: center;
342
+ margin-left: 1rem;
343
+ }
344
+
345
+ #partial-transcript-checkbox .gradio-checkbox {
346
+ transform: scale(1.2);
347
+ }
348
+
349
+ #partial-transcript-checkbox label {
350
+ font-family: 'JetBrains Mono', monospace !important;
351
+ font-size: 0.85rem !important;
352
+ font-weight: 600 !important;
353
+ color: #1E1E1E !important;
354
+ text-transform: uppercase;
355
+ letter-spacing: 0.05em;
356
+ }
357
+
358
  .gradio-container .prose {
359
  max-width: none !important;
360
  }
 
366
  font-style: italic;
367
  text-align: center;
368
  margin-top: 1rem;
369
+ }