liuyang commited on
Commit
99ff812
Β·
1 Parent(s): 77abe68

fast whisper

Browse files
Files changed (2) hide show
  1. app.py +118 -73
  2. requirements.txt +6 -3
app.py CHANGED
@@ -28,64 +28,38 @@ import subprocess
28
  import os
29
  import tempfile
30
  import spaces
31
- from transformers import pipeline
 
32
  from pyannote.audio import Pipeline
33
  import requests
34
  import base64
35
 
36
- # Install flash attention for acceleration
37
- '''
38
- try:
39
- subprocess.run(
40
- "pip install flash-attn --no-build-isolation",
41
- env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
42
- shell=True,
43
- check=True
44
- )
45
- except subprocess.CalledProcessError:
46
- print("Warning: Could not install flash-attn, falling back to default attention")
47
- '''
48
-
49
- # Create global Whisper pipeline
50
- pipe = pipeline(
51
- "automatic-speech-recognition",
52
- model="openai/whisper-large-v3-turbo",
53
- torch_dtype=torch.float16,
54
  device="cuda",
55
- model_kwargs={"attn_implementation": "flash_attention_2"},#flash_attention_2
56
- return_timestamps=True,
57
  )
 
58
 
59
  # Create global diarization pipeline
60
  diarization_pipe = None
61
  try:
62
  print("Loading diarization model...")
63
- torch.backends.cuda.matmul.allow_tf32 = True
64
- torch.backends.cudnn.allow_tf32 = True
65
- torch.set_float32_matmul_precision('high')
66
-
67
  diarization_pipe = Pipeline.from_pretrained(
68
  "pyannote/speaker-diarization-3.1",
69
  use_auth_token=os.getenv("HF_TOKEN"),
70
  torch_dtype=torch.float16,
71
  ).to(torch.device("cuda"))
72
- pipe.model.half() # FP16
73
-
74
- for m in pipe.model.modules(): # compact LSTM weights
75
- if isinstance(m, torch.nn.LSTM):
76
- m.flatten_parameters()
77
-
78
- pipe.model = torch.compile(pipe.model, mode="reduce-overhead")
79
  print("Diarization model loaded successfully")
80
  except Exception as e:
81
- import traceback
82
- traceback.print_exc()
83
  print(f"Could not load diarization model: {e}")
84
  diarization_pipe = None
85
 
86
  class WhisperTranscriber:
87
  def __init__(self):
88
- self.pipe = pipe # Use global Whisper pipeline
89
  self.diarization_model = diarization_pipe # Use global diarization pipeline
90
 
91
  def convert_audio_format(self, audio_path):
@@ -137,42 +111,65 @@ class WhisperTranscriber:
137
 
138
  @spaces.GPU
139
  def transcribe_audio_segments(self, audio_segments, language=None, translate=False, prompt=None):
140
- """Transcribe multiple audio segments"""
141
  print(f"Transcribing {len(audio_segments)} audio segments...")
142
  start_time = time.time()
143
 
144
- # Prepare generation kwargs
145
- generate_kwargs = {}
146
- if language:
147
- generate_kwargs["language"] = language
148
- if translate:
149
- generate_kwargs["task"] = "translate"
150
- if prompt:
151
- generate_kwargs["prompt_ids"] = self.pipe.tokenizer.encode(prompt)
 
 
 
 
 
 
 
 
 
152
 
153
  results = []
 
 
154
  for i, segment in enumerate(audio_segments):
155
  print(f"Processing segment {i+1}/{len(audio_segments)}")
156
 
157
  # Transcribe this segment
158
- result = self.pipe(
159
- segment["audio_path"],
160
- return_timestamps=True,
161
- generate_kwargs=generate_kwargs,
162
- chunk_length_s=30,
163
- batch_size=128,
164
- )
165
 
166
- # Extract text
167
- text = result["text"].strip() if "text" in result else ""
 
168
 
169
- # Create result entry
170
- results.append({
171
- "start_time": segment["start"],
172
- "end_time": segment["end"],
173
- "speaker_label": segment["speaker"],
174
- "text": text
175
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  # Clean up temporary files
178
  for segment in audio_segments:
@@ -182,7 +179,7 @@ class WhisperTranscriber:
182
  transcription_time = time.time() - start_time
183
  print(f"All segments transcribed in {transcription_time:.2f} seconds")
184
 
185
- return results
186
 
187
  def perform_diarization(self, audio_path, num_speakers=None):
188
  """Perform speaker diarization"""
@@ -228,6 +225,47 @@ class WhisperTranscriber:
228
 
229
  return diarize_segments, detected_num_speakers
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  @spaces.GPU
232
  def process_audio(self, audio_file, num_speakers=None, language=None,
233
  translate=False, prompt=None, group_segments=True):
@@ -252,14 +290,19 @@ class WhisperTranscriber:
252
  audio_segments = self.cut_audio_segments(converted_audio_path, diarization_segments)
253
 
254
  # Step 4: Transcribe each segment
255
- transcription_results = self.transcribe_audio_segments(
256
  audio_segments, language, translate, prompt
257
  )
258
 
259
- # Step 5: Return in requested format
 
 
 
 
260
  return {
261
- "speaker_count": detected_num_speakers,
262
- "transcription": transcription_results
 
263
  }
264
 
265
  except Exception as e:
@@ -280,19 +323,21 @@ def format_segments_for_display(result):
280
  if "error" in result:
281
  return f"❌ Error: {result['error']}"
282
 
283
- speaker_count = result.get("speaker_count", 1)
284
- transcription = result.get("transcription", [])
 
285
 
286
  output = f"🎯 **Detection Results:**\n"
287
- output += f"- Speakers: {speaker_count}\n"
288
- output += f"- Segments: {len(transcription)}\n\n"
 
289
 
290
  output += "πŸ“ **Transcription:**\n\n"
291
 
292
- for i, segment in enumerate(transcription, 1):
293
- start_time = str(datetime.timedelta(seconds=int(segment["start_time"])))
294
- end_time = str(datetime.timedelta(seconds=int(segment["end_time"])))
295
- speaker = segment.get("speaker_label", "SPEAKER_00")
296
  text = segment["text"]
297
 
298
  output += f"**{speaker}** ({start_time} β†’ {end_time})\n"
 
28
  import os
29
  import tempfile
30
  import spaces
31
+ from faster_whisper import WhisperModel
32
+ from faster_whisper.vad import VadOptions
33
  from pyannote.audio import Pipeline
34
  import requests
35
  import base64
36
 
37
+ # Create global Whisper model
38
+ print("Loading Whisper model...")
39
+ model = WhisperModel(
40
+ "large-v3-turbo",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  device="cuda",
42
+ compute_type="float16",
 
43
  )
44
+ print("Whisper model loaded successfully")
45
 
46
  # Create global diarization pipeline
47
  diarization_pipe = None
48
  try:
49
  print("Loading diarization model...")
 
 
 
 
50
  diarization_pipe = Pipeline.from_pretrained(
51
  "pyannote/speaker-diarization-3.1",
52
  use_auth_token=os.getenv("HF_TOKEN"),
53
  torch_dtype=torch.float16,
54
  ).to(torch.device("cuda"))
 
 
 
 
 
 
 
55
  print("Diarization model loaded successfully")
56
  except Exception as e:
 
 
57
  print(f"Could not load diarization model: {e}")
58
  diarization_pipe = None
59
 
60
  class WhisperTranscriber:
61
  def __init__(self):
62
+ self.model = model # Use global Whisper model
63
  self.diarization_model = diarization_pipe # Use global diarization pipeline
64
 
65
  def convert_audio_format(self, audio_path):
 
111
 
112
  @spaces.GPU
113
  def transcribe_audio_segments(self, audio_segments, language=None, translate=False, prompt=None):
114
+ """Transcribe multiple audio segments using faster_whisper"""
115
  print(f"Transcribing {len(audio_segments)} audio segments...")
116
  start_time = time.time()
117
 
118
+ # Prepare options similar to replicate.py
119
+ options = dict(
120
+ language=language,
121
+ beam_size=5,
122
+ vad_filter=True,
123
+ vad_parameters=VadOptions(
124
+ max_speech_duration_s=self.model.feature_extractor.chunk_length,
125
+ min_speech_duration_ms=100,
126
+ speech_pad_ms=100,
127
+ threshold=0.25,
128
+ neg_threshold=0.2,
129
+ ),
130
+ word_timestamps=True,
131
+ initial_prompt=prompt,
132
+ language_detection_segments=1,
133
+ task="translate" if translate else "transcribe",
134
+ )
135
 
136
  results = []
137
+ detected_language = None
138
+
139
  for i, segment in enumerate(audio_segments):
140
  print(f"Processing segment {i+1}/{len(audio_segments)}")
141
 
142
  # Transcribe this segment
143
+ segments, transcript_info = self.model.transcribe(segment["audio_path"], **options)
144
+ segments = list(segments)
 
 
 
 
 
145
 
146
+ # Get detected language from first segment
147
+ if detected_language is None:
148
+ detected_language = transcript_info.language
149
 
150
+ # Process each transcribed segment
151
+ for seg in segments:
152
+ # Create result entry with detailed format like replicate.py
153
+ words_list = []
154
+ if seg.words:
155
+ for word in seg.words:
156
+ words_list.append({
157
+ "start": float(word.start) + segment["start"],
158
+ "end": float(word.end) + segment["start"],
159
+ "word": word.word,
160
+ "probability": word.probability,
161
+ "speaker": segment["speaker"]
162
+ })
163
+
164
+ results.append({
165
+ "start": float(seg.start) + segment["start"],
166
+ "end": float(seg.end) + segment["start"],
167
+ "text": seg.text,
168
+ "speaker": segment["speaker"],
169
+ "avg_logprob": seg.avg_logprob,
170
+ "words": words_list,
171
+ "duration": float(seg.end - seg.start)
172
+ })
173
 
174
  # Clean up temporary files
175
  for segment in audio_segments:
 
179
  transcription_time = time.time() - start_time
180
  print(f"All segments transcribed in {transcription_time:.2f} seconds")
181
 
182
+ return results, detected_language
183
 
184
  def perform_diarization(self, audio_path, num_speakers=None):
185
  """Perform speaker diarization"""
 
225
 
226
  return diarize_segments, detected_num_speakers
227
 
228
+ def group_segments_by_speaker(self, segments, max_gap=1.0, max_duration=30.0):
229
+ """Group consecutive segments from the same speaker"""
230
+ if not segments:
231
+ return segments
232
+
233
+ grouped_segments = []
234
+ current_group = segments[0].copy()
235
+ sentence_end_pattern = r"[.!?]+"
236
+
237
+ for segment in segments[1:]:
238
+ time_gap = segment["start"] - current_group["end"]
239
+ current_duration = current_group["end"] - current_group["start"]
240
+
241
+ # Conditions for combining segments
242
+ can_combine = (
243
+ segment["speaker"] == current_group["speaker"] and
244
+ time_gap <= max_gap and
245
+ current_duration < max_duration and
246
+ not re.search(sentence_end_pattern, current_group["text"][-1:])
247
+ )
248
+
249
+ if can_combine:
250
+ # Merge segments
251
+ current_group["end"] = segment["end"]
252
+ current_group["text"] += " " + segment["text"]
253
+ current_group["words"].extend(segment["words"])
254
+ current_group["duration"] = current_group["end"] - current_group["start"]
255
+ else:
256
+ # Start new group
257
+ grouped_segments.append(current_group)
258
+ current_group = segment.copy()
259
+
260
+ grouped_segments.append(current_group)
261
+
262
+ # Clean up text
263
+ for segment in grouped_segments:
264
+ segment["text"] = re.sub(r"\s+", " ", segment["text"]).strip()
265
+ segment["text"] = re.sub(r"\s+([.,!?])", r"\1", segment["text"])
266
+
267
+ return grouped_segments
268
+
269
  @spaces.GPU
270
  def process_audio(self, audio_file, num_speakers=None, language=None,
271
  translate=False, prompt=None, group_segments=True):
 
290
  audio_segments = self.cut_audio_segments(converted_audio_path, diarization_segments)
291
 
292
  # Step 4: Transcribe each segment
293
+ transcription_results, detected_language = self.transcribe_audio_segments(
294
  audio_segments, language, translate, prompt
295
  )
296
 
297
+ # Step 5: Group segments if requested
298
+ if group_segments:
299
+ transcription_results = self.group_segments_by_speaker(transcription_results)
300
+
301
+ # Step 6: Return in replicate.py format
302
  return {
303
+ "segments": transcription_results,
304
+ "language": detected_language,
305
+ "num_speakers": detected_num_speakers
306
  }
307
 
308
  except Exception as e:
 
323
  if "error" in result:
324
  return f"❌ Error: {result['error']}"
325
 
326
+ segments = result.get("segments", [])
327
+ language = result.get("language", "unknown")
328
+ num_speakers = result.get("num_speakers", 1)
329
 
330
  output = f"🎯 **Detection Results:**\n"
331
+ output += f"- Language: {language}\n"
332
+ output += f"- Speakers: {num_speakers}\n"
333
+ output += f"- Segments: {len(segments)}\n\n"
334
 
335
  output += "πŸ“ **Transcription:**\n\n"
336
 
337
+ for i, segment in enumerate(segments, 1):
338
+ start_time = str(datetime.timedelta(seconds=int(segment["start"])))
339
+ end_time = str(datetime.timedelta(seconds=int(segment["end"])))
340
+ speaker = segment.get("speaker", "SPEAKER_00")
341
  text = segment["text"]
342
 
343
  output += f"**{speaker}** ({start_time} β†’ {end_time})\n"
requirements.txt CHANGED
@@ -1,11 +1,14 @@
1
  # 1. Do NOT pin torch/torchaudio here – keep the CUDA builds that come with the image
2
  torch==2.4.0
3
  transformers==4.48.0
4
- # pre-built wheel that matches torch+cu126
5
- https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.0.8/flash_attn-2.7.4.post1+cu126torch2.4-cp310-cp310-linux_x86_64.whl
6
  pydantic==2.10.6
7
 
8
- # 2. Extra libs your app really needs
 
 
 
9
  gradio==5.0.1
10
  spaces>=0.19.0
11
  pyannote.audio>=3.1.0
 
1
  # 1. Do NOT pin torch/torchaudio here – keep the CUDA builds that come with the image
2
  torch==2.4.0
3
  transformers==4.48.0
4
+ # Removed flash-attention since faster-whisper handles this internally
5
+ # https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.0.8/flash_attn-2.7.4.post1+cu126torch2.4-cp310-cp310-linux_x86_64.whl
6
  pydantic==2.10.6
7
 
8
+ # 2. Main whisper model
9
+ faster-whisper>=1.0.0
10
+
11
+ # 3. Extra libs your app really needs
12
  gradio==5.0.1
13
  spaces>=0.19.0
14
  pyannote.audio>=3.1.0