Ubuntu commited on
Commit
5a0bcf6
·
1 Parent(s): 7ce11b0

Add WebSocket on_final mode support for faster transcription and update requirements

Browse files
Files changed (2) hide show
  1. app.py +139 -38
  2. requirements.txt +1 -0
app.py CHANGED
@@ -7,6 +7,9 @@ Real-time streaming transcription using Gradio's audio streaming.
7
  import os
8
  import tempfile
9
  from pathlib import Path
 
 
 
10
 
11
  import gradio as gr
12
  import requests
@@ -14,6 +17,13 @@ import numpy as np
14
  import soundfile as sf
15
  from dotenv import load_dotenv
16
 
 
 
 
 
 
 
 
17
  try:
18
  import librosa
19
  HAS_LIBROSA = True
@@ -87,6 +97,80 @@ class RinggSTTClient:
87
  print(f"Transcription error: {e}")
88
  return ""
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def transcribe_file(self, audio_file_path: str, language: str = "hi") -> str:
91
  """Transcribe audio file via multipart upload API"""
92
  try:
@@ -141,12 +225,8 @@ def resample_audio(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarra
141
 
142
  def transcribe_stream(audio, language, audio_buffer, last_transcription, samples_processed):
143
  """
144
- Process streaming audio from microphone.
145
-
146
- Simplified approach:
147
- - Accumulate ALL audio chunks
148
- - When we have enough new audio, transcribe the ENTIRE recording
149
- - Display the complete transcription (backend handles everything)
150
  """
151
  # Initialize states
152
  if audio_buffer is None:
@@ -183,22 +263,35 @@ def transcribe_stream(audio, language, audio_buffer, last_transcription, samples
183
  total_samples = sum(len(arr) for arr in audio_buffer)
184
  total_duration = total_samples / sample_rate
185
 
186
- # Calculate new audio since last transcription
187
- new_samples = total_samples - samples_processed
188
- new_duration = new_samples / sample_rate
189
-
190
- # Only transcribe if we have enough NEW audio (to avoid too frequent API calls)
191
- if new_duration < MIN_AUDIO_LENGTH:
192
- display = last_transcription if last_transcription else f"🎤 Recording... ({total_duration:.1f}s)"
193
- return display, audio_buffer, last_transcription, samples_processed
 
 
 
 
 
194
 
195
  try:
196
  # Concatenate ALL buffered audio
197
  full_audio = np.concatenate(audio_buffer)
198
 
 
 
 
 
 
 
 
199
  # Resample to 16kHz if needed
200
  if sample_rate != TARGET_SAMPLE_RATE:
201
  full_audio = resample_audio(full_audio, sample_rate, TARGET_SAMPLE_RATE)
 
202
 
203
  # Normalize audio
204
  max_val = np.max(np.abs(full_audio))
@@ -208,32 +301,31 @@ def transcribe_stream(audio, language, audio_buffer, last_transcription, samples
208
  # Get language code
209
  lang_code = "hi" if language == "Hindi" else "en"
210
 
211
- # Transcribe the ENTIRE audio
212
- transcription = stt_client.transcribe_audio_data(
213
- full_audio.astype(np.float32),
214
- TARGET_SAMPLE_RATE,
215
- lang_code
 
 
216
  )
217
 
218
  # Update state
219
- if transcription.strip():
220
  last_transcription = transcription
221
-
222
- # Mark all current samples as processed
223
- samples_processed = total_samples
224
-
225
- display = last_transcription if last_transcription else f"🎤 Recording... ({total_duration:.1f}s)"
226
- return display, audio_buffer, last_transcription, samples_processed
227
 
228
  except Exception as e:
229
  print(f"Processing error: {e}")
230
- display = last_transcription if last_transcription else "🎤 Listening..."
231
- return display, audio_buffer, last_transcription, samples_processed
232
 
233
 
234
  def clear_transcription():
235
  """Clear all transcription state"""
236
- return "🎤 Click microphone to start...", None, "", 0
237
 
238
 
239
  def transcribe_file(audio_file, language):
@@ -270,16 +362,16 @@ def create_interface():
270
 
271
  # Real-time streaming section
272
  gr.Markdown("""
273
- ## 🎤 Real-time Transcription
274
- Click the microphone to start recording. Transcription updates as you speak.
275
 
276
- *The entire recording is transcribed each time, so text may refine as more context is added.*
277
  """)
278
 
279
  # States for streaming
280
  audio_buffer = gr.State(None)
281
  last_transcription = gr.State("")
282
- samples_processed = gr.State(0)
283
 
284
  with gr.Row():
285
  with gr.Column(scale=1):
@@ -294,7 +386,9 @@ def create_interface():
294
  streaming=True,
295
  label="🎤 Click to start recording",
296
  )
297
- clear_btn = gr.Button("🗑️ Clear & Reset", variant="secondary")
 
 
298
 
299
  with gr.Column(scale=2):
300
  text_output = gr.Textbox(
@@ -304,18 +398,25 @@ def create_interface():
304
  interactive=False,
305
  )
306
 
307
- # Wire up streaming
308
  audio_input.stream(
309
  fn=transcribe_stream,
310
- inputs=[audio_input, stream_language, audio_buffer, last_transcription, samples_processed],
311
- outputs=[text_output, audio_buffer, last_transcription, samples_processed],
 
 
 
 
 
 
 
312
  )
313
 
314
  # Clear button
315
  clear_btn.click(
316
  fn=clear_transcription,
317
  inputs=[],
318
- outputs=[text_output, audio_buffer, last_transcription, samples_processed],
319
  )
320
 
321
  gr.Markdown("<br>")
 
7
  import os
8
  import tempfile
9
  from pathlib import Path
10
+ import json
11
+ import struct
12
+ import asyncio
13
 
14
  import gradio as gr
15
  import requests
 
17
  import soundfile as sf
18
  from dotenv import load_dotenv
19
 
20
+ try:
21
+ import websockets
22
+ HAS_WEBSOCKETS = True
23
+ except ImportError:
24
+ HAS_WEBSOCKETS = False
25
+ print("⚠️ websockets not installed. Install with: pip install websockets")
26
+
27
  try:
28
  import librosa
29
  HAS_LIBROSA = True
 
97
  print(f"Transcription error: {e}")
98
  return ""
99
 
100
+ async def transcribe_websocket_on_final(self, audio_data: np.ndarray, sample_rate: int, language: str = "hi") -> str:
101
+ """Transcribe audio via WebSocket on_final endpoint"""
102
+ if not HAS_WEBSOCKETS:
103
+ return "❌ websockets library not installed"
104
+
105
+ try:
106
+ # Convert HTTP endpoint to WebSocket
107
+ ws_endpoint = self.api_endpoint.replace("http://", "ws://").replace("https://", "wss://")
108
+ ws_url = f"{ws_endpoint}/v1/audio/stream"
109
+
110
+ # Convert audio to int16 PCM
111
+ audio_int16 = (audio_data * 32767).astype(np.int16)
112
+ audio_bytes = audio_int16.tobytes()
113
+
114
+ # Chunk size for streaming (send in 1 second chunks)
115
+ chunk_size = sample_rate * 2 # 2 bytes per sample (int16)
116
+
117
+ async with websockets.connect(ws_url, max_size=None) as ws:
118
+ # Send start message with on_final mode (first message must be "start")
119
+ start_msg = {
120
+ "type": "start",
121
+ "prediction_method": "on_final",
122
+ "sample_rate": sample_rate,
123
+ "encoding": "int16",
124
+ "language": "Hindi" if language == "hi" else "English",
125
+ "api_key": "gradio-client",
126
+ "punctuate": False
127
+ }
128
+ await ws.send(json.dumps(start_msg))
129
+
130
+ # Wait for ready response
131
+ ready_msg = await ws.recv()
132
+ ready_data = json.loads(ready_msg)
133
+
134
+ if ready_data.get("type") != "ready":
135
+ return f"❌ Unexpected response: {ready_data}"
136
+
137
+ print(f"✅ WebSocket ready: {ready_data}")
138
+
139
+ # Send audio in chunks
140
+ for i in range(0, len(audio_bytes), chunk_size):
141
+ chunk = audio_bytes[i:i + chunk_size]
142
+ await ws.send(chunk)
143
+
144
+ # Receive chunk acknowledgment
145
+ ack = await ws.recv()
146
+ ack_data = json.loads(ack)
147
+ if ack_data.get("type") == "chunk":
148
+ print(f"Buffered: {ack_data.get('total_buffered', 0)} samples")
149
+
150
+ # Send end signal to trigger transcription
151
+ end_msg = {"type": "end"}
152
+ await ws.send(json.dumps(end_msg))
153
+
154
+ # Receive transcription
155
+ transcription = ""
156
+ result_msg = await ws.recv()
157
+ result_data = json.loads(result_msg)
158
+
159
+ if result_data.get("type") == "transcript":
160
+ transcription = result_data.get("transcription", "")
161
+ elif result_data.get("type") == "error":
162
+ return f"❌ Error: {result_data.get('detail', 'Unknown error')}"
163
+
164
+ # Send stop to end session
165
+ stop_msg = {"type": "stop"}
166
+ await ws.send(json.dumps(stop_msg))
167
+
168
+ return transcription
169
+
170
+ except Exception as e:
171
+ print(f"WebSocket transcription error: {e}")
172
+ return f"❌ WebSocket Error: {str(e)}"
173
+
174
  def transcribe_file(self, audio_file_path: str, language: str = "hi") -> str:
175
  """Transcribe audio file via multipart upload API"""
176
  try:
 
225
 
226
  def transcribe_stream(audio, language, audio_buffer, last_transcription, samples_processed):
227
  """
228
+ Accumulate audio chunks during recording.
229
+ Just buffer the audio, don't transcribe yet.
 
 
 
 
230
  """
231
  # Initialize states
232
  if audio_buffer is None:
 
263
  total_samples = sum(len(arr) for arr in audio_buffer)
264
  total_duration = total_samples / sample_rate
265
 
266
+ # Just show recording status, don't transcribe yet
267
+ display = last_transcription if last_transcription else f"🎤 Recording... ({total_duration:.1f}s)"
268
+ return display, audio_buffer, last_transcription, sample_rate
269
+
270
+
271
+ def process_recorded_audio(audio_buffer, sample_rate, language, last_transcription):
272
+ """
273
+ Process the entire recorded audio after user stops recording.
274
+ This is called when the stop recording button is pressed.
275
+ Uses WebSocket on_final endpoint for faster transcription.
276
+ """
277
+ if audio_buffer is None or len(audio_buffer) == 0:
278
+ return "⚠️ No audio recorded", audio_buffer, last_transcription, 0
279
 
280
  try:
281
  # Concatenate ALL buffered audio
282
  full_audio = np.concatenate(audio_buffer)
283
 
284
+ # Calculate duration
285
+ total_samples = len(full_audio)
286
+ total_duration = total_samples / sample_rate
287
+
288
+ # Show processing message
289
+ print(f"Processing {total_duration:.1f}s of audio via WebSocket...")
290
+
291
  # Resample to 16kHz if needed
292
  if sample_rate != TARGET_SAMPLE_RATE:
293
  full_audio = resample_audio(full_audio, sample_rate, TARGET_SAMPLE_RATE)
294
+ sample_rate = TARGET_SAMPLE_RATE
295
 
296
  # Normalize audio
297
  max_val = np.max(np.abs(full_audio))
 
301
  # Get language code
302
  lang_code = "hi" if language == "Hindi" else "en"
303
 
304
+ # Transcribe via WebSocket on_final endpoint
305
+ transcription = asyncio.run(
306
+ stt_client.transcribe_websocket_on_final(
307
+ full_audio.astype(np.float32),
308
+ sample_rate,
309
+ lang_code
310
+ )
311
  )
312
 
313
  # Update state
314
+ if transcription and transcription.strip() and not transcription.startswith("❌"):
315
  last_transcription = transcription
316
+ return transcription, audio_buffer, last_transcription, sample_rate
317
+ else:
318
+ return transcription or "⚠️ No speech detected in the recording", audio_buffer, last_transcription, sample_rate
 
 
 
319
 
320
  except Exception as e:
321
  print(f"Processing error: {e}")
322
+ error_msg = f"❌ Error processing audio: {str(e)}"
323
+ return error_msg, audio_buffer, last_transcription, sample_rate
324
 
325
 
326
  def clear_transcription():
327
  """Clear all transcription state"""
328
+ return "🎤 Click microphone to start...", None, "", 16000
329
 
330
 
331
  def transcribe_file(audio_file, language):
 
362
 
363
  # Real-time streaming section
364
  gr.Markdown("""
365
+ ## 🎤 Record & Transcribe (WebSocket)
366
+ Click the microphone to start recording. Click stop when finished to get transcription.
367
 
368
+ *The entire recording will be transcribed via WebSocket on_final endpoint with TensorRT acceleration.*
369
  """)
370
 
371
  # States for streaming
372
  audio_buffer = gr.State(None)
373
  last_transcription = gr.State("")
374
+ sample_rate_state = gr.State(16000)
375
 
376
  with gr.Row():
377
  with gr.Column(scale=1):
 
386
  streaming=True,
387
  label="🎤 Click to start recording",
388
  )
389
+ with gr.Row():
390
+ stop_btn = gr.Button("⏹️ Stop & Transcribe", variant="primary", size="lg")
391
+ clear_btn = gr.Button("🗑️ Clear & Reset", variant="secondary")
392
 
393
  with gr.Column(scale=2):
394
  text_output = gr.Textbox(
 
398
  interactive=False,
399
  )
400
 
401
+ # Wire up streaming (just accumulates audio, doesn't transcribe)
402
  audio_input.stream(
403
  fn=transcribe_stream,
404
+ inputs=[audio_input, stream_language, audio_buffer, last_transcription, sample_rate_state],
405
+ outputs=[text_output, audio_buffer, last_transcription, sample_rate_state],
406
+ )
407
+
408
+ # Stop button - processes all accumulated audio
409
+ stop_btn.click(
410
+ fn=process_recorded_audio,
411
+ inputs=[audio_buffer, sample_rate_state, stream_language, last_transcription],
412
+ outputs=[text_output, audio_buffer, last_transcription, sample_rate_state],
413
  )
414
 
415
  # Clear button
416
  clear_btn.click(
417
  fn=clear_transcription,
418
  inputs=[],
419
+ outputs=[text_output, audio_buffer, last_transcription, sample_rate_state],
420
  )
421
 
422
  gr.Markdown("<br>")
requirements.txt CHANGED
@@ -5,3 +5,4 @@ requests==2.32.5
5
  huggingface-hub==1.0.1
6
  python-dotenv
7
  soundfile
 
 
5
  huggingface-hub==1.0.1
6
  python-dotenv
7
  soundfile
8
+ websockets