liuyang commited on
Commit
7cf016f
Β·
1 Parent(s): aaba71b

Add batched inference support in WhisperTranscriber for improved transcription performance. Update methods to accept batch size parameters and adjust output formatting accordingly.

Browse files
Files changed (1) hide show
  1. app.py +68 -40
app.py CHANGED
@@ -28,7 +28,7 @@ import subprocess
28
  import os
29
  import tempfile
30
  import spaces
31
- from faster_whisper import WhisperModel
32
  from faster_whisper.vad import VadOptions
33
  import requests
34
  import base64
@@ -64,6 +64,7 @@ model_cache_path = LOCAL_DIR # <‑‑ this is what we pass to WhisperModel
64
 
65
  # Lazy global holder ----------------------------------------------------------
66
  _whisper = None
 
67
  _diarizer = None
68
 
69
  # Create global diarization pipeline
@@ -87,7 +88,7 @@ except Exception as e:
87
 
88
  @spaces.GPU # GPU is guaranteed to exist *inside* this function
89
  def _load_models():
90
- global _whisper, _diarizer
91
  if _whisper is None:
92
  print("Loading Whisper model...")
93
  _whisper = WhisperModel(
@@ -95,8 +96,11 @@ def _load_models():
95
  device="cuda",
96
  compute_type="float16",
97
  )
98
- print("Whisper model loaded successfully")
99
- return _whisper, _diarizer
 
 
 
100
 
101
  # -----------------------------------------------------------------------------
102
  class WhisperTranscriber:
@@ -121,18 +125,18 @@ class WhisperTranscriber:
121
  raise RuntimeError(f"Audio conversion failed: {e}")
122
 
123
  @spaces.GPU # each call gets a GPU slice
124
- def transcribe_full_audio(self, audio_path, language=None, translate=False, prompt=None):
125
- """Transcribe the entire audio file without speaker diarization"""
126
- whisper, _ = _load_models() # models live on the GPU
127
 
128
- print("Transcribing full audio...")
129
  start_time = time.time()
130
 
131
- # Prepare options
132
  options = dict(
133
  language=language,
134
  beam_size=5,
135
- vad_filter=True,
136
  vad_parameters=VadOptions(
137
  max_speech_duration_s=whisper.feature_extractor.chunk_length,
138
  min_speech_duration_ms=100,
@@ -146,8 +150,12 @@ class WhisperTranscriber:
146
  task="translate" if translate else "transcribe",
147
  )
148
 
149
- # Transcribe the entire audio
150
- segments, transcript_info = whisper.transcribe(audio_path, batch_size=24, **options)
 
 
 
 
151
  segments = list(segments)
152
 
153
  detected_language = transcript_info.language
@@ -176,9 +184,9 @@ class WhisperTranscriber:
176
  "words": words_list,
177
  "duration": float(seg.end - seg.start)
178
  })
179
- print(results)
180
  transcription_time = time.time() - start_time
181
- print(f"Full audio transcribed in {transcription_time:.2f} seconds")
182
 
183
  return results, detected_language
184
 
@@ -214,14 +222,14 @@ class WhisperTranscriber:
214
  return audio_segments
215
 
216
  @spaces.GPU # each call gets a GPU slice
217
- def transcribe_audio_segments(self, audio_segments, language=None, translate=False, prompt=None):
218
- """Transcribe multiple audio segments using faster_whisper"""
219
- whisper, diarizer = _load_models() # models live on the GPU
220
 
221
- print(f"Transcribing {len(audio_segments)} audio segments...")
222
  start_time = time.time()
223
 
224
- # Prepare options similar to replicate.py
225
  options = dict(
226
  language=language,
227
  beam_size=5,
@@ -245,8 +253,12 @@ class WhisperTranscriber:
245
  for i, segment in enumerate(audio_segments):
246
  print(f"Processing segment {i+1}/{len(audio_segments)}")
247
 
248
- # Transcribe this segment
249
- segments, transcript_info = whisper.transcribe(segment["audio_path"], batch_size=24, **options)
 
 
 
 
250
  segments = list(segments)
251
 
252
  # Get detected language from first segment
@@ -255,7 +267,7 @@ class WhisperTranscriber:
255
 
256
  # Process each transcribed segment
257
  for seg in segments:
258
- # Create result entry with detailed format like replicate.py
259
  words_list = []
260
  if seg.words:
261
  for word in seg.words:
@@ -283,14 +295,14 @@ class WhisperTranscriber:
283
  os.unlink(segment["audio_path"])
284
 
285
  transcription_time = time.time() - start_time
286
- print(f"All segments transcribed in {transcription_time:.2f} seconds")
287
 
288
  return results, detected_language
289
 
290
  @spaces.GPU # each call gets a GPU slice
291
  def perform_diarization(self, audio_path, num_speakers=None):
292
  """Perform speaker diarization"""
293
- whisper, diarizer = _load_models() # models live on the GPU
294
 
295
  if diarizer is None:
296
  print("Diarization model not available, creating single speaker segment")
@@ -376,7 +388,7 @@ class WhisperTranscriber:
376
  return grouped_segments
377
 
378
  @spaces.GPU # each call gets a GPU slice
379
- def process_audio_full(self, audio_file, language=None, translate=False, prompt=None, group_segments=True):
380
  """Process audio with full transcription (no speaker diarization)"""
381
  if audio_file is None:
382
  return {"error": "No audio file provided"}
@@ -389,9 +401,9 @@ class WhisperTranscriber:
389
  print("Converting audio format...")
390
  converted_audio_path = self.convert_audio_format(audio_file)
391
 
392
- # Step 2: Transcribe the entire audio
393
  transcription_results, detected_language = self.transcribe_full_audio(
394
- converted_audio_path, language, translate, prompt
395
  )
396
 
397
  # Step 3: Group segments if requested (based on time gaps and sentence endings)
@@ -403,7 +415,8 @@ class WhisperTranscriber:
403
  "segments": transcription_results,
404
  "language": detected_language,
405
  "num_speakers": 1, # Single speaker assumption
406
- "transcription_method": "full_audio"
 
407
  }
408
 
409
  except Exception as e:
@@ -418,7 +431,7 @@ class WhisperTranscriber:
418
 
419
  @spaces.GPU # each call gets a GPU slice
420
  def process_audio(self, audio_file, num_speakers=None, language=None,
421
- translate=False, prompt=None, group_segments=True):
422
  """Main processing function - diarization first, then transcription"""
423
  if audio_file is None:
424
  return {"error": "No audio file provided"}
@@ -439,21 +452,22 @@ class WhisperTranscriber:
439
  # Step 3: Cut audio into segments based on diarization
440
  audio_segments = self.cut_audio_segments(converted_audio_path, diarization_segments)
441
 
442
- # Step 4: Transcribe each segment
443
  transcription_results, detected_language = self.transcribe_audio_segments(
444
- audio_segments, language, translate, prompt
445
  )
446
 
447
  # Step 5: Group segments if requested
448
  if group_segments:
449
  transcription_results = self.group_segments_by_speaker(transcription_results)
450
 
451
- # Step 6: Return in replicate.py format
452
  return {
453
  "segments": transcription_results,
454
  "language": detected_language,
455
  "num_speakers": detected_num_speakers,
456
- "transcription_method": "diarized_segments"
 
457
  }
458
 
459
  except Exception as e:
@@ -478,12 +492,14 @@ def format_segments_for_display(result):
478
  language = result.get("language", "unknown")
479
  num_speakers = result.get("num_speakers", 1)
480
  method = result.get("transcription_method", "unknown")
 
481
 
482
  output = f"🎯 **Detection Results:**\n"
483
  output += f"- Language: {language}\n"
484
  output += f"- Speakers: {num_speakers}\n"
485
  output += f"- Segments: {len(segments)}\n"
486
- output += f"- Method: {method}\n\n"
 
487
 
488
  output += "πŸ“ **Transcription:**\n\n"
489
 
@@ -499,7 +515,7 @@ def format_segments_for_display(result):
499
  return output
500
 
501
  @spaces.GPU
502
- def process_audio_gradio(audio_file, num_speakers, language, translate, prompt, group_segments, use_diarization):
503
  """Gradio interface function"""
504
  if use_diarization:
505
  result = transcriber.process_audio(
@@ -508,7 +524,8 @@ def process_audio_gradio(audio_file, num_speakers, language, translate, prompt,
508
  language=language if language != "auto" else None,
509
  translate=translate,
510
  prompt=prompt if prompt and prompt.strip() else None,
511
- group_segments=group_segments
 
512
  )
513
  else:
514
  result = transcriber.process_audio_full(
@@ -516,7 +533,8 @@ def process_audio_gradio(audio_file, num_speakers, language, translate, prompt,
516
  language=language if language != "auto" else None,
517
  translate=translate,
518
  prompt=prompt if prompt and prompt.strip() else None,
519
- group_segments=group_segments
 
520
  )
521
 
522
  formatted_output = format_segments_for_display(result)
@@ -533,7 +551,7 @@ with demo:
533
  # πŸŽ™οΈ Advanced Audio Transcription & Speaker Diarization
534
 
535
  Upload an audio file to get accurate transcription with speaker identification, powered by:
536
- - **Whisper Large V3 Turbo** with Flash Attention for fast transcription
537
  - **Pyannote 3.1** for speaker diarization
538
  - **ZeroGPU** acceleration for optimal performance
539
  """)
@@ -552,6 +570,15 @@ with demo:
552
  info="Uncheck for faster transcription without speaker identification"
553
  )
554
 
 
 
 
 
 
 
 
 
 
555
  num_speakers = gr.Slider(
556
  minimum=0,
557
  maximum=20,
@@ -613,7 +640,8 @@ with demo:
613
  translate,
614
  prompt,
615
  group_segments,
616
- use_diarization
 
617
  ],
618
  outputs=[output_text, output_json]
619
  )
@@ -622,7 +650,7 @@ with demo:
622
  gr.Markdown("### πŸ“‹ Usage Tips:")
623
  gr.Markdown("""
624
  - **Supported formats**: MP3, WAV, M4A, FLAC, OGG, and more
625
- - **Max duration**: Recommended under 10 minutes for optimal performance
626
  - **Speaker diarization**: Enable for speaker identification (slower), disable for faster transcription
627
  - **Languages**: Supports 100+ languages with auto-detection
628
  - **Vocabulary**: Add names and technical terms in the prompt for better accuracy
 
28
  import os
29
  import tempfile
30
  import spaces
31
+ from faster_whisper import WhisperModel, BatchedInferencePipeline
32
  from faster_whisper.vad import VadOptions
33
  import requests
34
  import base64
 
64
 
65
  # Lazy global holder ----------------------------------------------------------
66
  _whisper = None
67
+ _batched_whisper = None
68
  _diarizer = None
69
 
70
  # Create global diarization pipeline
 
88
 
89
  @spaces.GPU # GPU is guaranteed to exist *inside* this function
90
  def _load_models():
91
+ global _whisper, _batched_whisper, _diarizer
92
  if _whisper is None:
93
  print("Loading Whisper model...")
94
  _whisper = WhisperModel(
 
96
  device="cuda",
97
  compute_type="float16",
98
  )
99
+
100
+ # Create batched inference pipeline for improved performance
101
+ _batched_whisper = BatchedInferencePipeline(model=_whisper)
102
+ print("Whisper model and batched pipeline loaded successfully")
103
+ return _whisper, _batched_whisper, _diarizer
104
 
105
  # -----------------------------------------------------------------------------
106
  class WhisperTranscriber:
 
125
  raise RuntimeError(f"Audio conversion failed: {e}")
126
 
127
  @spaces.GPU # each call gets a GPU slice
128
+ def transcribe_full_audio(self, audio_path, language=None, translate=False, prompt=None, batch_size=16):
129
+ """Transcribe the entire audio file without speaker diarization using batched inference"""
130
+ whisper, batched_whisper, _ = _load_models() # models live on the GPU
131
 
132
+ print(f"Transcribing full audio with batch size {batch_size}...")
133
  start_time = time.time()
134
 
135
+ # Prepare options for batched inference
136
  options = dict(
137
  language=language,
138
  beam_size=5,
139
+ vad_filter=True, # VAD is enabled by default for batched transcription
140
  vad_parameters=VadOptions(
141
  max_speech_duration_s=whisper.feature_extractor.chunk_length,
142
  min_speech_duration_ms=100,
 
150
  task="translate" if translate else "transcribe",
151
  )
152
 
153
+ # Use batched inference for better performance
154
+ segments, transcript_info = batched_whisper.transcribe(
155
+ audio_path,
156
+ batch_size=batch_size,
157
+ **options
158
+ )
159
  segments = list(segments)
160
 
161
  detected_language = transcript_info.language
 
184
  "words": words_list,
185
  "duration": float(seg.end - seg.start)
186
  })
187
+
188
  transcription_time = time.time() - start_time
189
+ print(f"Full audio transcribed in {transcription_time:.2f} seconds using batch size {batch_size}")
190
 
191
  return results, detected_language
192
 
 
222
  return audio_segments
223
 
224
  @spaces.GPU # each call gets a GPU slice
225
+ def transcribe_audio_segments(self, audio_segments, language=None, translate=False, prompt=None, batch_size=8):
226
+ """Transcribe multiple audio segments using faster_whisper with batching"""
227
+ whisper, batched_whisper, _ = _load_models() # models live on the GPU
228
 
229
+ print(f"Transcribing {len(audio_segments)} audio segments with batch size {batch_size}...")
230
  start_time = time.time()
231
 
232
+ # Prepare options
233
  options = dict(
234
  language=language,
235
  beam_size=5,
 
253
  for i, segment in enumerate(audio_segments):
254
  print(f"Processing segment {i+1}/{len(audio_segments)}")
255
 
256
+ # Use batched inference for each segment
257
+ segments, transcript_info = batched_whisper.transcribe(
258
+ segment["audio_path"],
259
+ batch_size=batch_size,
260
+ **options
261
+ )
262
  segments = list(segments)
263
 
264
  # Get detected language from first segment
 
267
 
268
  # Process each transcribed segment
269
  for seg in segments:
270
+ # Create result entry with detailed format
271
  words_list = []
272
  if seg.words:
273
  for word in seg.words:
 
295
  os.unlink(segment["audio_path"])
296
 
297
  transcription_time = time.time() - start_time
298
+ print(f"All segments transcribed in {transcription_time:.2f} seconds using batch size {batch_size}")
299
 
300
  return results, detected_language
301
 
302
  @spaces.GPU # each call gets a GPU slice
303
  def perform_diarization(self, audio_path, num_speakers=None):
304
  """Perform speaker diarization"""
305
+ _, _, diarizer = _load_models() # models live on the GPU
306
 
307
  if diarizer is None:
308
  print("Diarization model not available, creating single speaker segment")
 
388
  return grouped_segments
389
 
390
  @spaces.GPU # each call gets a GPU slice
391
+ def process_audio_full(self, audio_file, language=None, translate=False, prompt=None, group_segments=True, batch_size=16):
392
  """Process audio with full transcription (no speaker diarization)"""
393
  if audio_file is None:
394
  return {"error": "No audio file provided"}
 
401
  print("Converting audio format...")
402
  converted_audio_path = self.convert_audio_format(audio_file)
403
 
404
+ # Step 2: Transcribe the entire audio with batching
405
  transcription_results, detected_language = self.transcribe_full_audio(
406
+ converted_audio_path, language, translate, prompt, batch_size
407
  )
408
 
409
  # Step 3: Group segments if requested (based on time gaps and sentence endings)
 
415
  "segments": transcription_results,
416
  "language": detected_language,
417
  "num_speakers": 1, # Single speaker assumption
418
+ "transcription_method": "full_audio_batched",
419
+ "batch_size": batch_size
420
  }
421
 
422
  except Exception as e:
 
431
 
432
  @spaces.GPU # each call gets a GPU slice
433
  def process_audio(self, audio_file, num_speakers=None, language=None,
434
+ translate=False, prompt=None, group_segments=True, batch_size=8):
435
  """Main processing function - diarization first, then transcription"""
436
  if audio_file is None:
437
  return {"error": "No audio file provided"}
 
452
  # Step 3: Cut audio into segments based on diarization
453
  audio_segments = self.cut_audio_segments(converted_audio_path, diarization_segments)
454
 
455
+ # Step 4: Transcribe each segment with batching
456
  transcription_results, detected_language = self.transcribe_audio_segments(
457
+ audio_segments, language, translate, prompt, batch_size
458
  )
459
 
460
  # Step 5: Group segments if requested
461
  if group_segments:
462
  transcription_results = self.group_segments_by_speaker(transcription_results)
463
 
464
+ # Step 6: Return results
465
  return {
466
  "segments": transcription_results,
467
  "language": detected_language,
468
  "num_speakers": detected_num_speakers,
469
+ "transcription_method": "diarized_segments_batched",
470
+ "batch_size": batch_size
471
  }
472
 
473
  except Exception as e:
 
492
  language = result.get("language", "unknown")
493
  num_speakers = result.get("num_speakers", 1)
494
  method = result.get("transcription_method", "unknown")
495
+ batch_size = result.get("batch_size", "N/A")
496
 
497
  output = f"🎯 **Detection Results:**\n"
498
  output += f"- Language: {language}\n"
499
  output += f"- Speakers: {num_speakers}\n"
500
  output += f"- Segments: {len(segments)}\n"
501
+ output += f"- Method: {method}\n"
502
+ output += f"- Batch Size: {batch_size}\n\n"
503
 
504
  output += "πŸ“ **Transcription:**\n\n"
505
 
 
515
  return output
516
 
517
  @spaces.GPU
518
+ def process_audio_gradio(audio_file, num_speakers, language, translate, prompt, group_segments, use_diarization, batch_size):
519
  """Gradio interface function"""
520
  if use_diarization:
521
  result = transcriber.process_audio(
 
524
  language=language if language != "auto" else None,
525
  translate=translate,
526
  prompt=prompt if prompt and prompt.strip() else None,
527
+ group_segments=group_segments,
528
+ batch_size=batch_size
529
  )
530
  else:
531
  result = transcriber.process_audio_full(
 
533
  language=language if language != "auto" else None,
534
  translate=translate,
535
  prompt=prompt if prompt and prompt.strip() else None,
536
+ group_segments=group_segments,
537
+ batch_size=batch_size
538
  )
539
 
540
  formatted_output = format_segments_for_display(result)
 
551
  # πŸŽ™οΈ Advanced Audio Transcription & Speaker Diarization
552
 
553
  Upload an audio file to get accurate transcription with speaker identification, powered by:
554
+ - **Faster-Whisper Large V3 Turbo** with batched inference for optimal performance
555
  - **Pyannote 3.1** for speaker diarization
556
  - **ZeroGPU** acceleration for optimal performance
557
  """)
 
570
  info="Uncheck for faster transcription without speaker identification"
571
  )
572
 
573
+ batch_size = gr.Slider(
574
+ minimum=1,
575
+ maximum=32,
576
+ value=16,
577
+ step=1,
578
+ label="Batch Size",
579
+ info="Higher values = faster processing but more GPU memory usage. Recommended: 8-24"
580
+ )
581
+
582
  num_speakers = gr.Slider(
583
  minimum=0,
584
  maximum=20,
 
640
  translate,
641
  prompt,
642
  group_segments,
643
+ use_diarization,
644
+ batch_size
645
  ],
646
  outputs=[output_text, output_json]
647
  )
 
650
  gr.Markdown("### πŸ“‹ Usage Tips:")
651
  gr.Markdown("""
652
  - **Supported formats**: MP3, WAV, M4A, FLAC, OGG, and more
653
+ - **Batch Size**: Higher values (16-24) = faster processing but more GPU memory
654
  - **Speaker diarization**: Enable for speaker identification (slower), disable for faster transcription
655
  - **Languages**: Supports 100+ languages with auto-detection
656
  - **Vocabulary**: Add names and technical terms in the prompt for better accuracy